示例#1
0
文件: mextriang.c 项目: B-Rich/sdpt3
/*************************************************************
*   PROCEDURE mexFunction - Entry for Matlab
**************************************************************/
 void mexFunction(int nlhs, mxArray *plhs[],
                  int nrhs, const mxArray *prhs[])

{  const double  *U;
   mwIndex       *irb, *jcb; 
   int            isspb;
   int            n, k, kend, options;   
   double        *y, *b, *btmp;

   if (nrhs < 2) {
      mexErrMsgTxt("mextriang requires 2 input arguments."); }
   if (nlhs > 1) {
      mexErrMsgTxt("mextriang generates 1 output argument."); }
 
   U = mxGetPr(prhs[0]);
   if (mxIsSparse(prhs[0])) {
      mexErrMsgTxt("mextriang: Sparse U not supported."); }
   n = mxGetM(prhs[0]); 
   if (mxGetN(prhs[0]) != n) {
      mexErrMsgTxt("mextriang: U should be square and upper triangular."); }
   isspb = mxIsSparse(prhs[1]); 
   if ( mxGetM(prhs[1])*mxGetN(prhs[1]) != n ) {
       mexErrMsgTxt("mextriang: size of U,b mismatch."); }
   if (nrhs > 2) { 
      options = (int)*mxGetPr(prhs[2]); }
   else {
      options = 1; 
   }
   if (isspb) {
      btmp = mxGetPr(prhs[1]);
      irb = mxGetIr(prhs[1]); jcb = mxGetJc(prhs[1]); 
      b = (double*)mxCalloc(n,sizeof(double));       
      kend = jcb[1]; 
      for (k=0; k<kend; k++) { b[irb[k]] = btmp[k]; } 
   } else {
      b = mxGetPr(prhs[1]); 
   }
   /************************************************/
   plhs[0] = mxCreateDoubleMatrix(n, 1, mxREAL);
   y = mxGetPr(plhs[0]);
  
   /************************************************/
   if (options==1) { 
      for (k=0; k<n; k++) { y[k]=b[k]; } 
      bwsolve(y,U,n);  
   } else if (options==2) { 
      fwsolve(y,U,b,n); 
   }   
   return;
}
/* ************************************************************
   PROCEDURE mexFunction - Entry for Matlab
   y = fwblksolve(L,b, [y])
     y = L.L \ b(L.perm)
   ************************************************************ */
void mexFunction(const int nlhs, mxArray *plhs[],
  const int nrhs, const mxArray *prhs[])
{
 const mxArray *L_FIELD;
 int m,n, j, k, nsuper, inz;
 double *y,*fwork;
 const double *permPr, *b, *xsuperPr;
 const int *yjc, *yir, *bjc, *bir;
 int *perm, *invperm, *snode, *xsuper, *iwork;
 jcir L;
 char bissparse;
 /* ------------------------------------------------------------
    Check for proper number of arguments 
    ------------------------------------------------------------ */
 if(nrhs < MINNPARIN)
   mexErrMsgTxt("fwblkslv requires more input arguments.");
 if(nlhs > NPAROUT)
   mexErrMsgTxt("fwblkslv generates only 1 output argument.");
 /* ------------------------------------------------------------
    Disassemble block Cholesky structure L
    ------------------------------------------------------------ */
 if(!mxIsStruct(L_IN))
   mexErrMsgTxt("Parameter `L' should be a structure.");
 if( (L_FIELD = mxGetField(L_IN,0,"perm")) == NULL)      /* L.perm */
   mexErrMsgTxt("Missing field L.perm.");
 m = mxGetM(L_FIELD) * mxGetN(L_FIELD);
 permPr = mxGetPr(L_FIELD);
 if( (L_FIELD = mxGetField(L_IN,0,"L")) == NULL)      /* L.L */
   mexErrMsgTxt("Missing field L.L.");
 if( m != mxGetM(L_FIELD) || m != mxGetN(L_FIELD) )
   mexErrMsgTxt("Size L.L mismatch.");
 if(!mxIsSparse(L_FIELD))
   mexErrMsgTxt("L.L should be sparse.");
 L.jc = mxGetJc(L_FIELD);
 L.ir = mxGetIr(L_FIELD);
 L.pr = mxGetPr(L_FIELD);
 if( (L_FIELD = mxGetField(L_IN,0,"xsuper")) == NULL)      /* L.xsuper */
   mexErrMsgTxt("Missing field L.xsuper.");
 nsuper = mxGetM(L_FIELD) * mxGetN(L_FIELD) - 1;
 if( nsuper > m )
   mexErrMsgTxt("Size L.xsuper mismatch.");
 xsuperPr = mxGetPr(L_FIELD);
 /* ------------------------------------------------------------
    Get rhs matrix b.
    If it is sparse, then we also need the sparsity structure of y.
    ------------------------------------------------------------ */
 b = mxGetPr(B_IN);
 if( mxGetM(B_IN) != m )
   mexErrMsgTxt("Size mismatch b.");
 n = mxGetN(B_IN);
 if( (bissparse = mxIsSparse(B_IN)) ){
   bjc = mxGetJc(B_IN);
   bir = mxGetIr(B_IN);
   if(nrhs < NPARIN)
     mexErrMsgTxt("fwblkslv requires more inputs in case of sparse b.");
   if(mxGetM(Y_IN) != m || mxGetN(Y_IN) != n)
     mexErrMsgTxt("Size mismatch y.");
   if(!mxIsSparse(Y_IN))
     mexErrMsgTxt("y should be sparse.");
 }
/* ------------------------------------------------------------
   Allocate output y. If bissparse, then Y_IN gives the sparsity structure.
   ------------------------------------------------------------ */
 if(!bissparse)
   Y_OUT = mxCreateDoubleMatrix(m, n, mxREAL);
 else{
   yjc = mxGetJc(Y_IN);
   yir = mxGetIr(Y_IN);
   Y_OUT = mxCreateSparse(m,n, yjc[n],mxREAL);
   memcpy(mxGetJc(Y_OUT), yjc, (n+1) * sizeof(int));
   memcpy(mxGetIr(Y_OUT), yir, yjc[n] * sizeof(int));
 }
 y = mxGetPr(Y_OUT);
/* ------------------------------------------------------------
   Allocate working arrays fwork(m) and iwork(2*m + nsuper+1)
   ------------------------------------------------------------ */
 fwork = (double *) mxCalloc(m, sizeof(double));
 iwork = (int *) mxCalloc(2*m+nsuper+1, sizeof(int));
 perm = iwork;
 invperm = perm;
 xsuper = iwork + m;
 snode = xsuper + (nsuper+1);
/* ------------------------------------------------------------
   Convert real to integer array, and from Fortran to C style.
   In case of sparse b, we store the inverse perm, instead of perm itself.
   ------------------------------------------------------------ */
 for(k = 0; k <= nsuper; k++)
   xsuper[k] = xsuperPr[k] - 1;
 if(!bissparse)
   for(k = 0; k < m; k++)               /* Get perm if !bissparse */
     perm[k] = permPr[k] - 1;
 else{
   for(k = 0; k < m; k++){              /* Get invperm if bissparse */
     j = permPr[k];
     invperm[--j] = k;
   }
/* ------------------------------------------------------------
   In case of sparse b, we also create snode, which maps each subnode
   to the supernode containing it.
   ------------------------------------------------------------ */
   for(j = 0, k = 0; k < nsuper; k++)
     while(j < xsuper[k+1])
       snode[j++] = k;
 }
/* ------------------------------------------------------------
   The actual job is done here: y = L\b(perm).
   ------------------------------------------------------------ */
 if(!bissparse)
   for(j = 0; j < n; j++){
     for(k = 0; k < m; k++)            /* y = b(perm) */
       y[k] = b[perm[k]];
     fwsolve(y,L.jc,L.ir,L.pr,xsuper,nsuper,fwork);
     y += m; b += m;
   }
 else
   for(j = 0, inz = 0; j < n; j++){
     for(k = inz; k < yjc[j+1]; k++)            /* fwork = all-0 */
       fwork[yir[k]] = 0.0;
     for(k = bjc[j]; k < bjc[j+1]; k++)            /* fwork = b(perm) */
       fwork[invperm[bir[k]]] = b[k];
     selfwsolve(fwork,L.jc,L.ir,L.pr,xsuper,nsuper, snode,
                yir+inz,yjc[j+1]-inz);
     for(; inz < yjc[j+1]; inz++)
       y[inz] = fwork[yir[inz]];
   }
 /* ------------------------------------------------------------
    RELEASE WORKING ARRAYS.
    ------------------------------------------------------------ */
 mxFree(iwork);
 mxFree(fwork);
}