void dLDLTRemove (dReal **A, const int *p, dReal *L, dReal *d, int n1, int n2, int r, int nskip) { int i; dAASSERT(A && p && L && d && n1 > 0 && n2 > 0 && r >= 0 && r < n2 && n1 >= n2 && nskip >= n1); #ifndef dNODEBUG for (i=0; i<n2; i++) dIASSERT(p[i] >= 0 && p[i] < n1); #endif if (r==n2-1) { return; // deleting last row/col is easy } else if (r==0) { dReal *a = (dReal*) ALLOCA (n2 * sizeof(dReal)); for (i=0; i<n2; i++) a[i] = -GETA(p[i],p[0]); a[0] += REAL(1.0); dLDLTAddTL (L,d,a,n2,nskip); } else { dReal *t = (dReal*) ALLOCA (r * sizeof(dReal)); dReal *a = (dReal*) ALLOCA ((n2-r) * sizeof(dReal)); for (i=0; i<r; i++) t[i] = L[r*nskip+i] / d[i]; for (i=0; i<(n2-r); i++) a[i] = dDot(L+(r+i)*nskip,t,r) - GETA(p[r+i],p[r]); a[0] += REAL(1.0); dLDLTAddTL (L + r*nskip+r, d+r, a, n2-r, nskip); } // snip out row/column r from L and d dRemoveRowCol (L,n2,nskip,r); if (r < (n2-1)) memmove (d+r,d+r+1,(n2-r-1)*sizeof(dReal)); }
void _dLDLTRemove (dReal **A, const int *p, dReal *L, dReal *d, int n1, int n2, int r, int nskip, void *tmpbuf/*n2 + 2*nskip*/) { dAASSERT(A && p && L && d && n1 > 0 && n2 > 0 && r >= 0 && r < n2 && n1 >= n2 && nskip >= n1); #ifndef dNODEBUG for (int i=0; i<n2; ++i) dIASSERT(p[i] >= 0 && p[i] < n1); #endif if (r==n2-1) { return; // deleting last row/col is easy } else { size_t LDLTAddTL_size = _dEstimateLDLTAddTLTmpbufSize(nskip); dIASSERT(LDLTAddTL_size % sizeof(dReal) == 0); dReal *tmp = tmpbuf ? (dReal *)tmpbuf : (dReal*) ALLOCA (LDLTAddTL_size + n2 * sizeof(dReal)); if (r==0) { dReal *a = (dReal *)((char *)tmp + LDLTAddTL_size); const int p_0 = p[0]; for (int i=0; i<n2; ++i) { a[i] = -GETA(p[i],p_0); } a[0] += REAL(1.0); dLDLTAddTL (L,d,a,n2,nskip,tmp); } else { dReal *t = (dReal *)((char *)tmp + LDLTAddTL_size); { dReal *Lcurr = L + r*nskip; for (int i=0; i<r; ++Lcurr, ++i) { dIASSERT(d[i] != dReal(0.0)); t[i] = *Lcurr / d[i]; } } dReal *a = t + r; { dReal *Lcurr = L + r*nskip; const int *pp_r = p + r, p_r = *pp_r; const int n2_minus_r = n2-r; for (int i=0; i<n2_minus_r; Lcurr+=nskip,++i) { a[i] = dDot(Lcurr,t,r) - GETA(pp_r[i],p_r); } } a[0] += REAL(1.0); dLDLTAddTL (L + r*nskip+r, d+r, a, n2-r, nskip, tmp); } } // snip out row/column r from L and d dRemoveRowCol (L,n2,nskip,r); if (r < (n2-1)) memmove (d+r,d+r+1,(n2-r-1)*sizeof(dReal)); }
void testLDLTRemove() { int i,j,r,p[MSIZE]; dReal A[MSIZE4*MSIZE], L[MSIZE4*MSIZE], d[MSIZE], L2[MSIZE4*MSIZE], d2[MSIZE], DL2[MSIZE4*MSIZE], Atest1[MSIZE4*MSIZE], Atest2[MSIZE4*MSIZE], diff, maxdiff; HEADER; // make array of A row pointers dReal *Arows[MSIZE]; for (i=0; i<MSIZE; i++) Arows[i] = A+i*MSIZE4; // fill permutation vector for (i=0; i<MSIZE; i++) p[i]=i; dMakeRandomMatrix (A,MSIZE,MSIZE,1.0); dMultiply2 (L,A,A,MSIZE,MSIZE,MSIZE); memcpy (A,L,MSIZE4*MSIZE*sizeof(dReal)); dFactorLDLT (L,d,MSIZE,MSIZE4); maxdiff = 1e10; for (r=0; r<MSIZE; r++) { // get Atest1 = A with row/column r removed memcpy (Atest1,A,MSIZE4*MSIZE*sizeof(dReal)); dRemoveRowCol (Atest1,MSIZE,MSIZE4,r); // test that the row/column removal worked int bad = 0; for (i=0; i<MSIZE; i++) { for (j=0; j<MSIZE; j++) { if (i != r && j != r) { int ii = i; int jj = j; if (ii >= r) ii--; if (jj >= r) jj--; if (A[i*MSIZE4+j] != Atest1[ii*MSIZE4+jj]) bad = 1; } } } if (bad) printf ("\trow/col removal FAILED for row %d\n",r); // zero out last row/column of Atest1 for (i=0; i<MSIZE; i++) { Atest1[(MSIZE-1)*MSIZE4+i] = 0; Atest1[i*MSIZE4+MSIZE-1] = 0; } // get L2*D2*L2' = adjusted factorization to remove that row memcpy (L2,L,MSIZE4*MSIZE*sizeof(dReal)); memcpy (d2,d,MSIZE*sizeof(dReal)); dLDLTRemove (/*A*/ Arows,p,L2,d2,MSIZE,MSIZE,r,MSIZE4); // get Atest2 = L2*D2*L2' dClearUpperTriangle (L2,MSIZE); for (i=0; i<(MSIZE-1); i++) L2[i*MSIZE4+i] = 1.0; for (i=0; i<MSIZE; i++) L2[(MSIZE-1)*MSIZE4+i] = 0; d2[MSIZE-1] = 1; dSetZero (DL2,MSIZE4*MSIZE); for (i=0; i<(MSIZE-1); i++) { for (j=0; j<MSIZE-1; j++) DL2[i*MSIZE4+j] = L2[i*MSIZE4+j] / d2[j]; } dMultiply2 (Atest2,L2,DL2,MSIZE,MSIZE,MSIZE); diff = dMaxDifference(Atest1,Atest2,MSIZE,MSIZE); if (diff < maxdiff) maxdiff = diff; /* dPrintMatrix (Atest1,MSIZE,MSIZE); printf ("\n"); dPrintMatrix (Atest2,MSIZE,MSIZE); printf ("\n"); */ } printf ("\tmaximum difference = %.6e - %s\n",maxdiff, maxdiff > tol ? "FAILED" : "passed"); }