void test_direct_trans()
{
	const index_t m = M == 0 ? DM : M;
	const index_t n = N == 0 ? DN : N;

	// M x N ==> N x M

	typedef typename mat_host<Tag1, double, M, N>::cmat_t cmat_t;
	typedef typename mat_host<Tag2, double, N, M>::mat_t mat_t;

	mat_host<Tag1, double, M, N> src(m, n);
	mat_host<Tag2, double, N, M> dst(n, m);

	src.fill_lin();

	cmat_t smat = src.get_cmat();
	mat_t dmat = dst.get_mat();

	transpose(smat, dmat);

	dense_matrix<double, N, M> rmat(n, m);
	for (index_t i = 0; i < m; ++i)
	{
		for (index_t j = 0; j < n; ++j) rmat(j, i) = smat(i, j);
	}

	ASSERT_MAT_EQ(n, m, dmat, rmat);

	// N x M ==> M x N

	typedef typename mat_host<Tag1, double, N, M>::cmat_t cmat2_t;
	typedef typename mat_host<Tag2, double, M, N>::mat_t mat2_t;

	mat_host<Tag1, double, N, M> src2(n, m);
	mat_host<Tag2, double, M, N> dst2(m, n);

	src2.fill_lin();

	cmat2_t smat2 = src2.get_cmat();
	mat2_t dmat2 = dst2.get_mat();

	transpose(smat2, dmat2);

	dense_matrix<double, M, N> rmat2(m, n);
	for (index_t i = 0; i < m; ++i)
	{
		for (index_t j = 0; j < n; ++j) rmat2(i, j) = smat2(j, i);
	}

	ASSERT_MAT_EQ(m, n, dmat2, rmat2);

}
Exemplo n.º 2
0
void mexFunction(int nlhs,   mxArray  *plhs[], 
                 int nrhs,   const mxArray  *prhs[] )

{    mxArray  *rhs[2]; 
     mxArray  *blk_cell_pr, *A_cell_pr;
     double   *A,  *B,  *blksize, *AI, *BI;
     mwIndex  *irA, *jcA, *irB, *jcB;
     int      *cumblksize, *blknnz;
     int       mblk, mA, nA, m1, n1, rowidx, colidx, isspA, isspB;
     int       iscellA, iscmpA; 

     mwIndex  subs[2];
     mwSize   nsubs=2; 
     int      n, n2, k, nsub, index, numblk, NZmax;
     double   ir2=1/sqrt(2); 

/* CHECK FOR PROPER NUMBER OF ARGUMENTS */

     if (nrhs < 2){
         mexErrMsgTxt("mexsmat: requires at least 2 input arguments."); }
     else if (nlhs>1){ 
         mexErrMsgTxt("mexsmat: requires 1 output argument."); }

/* CHECK THE DIMENSIONS */

     iscellA = mxIsCell(prhs[1]); 
     if (iscellA) { mA = mxGetM(prhs[1]); nA = mxGetN(prhs[1]); }
     else         { mA = 1; nA = 1; }
     if (mxGetM(prhs[0]) != mA) {
         mexErrMsgTxt("mexsmat: blk and Avec not compatible"); }

/***** main body *****/ 
       
     if (nrhs > 3) {rowidx = (int)*mxGetPr(prhs[3]); } else {rowidx = 1;}  
     if (rowidx > mA) {
         mexErrMsgTxt("mexsmat: rowidx exceeds size(Avec,1)."); }
     subs[0] = rowidx-1;  /* subtract 1 to adjust for Matlab index */
     subs[1] = 1;
     index = mxCalcSingleSubscript(prhs[0],nsubs,subs); 
     blk_cell_pr = mxGetCell(prhs[0],index);
     if (blk_cell_pr == NULL) { 
        mexErrMsgTxt("mexsmat: blk not properly specified"); }    
     numblk  = mxGetN(blk_cell_pr);            
     blksize = mxGetPr(blk_cell_pr);
     cumblksize = mxCalloc(numblk+1,sizeof(int)); 
     blknnz = mxCalloc(numblk+1,sizeof(int)); 
     cumblksize[0] = 0; blknnz[0] = 0; 
     n = 0;  n2 = 0; 
     for (k=0; k<numblk; ++k) {
          nsub = (int) blksize[k];
          n  += nsub; 
          n2 += nsub*(nsub+1)/2;  
          cumblksize[k+1] = n; 
          blknnz[k+1] = n2;
     }
     /***** assign pointers *****/
     if (iscellA) { 
         subs[0] = rowidx-1; 
         subs[1] = 0;
         index = mxCalcSingleSubscript(prhs[1],nsubs,subs); 
         A_cell_pr = mxGetCell(prhs[1],index); 
         A  = mxGetPr(A_cell_pr); 
         m1 = mxGetM(A_cell_pr); 
         n1 = mxGetN(A_cell_pr);
         isspA = mxIsSparse(A_cell_pr);
         iscmpA = mxIsComplex(A_cell_pr);  
         if (isspA) { irA = mxGetIr(A_cell_pr);
                      jcA = mxGetJc(A_cell_pr); } 
         if (iscmpA) { AI = mxGetPi(A_cell_pr); }
     } else { 
         A  = mxGetPr(prhs[1]); 
         m1 = mxGetM(prhs[1]); 
         n1 = mxGetN(prhs[1]); 
         isspA = mxIsSparse(prhs[1]); 
         iscmpA = mxIsComplex(prhs[1]);  
         if (isspA) {  irA = mxGetIr(prhs[1]); 
                       jcA = mxGetJc(prhs[1]); }
         if (iscmpA) { AI = mxGetPi(prhs[1]); }
     }
     if (numblk > 1) { 
        isspB = 1; 
     } else { 
        if (nrhs > 2) {isspB = (int)*mxGetPr(prhs[2]);} else {isspB = isspA;} 
     }
     if (nrhs > 4) {colidx = (int)*mxGetPr(prhs[4]) -1;} else {colidx = 0;} 
     if (colidx > n1) { 
         mexErrMsgTxt("mexsmat: colidx exceeds size(Avec,2)."); 
     }    
     /***** create return argument *****/
     if (isspB) {
	if (isspA) { NZmax = jcA[colidx+1]-jcA[colidx]; } 
        else       { NZmax = blknnz[numblk]; } 
        if (iscmpA) { 
	   rhs[0] = mxCreateSparse(n,n,2*NZmax,mxCOMPLEX); 
        } else {
           rhs[0] = mxCreateSparse(n,n,2*NZmax,mxREAL); 
        }
	B = mxGetPr(rhs[0]); 
        irB = mxGetIr(rhs[0]); 
        jcB = mxGetJc(rhs[0]);
        if (iscmpA) { BI = mxGetPi(rhs[0]); }
     } else {
        if (iscmpA) { 
          plhs[0] = mxCreateDoubleMatrix(n,n,mxCOMPLEX); 
        } else {
          plhs[0] = mxCreateDoubleMatrix(n,n,mxREAL); 
        }
        B = mxGetPr(plhs[0]); 
        if (iscmpA) { BI = mxGetPi(plhs[0]); }
     }
     /***** Do the computations in a subroutine *****/
     if (iscmpA) { 
       if (numblk == 1) { 
          smat1cmp(n,ir2,A,irA,jcA,isspA,m1,colidx,B,irB,jcB,isspB,AI,BI);  
       } else { 
          smat2cmp(n,numblk,cumblksize,blknnz,ir2,A,irA,jcA,isspA,m1,colidx,B,irB,jcB,isspB,AI,BI);
       }
     } else {
       if (numblk == 1) { 
          smat1(n,ir2,A,irA,jcA,isspA,m1,colidx,B,irB,jcB,isspB);  
       } else { 
          smat2(n,numblk,cumblksize,blknnz,ir2,A,irA,jcA,isspA,m1,colidx,B,irB,jcB,isspB);
       }
     }
     if (isspB) {
        /*** if isspB, (actual B) = B+B' ****/ 
        mexCallMATLAB(1, &rhs[1], 1, &rhs[0], "ctranspose"); 
        mexCallMATLAB(1, &plhs[0],2, rhs, "+");  
        mxDestroyArray(*rhs); 
     }
     mxFree(blknnz); 
     mxFree(cumblksize);
 return;
 }