void fmg2_mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) { const int *dm; int cyc=1, nit=1, rtype=2; double *A, *b, *x, *scratch; static double param[] = {1.0, 1.0, 1.0, 0.0, 0.0}; if (nrhs!=3 || nlhs>1) mexErrMsgTxt("Incorrect usage"); if (!mxIsNumeric(prhs[0]) || mxIsComplex(prhs[0]) || mxIsSparse(prhs[0]) || !mxIsDouble(prhs[0])) mexErrMsgTxt("Data must be numeric, real, full and double"); if (mxGetNumberOfDimensions(prhs[0])!=3) mexErrMsgTxt("Wrong number of dimensions."); if (mxGetDimensions(prhs[0])[2]!=3) mexErrMsgTxt("3rd dimension of 1st arg must be 3."); if (!mxIsNumeric(prhs[1]) || mxIsComplex(prhs[1]) || mxIsSparse(prhs[1]) || !mxIsDouble(prhs[1])) mexErrMsgTxt("Data must be numeric, real, full and double"); if (mxGetNumberOfDimensions(prhs[1])!=3) mexErrMsgTxt("Wrong number of dimensions."); dm = mxGetDimensions(prhs[1]); if (dm[2]!=2) mexErrMsgTxt("3rd dimension of second arg must be 2."); if (mxGetDimensions(prhs[0])[0] != dm[0]) mexErrMsgTxt("Incompatible 1st dimension."); if (mxGetDimensions(prhs[0])[1] != dm[1]) mexErrMsgTxt("Incompatible 2nd dimension."); if (!mxIsNumeric(prhs[2]) || mxIsComplex(prhs[2]) || mxIsSparse(prhs[2]) || !mxIsDouble(prhs[2])) mexErrMsgTxt("Data must be numeric, real, full and double"); if (mxGetNumberOfElements(prhs[2]) >8) mexErrMsgTxt("Third argument should contain rtype, vox1, vox2, param1, param2, param3, ncycles and relax-its."); if (mxGetNumberOfElements(prhs[2]) >=1) rtype = (int)mxGetPr(prhs[2])[0]; if (mxGetNumberOfElements(prhs[2]) >=2) param[0] = 1/mxGetPr(prhs[2])[1]; if (mxGetNumberOfElements(prhs[2]) >=3) param[1] = 1/mxGetPr(prhs[2])[2]; if (mxGetNumberOfElements(prhs[2]) >=4) param[2] = mxGetPr(prhs[2])[3]; if (mxGetNumberOfElements(prhs[2]) >=5) param[3] = mxGetPr(prhs[2])[4]; if (mxGetNumberOfElements(prhs[2]) >=6) param[4] = mxGetPr(prhs[2])[5]; if (mxGetNumberOfElements(prhs[2]) >=7) cyc = mxGetPr(prhs[2])[6]; if (mxGetNumberOfElements(prhs[2]) >=8) nit = (int)mxGetPr(prhs[2])[7]; plhs[0] = mxCreateNumericArray(3,dm, mxDOUBLE_CLASS, mxREAL); A = mxGetPr(prhs[0]); b = mxGetPr(prhs[1]); x = mxGetPr(plhs[0]); scratch = (double *)mxCalloc(fmg2_scratchsize((int *)dm),sizeof(double)); fmg2((int *)dm, A, b, rtype, param, cyc, nit, x, scratch); mxFree((void *)scratch); }
void dartel(int dm[], int k, double v[], double g[], double f[], double dj[], int rtype, double param[], double lmreg, int cycles, int nits, int issym, double ov[], double ll[], double *buf) { double *sbuf; double *b, *A, *b1, *A1; double *t0, *t1, *J0, *J1; double sc; double ssl, ssp; double normb; int j, m = dm[0]*dm[1]; /* Allocate memory. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 [ A A A t t J J J J t t J J J J] for computing derivatives [ A A A s1 s2 s3 s4 s5 s6 s7 s8] for CGS solver */ b = ov; A = buf; t0 = buf + 3*m; J0 = buf + 5*m; t1 = buf + 9*m; J1 = buf + 11*m; A1 = buf + 15*m; b1 = buf + 18*m; sbuf = buf + 3*m; sc = 1.0/pow2(k); expdef(dm, k, v, t0, t1, J0, J1); jac_div_smalldef(dm, sc, v, J0); ssl = initialise_objfun(dm, f, g, t0, J0, dj, b, A); smalldef_jac(dm, -sc, v, t0, J0); squaring(dm, k, issym, b, A, t0, t1, J0, J1); if (issym) { jac_div_smalldef(dm, -sc, v, J0); ssl += initialise_objfun(dm, g, f, t0, J0, (double *)0, b1, A1); smalldef_jac(dm, sc, v, t0, J0); squaring(dm, k, 0, b1, A1, t0, t1, J0, J1); for(j=0; j<m*2; j++) b[j] -= b1[j]; for(j=0; j<m*3; j++) A[j] += A1[j]; } if (rtype==0) LtLf_le(dm, v, param, t1); else if (rtype==1) LtLf_me(dm, v, param, t1); else LtLf_be(dm, v, param, t1); ssp = 0.0; for(j=0; j<2*m; j++) { b[j] = b[j]*sc + t1[j]; ssp += t1[j]*v[j]; } normb = norm(2*m,b); for(j=0; j<3*m; j++) A[j] *= sc; for(j=0; j<2*m; j++) A[j] += lmreg; /* Solve equations for Levenberg-Marquardt update: * v = v - inv(H + L'*L + R)*(d + L'*L*v) * v: velocity or flow field * H: matrix of second derivatives * L: regularisation (L'*L is the inverse of the prior covariance) * R: Levenberg-Marquardt regularisation * d: vector of first derivatives */ /* cgs2(dm, A, b, rtype, param, 1e-8, 4000, sbuf, sbuf+2*m, sbuf+4*m, sbuf+6*m); */ fmg2(dm, A, b, rtype, param, cycles, nits, sbuf, sbuf+2*m); for(j=0; j<2*m; j++) ov[j] = v[j] - sbuf[j]; ll[0] = ssl; ll[1] = ssp*0.5; ll[2] = normb; }