/**
eigentrain

computes the eigen space for matrix images.

This function is used in the training procedure of the face recognition pca
algorithm.
INPUT:  the data matrix of images
OUTPUT: mean: the mean value of the images
eigen_values: eigenvalues
eigen_base: eigenvectors

The data matrix is mean centered, and this is a side effect.
*/
void eigentrain(Matrix *mean, Matrix *eigen_vals,
                Matrix *eigen_base, Matrix images) 
{
    double p = 0.0;
    Matrix M, eigenvectors;

    DEBUG(1, "Calculating mean image.");
    *mean = get_mean_image(images);

    DEBUG(1, "Calculating the mean centered images for the training set.");
    mean_subtract_images(images, *mean);

    MESSAGE2ARG("Calculating Covariance Matrix: M = images' * images. M is a %d by %d Matrix.", images->col_dim, images->col_dim);
    M = transposeMultiplyMatrixL(images, images);
    DEBUG_INT(3, "Covariance Matrix Rows", M->row_dim);
    DEBUG_INT(3, "Covariance Matrix Cols", M->col_dim);

    DEBUG(2, "Allocating memory for eigenvectors and eigenvalues.");
    eigenvectors = makeMatrix(M->row_dim, M->col_dim);
    *eigen_vals = makeMatrix(M->row_dim, 1);


    MESSAGE("Computing snap shot eigen vectors using the double precision cv eigensolver.");

    cvJacobiEigens_64d(M->data, eigenvectors->data, (*eigen_vals)->data, images->col_dim, p, 1);

    freeMatrix(M);

    DEBUG(1, "Verifying the eigen vectors");
    /* Reconstruct M because it is destroyed in cvJacobiEigens */
    M = transposeMultiplyMatrixL(images, images);
    if (debuglevel >= 3)
        eigen_verify(M, *eigen_vals, eigenvectors);
    freeMatrix(M);

    *eigen_base = multiplyMatrix(images, eigenvectors);
    MESSAGE2ARG("Recovered the %d by %d high resolution eigen basis.", (*eigen_base)->row_dim, (*eigen_base)->col_dim);

    DEBUG(1, "Normalizing eigen basis");
    basis_normalize(*eigen_base);

    /*Remove last elements because they are unneeded.  Mean centering the image
    guarantees that the data points define a hyperplane that passes through
    the origin. Therefore all points are in a k - 1 dimensional subspace.
    */
    (*eigen_base)->col_dim -= 1;
    (*eigen_vals)->row_dim -= 1;
    eigenvectors->col_dim -= 1;

    DEBUG(1, "Verifying eigenbasis");
    if (debuglevel >= 3)
        basis_verify(images, *eigen_base);

    /* The eigenvectors for the smaller covariance (snap shot) are not needed */
    freeMatrix(eigenvectors);
}
/*START: 	Changed by Zeeshan: for ICA*/
Matrix
centerThenProjectImagesICA (Subspace *s, Matrix images)
{
    Matrix eigpims, subspims;

    assert(s->useICA);

    mean_subtract_images (images, s->mean);

    eigpims = transposeMultiplyMatrixL (images, s->basis);

    //in ICA1, s->basis contains the combined basis
    if(s->arch == 1)
        return eigpims;

    mean_subtract_images (eigpims, s->mean);

    subspims = transposeMultiplyMatrixL (eigpims, s->ica2Basis);

    return subspims;
}
Matrix
centerThenProjectImages (Subspace *s, Matrix images)
{
    Matrix subspims;

    /*START: 	Changed by Zeeshan: for ICA*/
    assert(!s->useICA);
    /*END: 		Changed by Zeeshan: for ICA*/

    mean_subtract_images (images, s->mean);
    subspims = transposeMultiplyMatrixL (s->basis, images);
    return subspims;
}
/*  findBCSMatrix
 
 The approach here is to use equation 116 on p. 122 of
 Duda, "Pattern Classification" to solve for the between
 class scatter matrix. The equation reads:
 
 St = Sw + Sb  (St is total scatter matrix)
 
 Solving for Sb,  Sb = St - Sw. Note Sw is computed above.
 
 The total scatter matrix is easy to compute, it is what we
 think of as the covariance matrix when the points have been
 centered.
 */
Matrix findBCSMatrix(Matrix imspca, Matrix Sw) {
    Matrix Sb;
    Matrix St;
    Matrix mean;

    MESSAGE("Finding between-class scatter matrix.");

    mean = get_mean_image(imspca);
    mean_subtract_images(imspca, mean);

    DEBUG(3, "Producing total scatter matrix");
    St = transposeMultiplyMatrixR(imspca, imspca);

    DEBUG(2, "Producing between-class scatter matrix");
    Sb = subtractMatrix(St, Sw);

    freeMatrix(St);
    freeMatrix(mean);

    return Sb;
}