void dGetDiagU(SuperMatrix *L, double *diagU) { /* * Purpose * ======= * * GetDiagU extracts the main diagonal of matrix U of the LU factorization. * * Arguments * ========= * * L (input) SuperMatrix* * The factor L from the factorization Pr*A*Pc=L*U as computed by * dgstrf(). Use compressed row subscripts storage for supernodes, * i.e., L has types: Stype = SLU_SC, Dtype = SLU_D, Mtype = SLU_TRLU. * * diagU (output) double*, dimension (n) * The main diagonal of matrix U. * * Note * ==== * The diagonal blocks of the L and U matrices are stored in the L * data structures. * */ int_t i, k, nsupers; int_t fsupc, nsupr, nsupc, luptr; double *dblock, *Lval; SCformat *Lstore; Lstore = L->Store; Lval = Lstore->nzval; nsupers = Lstore->nsuper + 1; for (k = 0; k < nsupers; ++k) { fsupc = L_FST_SUPC(k); nsupc = L_FST_SUPC(k+1) - fsupc; nsupr = L_SUB_START(fsupc+1) - L_SUB_START(fsupc); luptr = L_NZ_START(fsupc); dblock = &diagU[fsupc]; for (i = 0; i < nsupc; ++i) { dblock[i] = Lval[luptr]; luptr += nsupr + 1; } } }
float cPivotGrowth(int ncols, SuperMatrix *A, int *perm_c, SuperMatrix *L, SuperMatrix *U) { /* * Purpose * ======= * * Compute the reciprocal pivot growth factor of the leading ncols columns * of the matrix, using the formula: * min_j ( max_i(abs(A_ij)) / max_i(abs(U_ij)) ) * * Arguments * ========= * * ncols (input) int * The number of columns of matrices A, L and U. * * A (input) SuperMatrix* * Original matrix A, permuted by columns, of dimension * (A->nrow, A->ncol). The type of A can be: * Stype = NC; Dtype = SLU_C; Mtype = GE. * * L (output) SuperMatrix* * The factor L from the factorization Pr*A=L*U; use compressed row * subscripts storage for supernodes, i.e., L has type: * Stype = SC; Dtype = SLU_C; Mtype = TRLU. * * U (output) SuperMatrix* * The factor U from the factorization Pr*A*Pc=L*U. Use column-wise * storage scheme, i.e., U has types: Stype = NC; * Dtype = SLU_C; Mtype = TRU. * */ NCformat *Astore; SCformat *Lstore; NCformat *Ustore; complex *Aval, *Lval, *Uval; int fsupc, nsupr, luptr, nz_in_U; int i, j, k, oldcol; int *inv_perm_c; float rpg, maxaj, maxuj; extern double slamch_(char *); float smlnum; complex *luval; complex temp_comp; /* Get machine constants. */ smlnum = slamch_("S"); rpg = 1. / smlnum; Astore = A->Store; Lstore = L->Store; Ustore = U->Store; Aval = Astore->nzval; Lval = Lstore->nzval; Uval = Ustore->nzval; inv_perm_c = (int *) SUPERLU_MALLOC(A->ncol*sizeof(int)); for (j = 0; j < A->ncol; ++j) inv_perm_c[perm_c[j]] = j; for (k = 0; k <= Lstore->nsuper; ++k) { fsupc = L_FST_SUPC(k); nsupr = L_SUB_START(fsupc+1) - L_SUB_START(fsupc); luptr = L_NZ_START(fsupc); luval = &Lval[luptr]; nz_in_U = 1; for (j = fsupc; j < L_FST_SUPC(k+1) && j < ncols; ++j) { maxaj = 0.; oldcol = inv_perm_c[j]; for (i = Astore->colptr[oldcol]; i < Astore->colptr[oldcol+1]; ++i) maxaj = SUPERLU_MAX( maxaj, slu_c_abs1( &Aval[i]) ); maxuj = 0.; for (i = Ustore->colptr[j]; i < Ustore->colptr[j+1]; i++) maxuj = SUPERLU_MAX( maxuj, slu_c_abs1( &Uval[i]) ); /* Supernode */ for (i = 0; i < nz_in_U; ++i) maxuj = SUPERLU_MAX( maxuj, slu_c_abs1( &luval[i]) ); ++nz_in_U; luval += nsupr; if ( maxuj == 0. ) rpg = SUPERLU_MIN( rpg, 1.); else rpg = SUPERLU_MIN( rpg, maxaj / maxuj ); } if ( j >= ncols ) break; } SUPERLU_FREE(inv_perm_c); return (rpg); }
int dgst01(int m, int n, SuperMatrix *A, SuperMatrix *L, SuperMatrix *U, int *perm_c, int *perm_r, double *resid) { /* Purpose ======= DGST01 reconstructs a matrix A from its L*U factorization and computes the residual norm(L*U - A) / ( N * norm(A) * EPS ), where EPS is the machine epsilon. Arguments ========== M (input) INT The number of rows of the matrix A. M >= 0. N (input) INT The number of columns of the matrix A. N >= 0. A (input) SuperMatrix *, dimension (A->nrow, A->ncol) The original M x N matrix A. L (input) SuperMatrix *, dimension (L->nrow, L->ncol) The factor matrix L. U (input) SuperMatrix *, dimension (U->nrow, U->ncol) The factor matrix U. perm_c (input) INT array, dimension (N) The column permutation from DGSTRF. perm_r (input) INT array, dimension (M) The pivot indices from DGSTRF. RESID (output) DOUBLE* norm(L*U - A) / ( N * norm(A) * EPS ) ===================================================================== */ /* Local variables */ double zero = 0.0; int i, j, k, arow, lptr,isub, urow, superno, fsupc, u_part; double utemp, comp_temp; double anorm, tnorm, cnorm; double eps; double *work; SCformat *Lstore; NCformat *Astore, *Ustore; double *Aval, *Lval, *Uval; int *colbeg, *colend; /* Function prototypes */ extern double dlangs(char *, SuperMatrix *); /* Quick exit if M = 0 or N = 0. */ if (m <= 0 || n <= 0) { *resid = 0.f; return 0; } work = (double *)doubleCalloc(m); Astore = A->Store; Aval = Astore->nzval; Lstore = L->Store; Lval = Lstore->nzval; Ustore = U->Store; Uval = Ustore->nzval; colbeg = intMalloc(n); colend = intMalloc(n); for (i = 0; i < n; i++) { colbeg[perm_c[i]] = Astore->colptr[i]; colend[perm_c[i]] = Astore->colptr[i+1]; } /* Determine EPS and the norm of A. */ eps = dmach("Epsilon"); anorm = dlangs("1", A); cnorm = 0.; /* Compute the product L*U, one column at a time */ for (k = 0; k < n; ++k) { /* The U part outside the rectangular supernode */ for (i = U_NZ_START(k); i < U_NZ_START(k+1); ++i) { urow = U_SUB(i); utemp = Uval[i]; superno = Lstore->col_to_sup[urow]; fsupc = L_FST_SUPC(superno); u_part = urow - fsupc + 1; lptr = L_SUB_START(fsupc) + u_part; work[L_SUB(lptr-1)] -= utemp; /* L_ii = 1 */ for (j = L_NZ_START(urow) + u_part; j < L_NZ_START(urow+1); ++j) { isub = L_SUB(lptr); work[isub] -= Lval[j] * utemp; ++lptr; } } /* The U part inside the rectangular supernode */ superno = Lstore->col_to_sup[k]; fsupc = L_FST_SUPC(superno); urow = L_NZ_START(k); for (i = fsupc; i <= k; ++i) { utemp = Lval[urow++]; u_part = i - fsupc + 1; lptr = L_SUB_START(fsupc) + u_part; work[L_SUB(lptr-1)] -= utemp; /* L_ii = 1 */ for (j = L_NZ_START(i)+u_part; j < L_NZ_START(i+1); ++j) { isub = L_SUB(lptr); work[isub] -= Lval[j] * utemp; ++lptr; } } /* Now compute A[k] - (L*U)[k] (Both matrices may be permuted.) */ for (i = colbeg[k]; i < colend[k]; ++i) { arow = Astore->rowind[i]; work[perm_r[arow]] += Aval[i]; } /* Now compute the 1-norm of the column vector work */ tnorm = 0.; for (i = 0; i < m; ++i) { tnorm += fabs(work[i]); work[i] = zero; } cnorm = SUPERLU_MAX(tnorm, cnorm); } *resid = cnorm; if (anorm <= 0.f) { if (*resid != 0.f) { *resid = 1.f / eps; } } else { *resid = *resid / (float) n / anorm / eps; } SUPERLU_FREE(work); SUPERLU_FREE(colbeg); SUPERLU_FREE(colend); return 0; /* End of DGST01 */ } /* dgst01_ */
int sp_dtrsv_dist(char *uplo, char *trans, char *diag, SuperMatrix *L, SuperMatrix *U, double *x, int *info) { /* * Purpose * ======= * * sp_dtrsv_dist() solves one of the systems of equations * A*x = b, or A'*x = b, * where b and x are n element vectors and A is a sparse unit , or * non-unit, upper or lower triangular matrix. * No test for singularity or near-singularity is included in this * routine. Such tests must be performed before calling this routine. * * Parameters * ========== * * uplo - (input) char* * On entry, uplo specifies whether the matrix is an upper or * lower triangular matrix as follows: * uplo = 'U' or 'u' A is an upper triangular matrix. * uplo = 'L' or 'l' A is a lower triangular matrix. * * trans - (input) char* * On entry, trans specifies the equations to be solved as * follows: * trans = 'N' or 'n' A*x = b. * trans = 'T' or 't' A'*x = b. * trans = 'C' or 'c' A'*x = b. * * diag - (input) char* * On entry, diag specifies whether or not A is unit * triangular as follows: * diag = 'U' or 'u' A is assumed to be unit triangular. * diag = 'N' or 'n' A is not assumed to be unit * triangular. * * L - (input) SuperMatrix* * The factor L from the factorization Pr*A*Pc=L*U. Use * compressed row subscripts storage for supernodes, * i.e., L has types: Stype = SC, Dtype = D, Mtype = TRLU. * * U - (input) SuperMatrix* * The factor U from the factorization Pr*A*Pc=L*U. * U has types: Stype = NC, Dtype = D, Mtype = TRU. * * x - (input/output) double* * Before entry, the incremented array X must contain the n * element right-hand side vector b. On exit, X is overwritten * with the solution vector x. * * info - (output) int* * If *info = -i, the i-th argument had an illegal value. * */ #ifdef _CRAY _fcd ftcs1, ftcs2, ftcs3; #endif SCformat *Lstore; NCformat *Ustore; double *Lval, *Uval; int incx = 1, incy = 1; double alpha = 1.0, beta = 1.0; int nrow; int fsupc, nsupr, nsupc, luptr, istart, irow; int i, k, iptr, jcol; double *work; flops_t solve_ops; extern SuperLUStat_t SuperLUStat; /* Test the input parameters */ *info = 0; if ( !lsame_(uplo,"L") && !lsame_(uplo, "U") ) *info = -1; else if ( !lsame_(trans, "N") && !lsame_(trans, "T") ) *info = -2; else if ( !lsame_(diag, "U") && !lsame_(diag, "N") ) *info = -3; else if ( L->nrow != L->ncol || L->nrow < 0 ) *info = -4; else if ( U->nrow != U->ncol || U->nrow < 0 ) *info = -5; if ( *info ) { i = -(*info); xerbla_("sp_dtrsv_dist", &i); return 0; } Lstore = L->Store; Lval = Lstore->nzval; Ustore = U->Store; Uval = Ustore->nzval; solve_ops = 0; if ( !(work = doubleCalloc_dist(L->nrow)) ) ABORT("Malloc fails for work in sp_dtrsv_dist()."); if ( lsame_(trans, "N") ) { /* Form x := inv(A)*x. */ if ( lsame_(uplo, "L") ) { /* Form x := inv(L)*x */ if ( L->nrow == 0 ) return 0; /* Quick return */ for (k = 0; k <= Lstore->nsuper; k++) { fsupc = L_FST_SUPC(k); istart = L_SUB_START(fsupc); nsupr = L_SUB_START(fsupc+1) - istart; nsupc = L_FST_SUPC(k+1) - fsupc; luptr = L_NZ_START(fsupc); nrow = nsupr - nsupc; solve_ops += nsupc * (nsupc - 1); solve_ops += 2 * nrow * nsupc; if ( nsupc == 1 ) { for (iptr=istart+1; iptr < L_SUB_START(fsupc+1); ++iptr) { irow = L_SUB(iptr); ++luptr; x[irow] -= x[fsupc] * Lval[luptr]; } } else { #ifdef USE_VENDOR_BLAS #ifdef _CRAY ftcs1 = _cptofcd("L", strlen("L")); ftcs2 = _cptofcd("N", strlen("N")); ftcs3 = _cptofcd("U", strlen("U")); STRSV(ftcs1, ftcs2, ftcs3, &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); SGEMV(ftcs2, &nrow, &nsupc, &alpha, &Lval[luptr+nsupc], &nsupr, &x[fsupc], &incx, &beta, &work[0], &incy); #else dtrsv_("L", "N", "U", &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); dgemv_("N", &nrow, &nsupc, &alpha, &Lval[luptr+nsupc], &nsupr, &x[fsupc], &incx, &beta, &work[0], &incy); #endif /* _CRAY */ #else dlsolve ( nsupr, nsupc, &Lval[luptr], &x[fsupc]); dmatvec ( nsupr, nsupr-nsupc, nsupc, &Lval[luptr+nsupc], &x[fsupc], &work[0] ); #endif iptr = istart + nsupc; for (i = 0; i < nrow; ++i, ++iptr) { irow = L_SUB(iptr); x[irow] -= work[i]; /* Scatter */ work[i] = 0.0; } } } /* for k ... */ } else { /* Form x := inv(U)*x */ if ( U->nrow == 0 ) return 0; /* Quick return */ for (k = Lstore->nsuper; k >= 0; k--) { fsupc = L_FST_SUPC(k); nsupr = L_SUB_START(fsupc+1) - L_SUB_START(fsupc); nsupc = L_FST_SUPC(k+1) - fsupc; luptr = L_NZ_START(fsupc); solve_ops += nsupc * (nsupc + 1); if ( nsupc == 1 ) { x[fsupc] /= Lval[luptr]; for (i = U_NZ_START(fsupc); i < U_NZ_START(fsupc+1); ++i) { irow = U_SUB(i); x[irow] -= x[fsupc] * Uval[i]; } } else { #ifdef USE_VENDOR_BLAS #ifdef _CRAY ftcs1 = _cptofcd("U", strlen("U")); ftcs2 = _cptofcd("N", strlen("N")); STRSV(ftcs1, ftcs2, ftcs2, &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #else dtrsv_("U", "N", "N", &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #endif #else dusolve ( nsupr, nsupc, &Lval[luptr], &x[fsupc] ); #endif for (jcol = fsupc; jcol < L_FST_SUPC(k+1); jcol++) { solve_ops += 2*(U_NZ_START(jcol+1) - U_NZ_START(jcol)); for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); i++) { irow = U_SUB(i); x[irow] -= x[jcol] * Uval[i]; } } } } /* for k ... */ } } else { /* Form x := inv(A')*x */ if ( lsame_(uplo, "L") ) { /* Form x := inv(L')*x */ if ( L->nrow == 0 ) return 0; /* Quick return */ for (k = Lstore->nsuper; k >= 0; --k) { fsupc = L_FST_SUPC(k); istart = L_SUB_START(fsupc); nsupr = L_SUB_START(fsupc+1) - istart; nsupc = L_FST_SUPC(k+1) - fsupc; luptr = L_NZ_START(fsupc); solve_ops += 2 * (nsupr - nsupc) * nsupc; for (jcol = fsupc; jcol < L_FST_SUPC(k+1); jcol++) { iptr = istart + nsupc; for (i = L_NZ_START(jcol) + nsupc; i < L_NZ_START(jcol+1); i++) { irow = L_SUB(iptr); x[jcol] -= x[irow] * Lval[i]; iptr++; } } if ( nsupc > 1 ) { solve_ops += nsupc * (nsupc - 1); #ifdef _CRAY ftcs1 = _cptofcd("L", strlen("L")); ftcs2 = _cptofcd("T", strlen("T")); ftcs3 = _cptofcd("U", strlen("U")); STRSV(ftcs1, ftcs2, ftcs3, &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #else dtrsv_("L", "T", "U", &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #endif } } } else { /* Form x := inv(U')*x */ if ( U->nrow == 0 ) return 0; /* Quick return */ for (k = 0; k <= Lstore->nsuper; k++) { fsupc = L_FST_SUPC(k); nsupr = L_SUB_START(fsupc+1) - L_SUB_START(fsupc); nsupc = L_FST_SUPC(k+1) - fsupc; luptr = L_NZ_START(fsupc); for (jcol = fsupc; jcol < L_FST_SUPC(k+1); jcol++) { solve_ops += 2*(U_NZ_START(jcol+1) - U_NZ_START(jcol)); for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); i++) { irow = U_SUB(i); x[jcol] -= x[irow] * Uval[i]; } } solve_ops += nsupc * (nsupc + 1); if ( nsupc == 1 ) { x[fsupc] /= Lval[luptr]; } else { #ifdef _CRAY ftcs1 = _cptofcd("U", strlen("U")); ftcs2 = _cptofcd("T", strlen("T")); ftcs3 = _cptofcd("N", strlen("N")); STRSV( ftcs1, ftcs2, ftcs3, &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #else dtrsv_("U", "T", "N", &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #endif } } /* for k ... */ } } SuperLUStat.ops[SOLVE] += solve_ops; SUPERLU_FREE(work); return 0; }
float sPivotGrowth(int ncols, SuperMatrix *A, int *perm_c, SuperMatrix *L, SuperMatrix *U) { NCformat *Astore; SCformat *Lstore; NCformat *Ustore; float *Aval, *Lval, *Uval; int fsupc, nsupr, luptr, nz_in_U; int i, j, k, oldcol; int *inv_perm_c; float rpg, maxaj, maxuj; extern double slamch_(char *); float smlnum; float *luval; /* Get machine constants. */ smlnum = slamch_("S"); rpg = 1. / smlnum; Astore = A->Store; Lstore = L->Store; Ustore = U->Store; Aval = Astore->nzval; Lval = Lstore->nzval; Uval = Ustore->nzval; inv_perm_c = (int *) SUPERLU_MALLOC(A->ncol*sizeof(int)); for (j = 0; j < A->ncol; ++j) inv_perm_c[perm_c[j]] = j; for (k = 0; k <= Lstore->nsuper; ++k) { fsupc = L_FST_SUPC(k); nsupr = L_SUB_START(fsupc+1) - L_SUB_START(fsupc); luptr = L_NZ_START(fsupc); luval = &Lval[luptr]; nz_in_U = 1; for (j = fsupc; j < L_FST_SUPC(k+1) && j < ncols; ++j) { maxaj = 0.; oldcol = inv_perm_c[j]; for (i = Astore->colptr[oldcol]; i < Astore->colptr[oldcol+1]; ++i) maxaj = SUPERLU_MAX( maxaj, fabs(Aval[i]) ); maxuj = 0.; for (i = Ustore->colptr[j]; i < Ustore->colptr[j+1]; i++) maxuj = SUPERLU_MAX( maxuj, fabs(Uval[i]) ); /* Supernode */ for (i = 0; i < nz_in_U; ++i) maxuj = SUPERLU_MAX( maxuj, fabs(luval[i]) ); ++nz_in_U; luval += nsupr; if ( maxuj == 0. ) rpg = SUPERLU_MIN( rpg, 1.); else rpg = SUPERLU_MIN( rpg, maxaj / maxuj ); } if ( j >= ncols ) break; } SUPERLU_FREE(inv_perm_c); return (rpg); }
void sgstrs (trans_t trans, SuperMatrix *L, SuperMatrix *U, int *perm_c, int *perm_r, SuperMatrix *B, SuperLUStat_t *stat, int *info) { /* * Purpose * ======= * * SGSTRS solves a system of linear equations A*X=B or A'*X=B * with A sparse and B dense, using the LU factorization computed by * SGSTRF. * * See supermatrix.h for the definition of 'SuperMatrix' structure. * * Arguments * ========= * * trans (input) trans_t * Specifies the form of the system of equations: * = NOTRANS: A * X = B (No transpose) * = TRANS: A'* X = B (Transpose) * = CONJ: A**H * X = B (Conjugate transpose) * * L (input) SuperMatrix* * The factor L from the factorization Pr*A*Pc=L*U as computed by * sgstrf(). Use compressed row subscripts storage for supernodes, * i.e., L has types: Stype = SLU_SC, Dtype = SLU_S, Mtype = SLU_TRLU. * * U (input) SuperMatrix* * The factor U from the factorization Pr*A*Pc=L*U as computed by * sgstrf(). Use column-wise storage scheme, i.e., U has types: * Stype = SLU_NC, Dtype = SLU_S, Mtype = SLU_TRU. * * perm_c (input) int*, dimension (L->ncol) * Column permutation vector, which defines the * permutation matrix Pc; perm_c[i] = j means column i of A is * in position j in A*Pc. * * perm_r (input) int*, dimension (L->nrow) * Row permutation vector, which defines the permutation matrix Pr; * perm_r[i] = j means row i of A is in position j in Pr*A. * * B (input/output) SuperMatrix* * B has types: Stype = SLU_DN, Dtype = SLU_S, Mtype = SLU_GE. * On entry, the right hand side matrix. * On exit, the solution matrix if info = 0; * * stat (output) SuperLUStat_t* * Record the statistics on runtime and floating-point operation count. * See util.h for the definition of 'SuperLUStat_t'. * * info (output) int* * = 0: successful exit * < 0: if info = -i, the i-th argument had an illegal value * */ #ifdef _CRAY _fcd ftcs1, ftcs2, ftcs3, ftcs4; #endif int incx = 1, incy = 1; #ifdef USE_VENDOR_BLAS float alpha = 1.0, beta = 1.0; float *work_col; #endif DNformat *Bstore; float *Bmat; SCformat *Lstore; NCformat *Ustore; float *Lval, *Uval; int fsupc, nrow, nsupr, nsupc, luptr, istart, irow; int i, j, k, iptr, jcol, n, ldb, nrhs; float *work, *rhs_work, *soln; flops_t solve_ops; void sprint_soln(); /* Test input parameters ... */ *info = 0; Bstore = B->Store; ldb = Bstore->lda; nrhs = B->ncol; if ( trans != NOTRANS && trans != TRANS && trans != CONJ ) *info = -1; else if ( L->nrow != L->ncol || L->nrow < 0 || L->Stype != SLU_SC || L->Dtype != SLU_S || L->Mtype != SLU_TRLU ) *info = -2; else if ( U->nrow != U->ncol || U->nrow < 0 || U->Stype != SLU_NC || U->Dtype != SLU_S || U->Mtype != SLU_TRU ) *info = -3; else if ( ldb < SUPERLU_MAX(0, L->nrow) || B->Stype != SLU_DN || B->Dtype != SLU_S || B->Mtype != SLU_GE ) *info = -6; if ( *info ) { i = -(*info); xerbla_("sgstrs", &i); return; } n = L->nrow; work = floatCalloc(n * nrhs); if ( !work ) ABORT("Malloc fails for local work[]."); soln = floatMalloc(n); if ( !soln ) ABORT("Malloc fails for local soln[]."); Bmat = Bstore->nzval; Lstore = L->Store; Lval = Lstore->nzval; Ustore = U->Store; Uval = Ustore->nzval; solve_ops = 0; if ( trans == NOTRANS ) { /* Permute right hand sides to form Pr*B */ for (i = 0; i < nrhs; i++) { rhs_work = &Bmat[i*ldb]; for (k = 0; k < n; k++) soln[perm_r[k]] = rhs_work[k]; for (k = 0; k < n; k++) rhs_work[k] = soln[k]; } /* Forward solve PLy=Pb. */ for (k = 0; k <= Lstore->nsuper; k++) { fsupc = L_FST_SUPC(k); istart = L_SUB_START(fsupc); nsupr = L_SUB_START(fsupc+1) - istart; nsupc = L_FST_SUPC(k+1) - fsupc; nrow = nsupr - nsupc; solve_ops += nsupc * (nsupc - 1) * nrhs; solve_ops += 2 * nrow * nsupc * nrhs; if ( nsupc == 1 ) { for (j = 0; j < nrhs; j++) { rhs_work = &Bmat[j*ldb]; luptr = L_NZ_START(fsupc); for (iptr=istart+1; iptr < L_SUB_START(fsupc+1); iptr++){ irow = L_SUB(iptr); ++luptr; rhs_work[irow] -= rhs_work[fsupc] * Lval[luptr]; } } } else { luptr = L_NZ_START(fsupc); #ifdef USE_VENDOR_BLAS #ifdef _CRAY ftcs1 = _cptofcd("L", strlen("L")); ftcs2 = _cptofcd("N", strlen("N")); ftcs3 = _cptofcd("U", strlen("U")); STRSM( ftcs1, ftcs1, ftcs2, ftcs3, &nsupc, &nrhs, &alpha, &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb); SGEMM( ftcs2, ftcs2, &nrow, &nrhs, &nsupc, &alpha, &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb, &beta, &work[0], &n ); #else strsm_("L", "L", "N", "U", &nsupc, &nrhs, &alpha, &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb); sgemm_( "N", "N", &nrow, &nrhs, &nsupc, &alpha, &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb, &beta, &work[0], &n ); #endif for (j = 0; j < nrhs; j++) { rhs_work = &Bmat[j*ldb]; work_col = &work[j*n]; iptr = istart + nsupc; for (i = 0; i < nrow; i++) { irow = L_SUB(iptr); rhs_work[irow] -= work_col[i]; /* Scatter */ work_col[i] = 0.0; iptr++; } } #else for (j = 0; j < nrhs; j++) { rhs_work = &Bmat[j*ldb]; slsolve (nsupr, nsupc, &Lval[luptr], &rhs_work[fsupc]); smatvec (nsupr, nrow, nsupc, &Lval[luptr+nsupc], &rhs_work[fsupc], &work[0] ); iptr = istart + nsupc; for (i = 0; i < nrow; i++) { irow = L_SUB(iptr); rhs_work[irow] -= work[i]; work[i] = 0.0; iptr++; } } #endif } /* else ... */ } /* for L-solve */ #ifdef DEBUG printf("After L-solve: y=\n"); sprint_soln(n, nrhs, Bmat); #endif /* * Back solve Ux=y. */ for (k = Lstore->nsuper; k >= 0; k--) { fsupc = L_FST_SUPC(k); istart = L_SUB_START(fsupc); nsupr = L_SUB_START(fsupc+1) - istart; nsupc = L_FST_SUPC(k+1) - fsupc; luptr = L_NZ_START(fsupc); solve_ops += nsupc * (nsupc + 1) * nrhs; if ( nsupc == 1 ) { rhs_work = &Bmat[0]; for (j = 0; j < nrhs; j++) { rhs_work[fsupc] /= Lval[luptr]; rhs_work += ldb; } } else { #ifdef USE_VENDOR_BLAS #ifdef _CRAY ftcs1 = _cptofcd("L", strlen("L")); ftcs2 = _cptofcd("U", strlen("U")); ftcs3 = _cptofcd("N", strlen("N")); STRSM( ftcs1, ftcs2, ftcs3, ftcs3, &nsupc, &nrhs, &alpha, &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb); #else strsm_("L", "U", "N", "N", &nsupc, &nrhs, &alpha, &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb); #endif #else for (j = 0; j < nrhs; j++) susolve ( nsupr, nsupc, &Lval[luptr], &Bmat[fsupc+j*ldb] ); #endif } for (j = 0; j < nrhs; ++j) { rhs_work = &Bmat[j*ldb]; for (jcol = fsupc; jcol < fsupc + nsupc; jcol++) { solve_ops += 2*(U_NZ_START(jcol+1) - U_NZ_START(jcol)); for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); i++ ){ irow = U_SUB(i); rhs_work[irow] -= rhs_work[jcol] * Uval[i]; } } } } /* for U-solve */ #ifdef DEBUG printf("After U-solve: x=\n"); sprint_soln(n, nrhs, Bmat); #endif /* Compute the final solution X := Pc*X. */ for (i = 0; i < nrhs; i++) { rhs_work = &Bmat[i*ldb]; for (k = 0; k < n; k++) soln[k] = rhs_work[perm_c[k]]; for (k = 0; k < n; k++) rhs_work[k] = soln[k]; } stat->ops[SOLVE] = solve_ops; } else { /* Solve A'*X=B or CONJ(A)*X=B */ /* Permute right hand sides to form Pc'*B. */ for (i = 0; i < nrhs; i++) { rhs_work = &Bmat[i*ldb]; for (k = 0; k < n; k++) soln[perm_c[k]] = rhs_work[k]; for (k = 0; k < n; k++) rhs_work[k] = soln[k]; } stat->ops[SOLVE] = 0; for (k = 0; k < nrhs; ++k) { /* Multiply by inv(U'). */ sp_strsv("U", "T", "N", L, U, &Bmat[k*ldb], stat, info); /* Multiply by inv(L'). */ sp_strsv("L", "T", "U", L, U, &Bmat[k*ldb], stat, info); } /* Compute the final solution X := Pr'*X (=inv(Pr)*X) */ for (i = 0; i < nrhs; i++) { rhs_work = &Bmat[i*ldb]; for (k = 0; k < n; k++) soln[k] = rhs_work[perm_r[k]]; for (k = 0; k < n; k++) rhs_work[k] = soln[k]; } } SUPERLU_FREE(work); SUPERLU_FREE(soln); }
void dgstrs(trans_t trans, SuperMatrix *L, SuperMatrix *U, int *perm_r, int *perm_c, SuperMatrix *B, Gstat_t *Gstat, int *info) { /* * -- SuperLU MT routine (version 1.0) -- * Univ. of California Berkeley, Xerox Palo Alto Research Center, * and Lawrence Berkeley National Lab. * August 15, 1997 * * Purpose * ======= * * dgstrs() solves a system of linear equations A*X=B or A'*X=B * with A sparse and B dense, using the LU factorization computed by * pdgstrf(). * * Arguments * ========= * * trans (input) Specifies the form of the system of equations: * = NOTRANS: A * X = B (No transpose) * = TRANS: A'* X = B (Transpose) * * L (input) SuperMatrix* * The factor L from the factorization Pr*A*Pc=L*U as computed by * pdgstrf(). Use compressed row subscripts storage for supernodes, * i.e., L has types: Stype = SCP, Dtype = _D, Mtype = TRLU. * * U (input) SuperMatrix* * The factor U from the factorization Pr*A*Pc=L*U as computed by * pdgstrf(). Use column-wise storage scheme, i.e., U has types: * Stype = NCP, Dtype = _D, Mtype = TRU. * * perm_r (input) int* * Row permutation vector of size L->nrow, which defines the * permutation matrix Pr; perm_r[i] = j means row i of A is in * position j in Pr*A. * * perm_c (int*) dimension A->ncol * Column permutation vector, which defines the * permutation matrix Pc; perm_c[i] = j means column i of A is * in position j in A*Pc. * * B (input/output) SuperMatrix* * B has types: Stype = DN, Dtype = _D, Mtype = GE. * On entry, the right hand side matrix. * On exit, the solution matrix if info = 0; * * Gstat (output) Gstat_t* * Record all the statistics about the triangular solves; * See Gstat_t structure defined in util.h. * * info (output) Diagnostics * = 0: successful exit * < 0: if info = -i, the i-th argument had an illegal value * */ #if ( MACH==CRAY_PVP ) _fcd ftcs1, ftcs2, ftcs3, ftcs4; #endif #ifdef USE_VENDOR_BLAS int incx = 1, incy = 1; double alpha = 1.0, beta = 1.0; #endif register int j, k, jcol, iptr, luptr, ksupno, istart, irow, bptr; register int fsupc, nsuper; int i, n, nsupc, nsupr, nrow, nrhs, ldb; int *supno; DNformat *Bstore; SCPformat *Lstore; NCPformat *Ustore; double *Lval, *Uval, *Bmat; double *work, *work_col, *rhs_work, *soln; flops_t solve_ops; void dprint_soln(); /* Test input parameters ... */ *info = 0; Bstore = B->Store; ldb = Bstore->lda; nrhs = B->ncol; if ( trans != NOTRANS && trans != TRANS ) *info = -1; else if ( L->nrow != L->ncol || L->nrow < 0 ) *info = -3; else if ( U->nrow != U->ncol || U->nrow < 0 ) *info = -4; else if ( ldb < MAX(0, L->nrow) ) *info = -6; if ( *info ) { i = -(*info); xerbla_("dgstrs", &i); return; } n = L->nrow; work = doubleCalloc(n * nrhs); if ( !work ) ABORT("Malloc fails for local work[]."); soln = doubleMalloc(n); if ( !soln ) ABORT("Malloc fails for local soln[]."); Bmat = Bstore->nzval; Lstore = L->Store; Lval = Lstore->nzval; Ustore = U->Store; Uval = Ustore->nzval; supno = Lstore->col_to_sup; nsuper = Lstore->nsuper; solve_ops = 0; if ( trans == NOTRANS ) { /* Permute right hand sides to form Pr*B */ for (i = 0, bptr = 0; i < nrhs; i++, bptr += ldb) { rhs_work = &Bmat[bptr]; for (k = 0; k < n; k++) soln[perm_r[k]] = rhs_work[k]; for (k = 0; k < n; k++) rhs_work[k] = soln[k]; } /* Forward solve PLy=Pb. */ /*>> for (k = 0; k < n; k += nsupc) { ksupno = supno[k]; */ for (ksupno = 0; ksupno <= nsuper; ++ksupno) { fsupc = L_FST_SUPC(ksupno); istart = L_SUB_START(fsupc); nsupr = L_SUB_END(fsupc) - istart; nsupc = L_LAST_SUPC(ksupno) - fsupc; nrow = nsupr - nsupc; solve_ops += nsupc * (nsupc - 1) * nrhs; solve_ops += 2 * nrow * nsupc * nrhs; if ( nsupc == 1 ) { for (j = 0, bptr = 0; j < nrhs; j++, bptr += ldb) { rhs_work = &Bmat[bptr]; luptr = L_NZ_START(fsupc); for (iptr=istart+1; iptr < L_SUB_END(fsupc); iptr++){ irow = L_SUB(iptr); ++luptr; rhs_work[irow] -= rhs_work[fsupc] * Lval[luptr]; } } } else { luptr = L_NZ_START(fsupc); #ifdef USE_VENDOR_BLAS #if ( MACH==CRAY_PVP ) ftcs1 = _cptofcd("L", strlen("L")); ftcs2 = _cptofcd("N", strlen("N")); ftcs3 = _cptofcd("U", strlen("U")); STRSM(ftcs1, ftcs1, ftcs2, ftcs3, &nsupc, &nrhs, &alpha, &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb); SGEMM(ftcs2, ftcs2, &nrow, &nrhs, &nsupc, &alpha, &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb, &beta, &work[0], &n ); #else dtrsm_("L", "L", "N", "U", &nsupc, &nrhs, &alpha, &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb); dgemm_( "N", "N", &nrow, &nrhs, &nsupc, &alpha, &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb, &beta, &work[0], &n ); #endif for (j = 0, bptr = 0; j < nrhs; j++, bptr += ldb) { rhs_work = &Bmat[bptr]; work_col = &work[j*n]; iptr = istart + nsupc; for (i = 0; i < nrow; i++) { irow = L_SUB(iptr); rhs_work[irow] -= work_col[i]; /* Scatter */ work_col[i] = 0.0; iptr++; } } #else for (j = 0, bptr = 0; j < nrhs; j++, bptr += ldb) { rhs_work = &Bmat[bptr]; dlsolve (nsupr, nsupc, &Lval[luptr], &rhs_work[fsupc]); dmatvec (nsupr, nrow, nsupc, &Lval[luptr+nsupc], &rhs_work[fsupc], &work[0] ); iptr = istart + nsupc; for (i = 0; i < nrow; i++) { irow = L_SUB(iptr); rhs_work[irow] -= work[i]; work[i] = 0.0; iptr++; } } #endif } /* if-else: nsupc == 1 ... */ } /* for L-solve */ #if ( DEBUGlevel>=2 ) printf("After L-solve: y=\n"); dprint_soln(n, nrhs, Bmat); #endif /* * Back solve Ux=y. */ /*>> for (k = n-1; k >= 0; k -= nsupc) { ksupno = supno[k]; */ for (ksupno = nsuper; ksupno >= 0; --ksupno) { fsupc = L_FST_SUPC(ksupno); istart = L_SUB_START(fsupc); nsupr = L_SUB_END(fsupc) - istart; nsupc = L_LAST_SUPC(ksupno) - fsupc; luptr = L_NZ_START(fsupc); solve_ops += nsupc * (nsupc + 1) * nrhs; /* dense triangular matrix */ if ( nsupc == 1 ) { rhs_work = &Bmat[0]; for (j = 0; j < nrhs; j++) { rhs_work[fsupc] /= Lval[luptr]; rhs_work += ldb; } } else { #ifdef USE_VENDOR_BLAS #if ( MACH==CRAY_PVP ) ftcs1 = _cptofcd("L", strlen("L")); ftcs2 = _cptofcd("U", strlen("U")); ftcs3 = _cptofcd("N", strlen("N")); STRSM(ftcs1, ftcs2, ftcs3, ftcs3, &nsupc, &nrhs, &alpha, &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb); #else dtrsm_("L", "U", "N", "N", &nsupc, &nrhs, &alpha, &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb); #endif #else for (j = 0, bptr = fsupc; j < nrhs; j++, bptr += ldb) { dusolve (nsupr, nsupc, &Lval[luptr], &Bmat[bptr]); } #endif } /* matrix-vector update */ for (j = 0, bptr = 0; j < nrhs; ++j, bptr += ldb) { rhs_work = &Bmat[bptr]; for (jcol = fsupc; jcol < fsupc + nsupc; jcol++) { solve_ops += 2*(U_NZ_END(jcol) - U_NZ_START(jcol)); for (i = U_NZ_START(jcol); i < U_NZ_END(jcol); i++ ){ irow = U_SUB(i); rhs_work[irow] -= rhs_work[jcol] * Uval[i]; } } } } /* for U-solve */ #if ( DEBUGlevel>=2 ) printf("After U-solve: x=\n"); dprint_soln(n, nrhs, Bmat); #endif /* Compute the final solution X <= Pc*X. */ for (i = 0, bptr = 0; i < nrhs; i++, bptr += ldb) { rhs_work = &Bmat[bptr]; for (k = 0; k < n; k++) soln[k] = rhs_work[perm_c[k]]; for (k = 0; k < n; k++) rhs_work[k] = soln[k]; } } else { /* Solve A'*X=B */ /* Permute right hand sides to form Pc'*B. */ for (i = 0, bptr = 0; i < nrhs; i++, bptr += ldb) { rhs_work = &Bmat[bptr]; for (k = 0; k < n; k++) soln[perm_c[k]] = rhs_work[k]; for (k = 0; k < n; k++) rhs_work[k] = soln[k]; } for (k = 0; k < nrhs; ++k) { /* Multiply by inv(U'). */ sp_dtrsv("U", "T", "N", L, U, &Bmat[k*ldb], info); /* Multiply by inv(L'). */ sp_dtrsv("L", "T", "U", L, U, &Bmat[k*ldb], info); } /* Compute the final solution X <= Pr'*X (=inv(Pr)*X) */ for (i = 0, bptr = 0; i < nrhs; i++, bptr += ldb) { rhs_work = &Bmat[bptr]; for (k = 0; k < n; k++) soln[k] = rhs_work[perm_r[k]]; for (k = 0; k < n; k++) rhs_work[k] = soln[k]; } } /* if-else trans */ Gstat->ops[TRISOLVE] = solve_ops; SUPERLU_FREE(work); SUPERLU_FREE(soln); }
/*! \brief Solves one of the systems of equations A*x = b, or A'*x = b * * <pre> * Purpose * ======= * * sp_ctrsv() solves one of the systems of equations * A*x = b, or A'*x = b, * where b and x are n element vectors and A is a sparse unit , or * non-unit, upper or lower triangular matrix. * No test for singularity or near-singularity is included in this * routine. Such tests must be performed before calling this routine. * * Parameters * ========== * * uplo - (input) char* * On entry, uplo specifies whether the matrix is an upper or * lower triangular matrix as follows: * uplo = 'U' or 'u' A is an upper triangular matrix. * uplo = 'L' or 'l' A is a lower triangular matrix. * * trans - (input) char* * On entry, trans specifies the equations to be solved as * follows: * trans = 'N' or 'n' A*x = b. * trans = 'T' or 't' A'*x = b. * trans = 'C' or 'c' A^H*x = b. * * diag - (input) char* * On entry, diag specifies whether or not A is unit * triangular as follows: * diag = 'U' or 'u' A is assumed to be unit triangular. * diag = 'N' or 'n' A is not assumed to be unit * triangular. * * L - (input) SuperMatrix* * The factor L from the factorization Pr*A*Pc=L*U. Use * compressed row subscripts storage for supernodes, * i.e., L has types: Stype = SC, Dtype = SLU_C, Mtype = TRLU. * * U - (input) SuperMatrix* * The factor U from the factorization Pr*A*Pc=L*U. * U has types: Stype = NC, Dtype = SLU_C, Mtype = TRU. * * x - (input/output) complex* * Before entry, the incremented array X must contain the n * element right-hand side vector b. On exit, X is overwritten * with the solution vector x. * * info - (output) int* * If *info = -i, the i-th argument had an illegal value. * </pre> */ int sp_ctrsv(char *uplo, char *trans, char *diag, SuperMatrix *L, SuperMatrix *U, complex *x, SuperLUStat_t *stat, int *info) { #ifdef _CRAY _fcd ftcs1 = _cptofcd("L", strlen("L")), ftcs2 = _cptofcd("N", strlen("N")), ftcs3 = _cptofcd("U", strlen("U")); #endif SCformat *Lstore; NCformat *Ustore; complex *Lval, *Uval; int incx = 1, incy = 1; complex temp; complex alpha = {1.0, 0.0}, beta = {1.0, 0.0}; complex comp_zero = {0.0, 0.0}; int nrow; int fsupc, nsupr, nsupc, luptr, istart, irow; int i, k, iptr, jcol; complex *work; flops_t solve_ops; /* Test the input parameters */ *info = 0; if ( strncmp(uplo,"L", 1)!=0 && strncmp(uplo, "U", 1)!=0 ) *info = -1; else if ( strncmp(trans, "N", 1)!=0 && strncmp(trans, "T", 1)!=0 && strncmp(trans, "C", 1)!=0) *info = -2; else if ( strncmp(diag, "U", 1)!=0 && strncmp(diag, "N", 1)!=0 ) *info = -3; else if ( L->nrow != L->ncol || L->nrow < 0 ) *info = -4; else if ( U->nrow != U->ncol || U->nrow < 0 ) *info = -5; if ( *info ) { i = -(*info); input_error("sp_ctrsv", &i); return 0; } Lstore = L->Store; Lval = Lstore->nzval; Ustore = U->Store; Uval = Ustore->nzval; solve_ops = 0; if ( !(work = complexCalloc(L->nrow)) ) ABORT("Malloc fails for work in sp_ctrsv()."); if ( strncmp(trans, "N", 1)==0 ) { /* Form x := inv(A)*x. */ if ( strncmp(uplo, "L", 1)==0 ) { /* Form x := inv(L)*x */ if ( L->nrow == 0 ) return 0; /* Quick return */ for (k = 0; k <= Lstore->nsuper; k++) { fsupc = L_FST_SUPC(k); istart = L_SUB_START(fsupc); nsupr = L_SUB_START(fsupc+1) - istart; nsupc = L_FST_SUPC(k+1) - fsupc; luptr = L_NZ_START(fsupc); nrow = nsupr - nsupc; /* 1 c_div costs 10 flops */ solve_ops += 4 * nsupc * (nsupc - 1) + 10 * nsupc; solve_ops += 8 * nrow * nsupc; if ( nsupc == 1 ) { for (iptr=istart+1; iptr < L_SUB_START(fsupc+1); ++iptr) { irow = L_SUB(iptr); ++luptr; cc_mult(&comp_zero, &x[fsupc], &Lval[luptr]); c_sub(&x[irow], &x[irow], &comp_zero); } } else { #ifdef USE_VENDOR_BLAS #ifdef _CRAY CTRSV(ftcs1, ftcs2, ftcs3, &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); CGEMV(ftcs2, &nrow, &nsupc, &alpha, &Lval[luptr+nsupc], &nsupr, &x[fsupc], &incx, &beta, &work[0], &incy); #else ctrsv_("L", "N", "U", &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); cgemv_("N", &nrow, &nsupc, &alpha, &Lval[luptr+nsupc], &nsupr, &x[fsupc], &incx, &beta, &work[0], &incy); #endif #else clsolve ( nsupr, nsupc, &Lval[luptr], &x[fsupc]); cmatvec ( nsupr, nsupr-nsupc, nsupc, &Lval[luptr+nsupc], &x[fsupc], &work[0] ); #endif iptr = istart + nsupc; for (i = 0; i < nrow; ++i, ++iptr) { irow = L_SUB(iptr); c_sub(&x[irow], &x[irow], &work[i]); /* Scatter */ work[i] = comp_zero; } } } /* for k ... */ } else { /* Form x := inv(U)*x */ if ( U->nrow == 0 ) return 0; /* Quick return */ for (k = Lstore->nsuper; k >= 0; k--) { fsupc = L_FST_SUPC(k); nsupr = L_SUB_START(fsupc+1) - L_SUB_START(fsupc); nsupc = L_FST_SUPC(k+1) - fsupc; luptr = L_NZ_START(fsupc); /* 1 c_div costs 10 flops */ solve_ops += 4 * nsupc * (nsupc + 1) + 10 * nsupc; if ( nsupc == 1 ) { c_div(&x[fsupc], &x[fsupc], &Lval[luptr]); for (i = U_NZ_START(fsupc); i < U_NZ_START(fsupc+1); ++i) { irow = U_SUB(i); cc_mult(&comp_zero, &x[fsupc], &Uval[i]); c_sub(&x[irow], &x[irow], &comp_zero); } } else { #ifdef USE_VENDOR_BLAS #ifdef _CRAY CTRSV(ftcs3, ftcs2, ftcs2, &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #else ctrsv_("U", "N", "N", &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #endif #else cusolve ( nsupr, nsupc, &Lval[luptr], &x[fsupc] ); #endif for (jcol = fsupc; jcol < L_FST_SUPC(k+1); jcol++) { solve_ops += 8*(U_NZ_START(jcol+1) - U_NZ_START(jcol)); for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); i++) { irow = U_SUB(i); cc_mult(&comp_zero, &x[jcol], &Uval[i]); c_sub(&x[irow], &x[irow], &comp_zero); } } } } /* for k ... */ } } else if ( strncmp(trans, "T", 1)==0 ) { /* Form x := inv(A')*x */ if ( strncmp(uplo, "L", 1)==0 ) { /* Form x := inv(L')*x */ if ( L->nrow == 0 ) return 0; /* Quick return */ for (k = Lstore->nsuper; k >= 0; --k) { fsupc = L_FST_SUPC(k); istart = L_SUB_START(fsupc); nsupr = L_SUB_START(fsupc+1) - istart; nsupc = L_FST_SUPC(k+1) - fsupc; luptr = L_NZ_START(fsupc); solve_ops += 8 * (nsupr - nsupc) * nsupc; for (jcol = fsupc; jcol < L_FST_SUPC(k+1); jcol++) { iptr = istart + nsupc; for (i = L_NZ_START(jcol) + nsupc; i < L_NZ_START(jcol+1); i++) { irow = L_SUB(iptr); cc_mult(&comp_zero, &x[irow], &Lval[i]); c_sub(&x[jcol], &x[jcol], &comp_zero); iptr++; } } if ( nsupc > 1 ) { solve_ops += 4 * nsupc * (nsupc - 1); #ifdef _CRAY ftcs1 = _cptofcd("L", strlen("L")); ftcs2 = _cptofcd("T", strlen("T")); ftcs3 = _cptofcd("U", strlen("U")); CTRSV(ftcs1, ftcs2, ftcs3, &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #else ctrsv_("L", "T", "U", &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #endif } } } else { /* Form x := inv(U')*x */ if ( U->nrow == 0 ) return 0; /* Quick return */ for (k = 0; k <= Lstore->nsuper; k++) { fsupc = L_FST_SUPC(k); nsupr = L_SUB_START(fsupc+1) - L_SUB_START(fsupc); nsupc = L_FST_SUPC(k+1) - fsupc; luptr = L_NZ_START(fsupc); for (jcol = fsupc; jcol < L_FST_SUPC(k+1); jcol++) { solve_ops += 8*(U_NZ_START(jcol+1) - U_NZ_START(jcol)); for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); i++) { irow = U_SUB(i); cc_mult(&comp_zero, &x[irow], &Uval[i]); c_sub(&x[jcol], &x[jcol], &comp_zero); } } /* 1 c_div costs 10 flops */ solve_ops += 4 * nsupc * (nsupc + 1) + 10 * nsupc; if ( nsupc == 1 ) { c_div(&x[fsupc], &x[fsupc], &Lval[luptr]); } else { #ifdef _CRAY ftcs1 = _cptofcd("U", strlen("U")); ftcs2 = _cptofcd("T", strlen("T")); ftcs3 = _cptofcd("N", strlen("N")); CTRSV( ftcs1, ftcs2, ftcs3, &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #else ctrsv_("U", "T", "N", &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #endif } } /* for k ... */ } } else { /* Form x := conj(inv(A'))*x */ if ( strncmp(uplo, "L", 1)==0 ) { /* Form x := conj(inv(L'))*x */ if ( L->nrow == 0 ) return 0; /* Quick return */ for (k = Lstore->nsuper; k >= 0; --k) { fsupc = L_FST_SUPC(k); istart = L_SUB_START(fsupc); nsupr = L_SUB_START(fsupc+1) - istart; nsupc = L_FST_SUPC(k+1) - fsupc; luptr = L_NZ_START(fsupc); solve_ops += 8 * (nsupr - nsupc) * nsupc; for (jcol = fsupc; jcol < L_FST_SUPC(k+1); jcol++) { iptr = istart + nsupc; for (i = L_NZ_START(jcol) + nsupc; i < L_NZ_START(jcol+1); i++) { irow = L_SUB(iptr); cc_conj(&temp, &Lval[i]); cc_mult(&comp_zero, &x[irow], &temp); c_sub(&x[jcol], &x[jcol], &comp_zero); iptr++; } } if ( nsupc > 1 ) { solve_ops += 4 * nsupc * (nsupc - 1); #ifdef _CRAY ftcs1 = _cptofcd("L", strlen("L")); ftcs2 = _cptofcd(trans, strlen("T")); ftcs3 = _cptofcd("U", strlen("U")); CTRSV(ftcs1, ftcs2, ftcs3, &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #else ctrsv_("L", trans, "U", &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #endif } } } else { /* Form x := conj(inv(U'))*x */ if ( U->nrow == 0 ) return 0; /* Quick return */ for (k = 0; k <= Lstore->nsuper; k++) { fsupc = L_FST_SUPC(k); nsupr = L_SUB_START(fsupc+1) - L_SUB_START(fsupc); nsupc = L_FST_SUPC(k+1) - fsupc; luptr = L_NZ_START(fsupc); for (jcol = fsupc; jcol < L_FST_SUPC(k+1); jcol++) { solve_ops += 8*(U_NZ_START(jcol+1) - U_NZ_START(jcol)); for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); i++) { irow = U_SUB(i); cc_conj(&temp, &Uval[i]); cc_mult(&comp_zero, &x[irow], &temp); c_sub(&x[jcol], &x[jcol], &comp_zero); } } /* 1 c_div costs 10 flops */ solve_ops += 4 * nsupc * (nsupc + 1) + 10 * nsupc; if ( nsupc == 1 ) { cc_conj(&temp, &Lval[luptr]); c_div(&x[fsupc], &x[fsupc], &temp); } else { #ifdef _CRAY ftcs1 = _cptofcd("U", strlen("U")); ftcs2 = _cptofcd(trans, strlen("T")); ftcs3 = _cptofcd("N", strlen("N")); CTRSV( ftcs1, ftcs2, ftcs3, &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #else ctrsv_("U", trans, "N", &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #endif } } /* for k ... */ } } stat->ops[SOLVE] += solve_ops; SUPERLU_FREE(work); return 0; }
/*! \brief * * <pre> * Purpose * ======= * * dgstrsU only performs the U-solve using the LU factorization computed * by DGSTRF. * * See supermatrix.h for the definition of 'SuperMatrix' structure. * * Arguments * ========= * * trans (input) trans_t * Specifies the form of the system of equations: * = NOTRANS: A * X = B (No transpose) * = TRANS: A'* X = B (Transpose) * = CONJ: A**H * X = B (Conjugate transpose) * * L (input) SuperMatrix* * The factor L from the factorization Pr*A*Pc=L*U as computed by * dgstrf(). Use compressed row subscripts storage for supernodes, * i.e., L has types: Stype = SLU_SC, Dtype = SLU_D, Mtype = SLU_TRLU. * * U (input) SuperMatrix* * The factor U from the factorization Pr*A*Pc=L*U as computed by * dgstrf(). Use column-wise storage scheme, i.e., U has types: * Stype = SLU_NC, Dtype = SLU_D, Mtype = SLU_TRU. * * perm_c (input) int*, dimension (L->ncol) * Column permutation vector, which defines the * permutation matrix Pc; perm_c[i] = j means column i of A is * in position j in A*Pc. * * perm_r (input) int*, dimension (L->nrow) * Row permutation vector, which defines the permutation matrix Pr; * perm_r[i] = j means row i of A is in position j in Pr*A. * * B (input/output) SuperMatrix* * B has types: Stype = SLU_DN, Dtype = SLU_D, Mtype = SLU_GE. * On entry, the right hand side matrix. * On exit, the solution matrix if info = 0; * * stat (output) SuperLUStat_t* * Record the statistics on runtime and floating-point operation count. * See util.h for the definition of 'SuperLUStat_t'. * * info (output) int* * = 0: successful exit * < 0: if info = -i, the i-th argument had an illegal value * </pre> */ void dgstrsU(trans_t trans, SuperMatrix *L, SuperMatrix *U, int *perm_c, int *perm_r, SuperMatrix *B, SuperLUStat_t *stat, int *info) { #ifdef _CRAY _fcd ftcs1, ftcs2, ftcs3, ftcs4; #endif #ifdef USE_VENDOR_BLAS double alpha = 1.0, beta = 1.0; double *work_col; #endif DNformat *Bstore; double *Bmat; SCformat *Lstore; NCformat *Ustore; double *Lval, *Uval; int fsupc, nsupr, nsupc, luptr, istart, irow; int i, j, k, jcol, n, ldb, nrhs; double *rhs_work, *soln; flops_t solve_ops; void dprint_soln(); /* Test input parameters ... */ *info = 0; Bstore = B->Store; ldb = Bstore->lda; nrhs = B->ncol; if ( trans != NOTRANS && trans != TRANS && trans != CONJ ) *info = -1; else if ( L->nrow != L->ncol || L->nrow < 0 || L->Stype != SLU_SC || L->Dtype != SLU_D || L->Mtype != SLU_TRLU ) *info = -2; else if ( U->nrow != U->ncol || U->nrow < 0 || U->Stype != SLU_NC || U->Dtype != SLU_D || U->Mtype != SLU_TRU ) *info = -3; else if ( ldb < SUPERLU_MAX(0, L->nrow) || B->Stype != SLU_DN || B->Dtype != SLU_D || B->Mtype != SLU_GE ) *info = -6; if ( *info ) { i = -(*info); xerbla_("dgstrs", &i); return; } n = L->nrow; soln = doubleMalloc(n); if ( !soln ) ABORT("Malloc fails for local soln[]."); Bmat = Bstore->nzval; Lstore = L->Store; Lval = Lstore->nzval; Ustore = U->Store; Uval = Ustore->nzval; solve_ops = 0; if ( trans == NOTRANS ) { /* * Back solve Ux=y. */ for (k = Lstore->nsuper; k >= 0; k--) { fsupc = L_FST_SUPC(k); istart = L_SUB_START(fsupc); nsupr = L_SUB_START(fsupc+1) - istart; nsupc = L_FST_SUPC(k+1) - fsupc; luptr = L_NZ_START(fsupc); solve_ops += nsupc * (nsupc + 1) * nrhs; if ( nsupc == 1 ) { rhs_work = &Bmat[0]; for (j = 0; j < nrhs; j++) { rhs_work[fsupc] /= Lval[luptr]; rhs_work += ldb; } } else { #ifdef USE_VENDOR_BLAS #ifdef _CRAY ftcs1 = _cptofcd("L", strlen("L")); ftcs2 = _cptofcd("U", strlen("U")); ftcs3 = _cptofcd("N", strlen("N")); STRSM( ftcs1, ftcs2, ftcs3, ftcs3, &nsupc, &nrhs, &alpha, &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb); #else dtrsm_("L", "U", "N", "N", &nsupc, &nrhs, &alpha, &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb); #endif #else for (j = 0; j < nrhs; j++) dusolve ( nsupr, nsupc, &Lval[luptr], &Bmat[fsupc+j*ldb] ); #endif } for (j = 0; j < nrhs; ++j) { rhs_work = &Bmat[j*ldb]; for (jcol = fsupc; jcol < fsupc + nsupc; jcol++) { solve_ops += 2*(U_NZ_START(jcol+1) - U_NZ_START(jcol)); for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); i++ ){ irow = U_SUB(i); rhs_work[irow] -= rhs_work[jcol] * Uval[i]; } } } } /* for U-solve */ #ifdef DEBUG printf("After U-solve: x=\n"); dprint_soln(n, nrhs, Bmat); #endif /* Compute the final solution X := Pc*X. */ for (i = 0; i < nrhs; i++) { rhs_work = &Bmat[i*ldb]; for (k = 0; k < n; k++) soln[k] = rhs_work[perm_c[k]]; for (k = 0; k < n; k++) rhs_work[k] = soln[k]; } stat->ops[SOLVE] = solve_ops; } else { /* Solve U'x = b */ /* Permute right hand sides to form Pc'*B. */ for (i = 0; i < nrhs; i++) { rhs_work = &Bmat[i*ldb]; for (k = 0; k < n; k++) soln[perm_c[k]] = rhs_work[k]; for (k = 0; k < n; k++) rhs_work[k] = soln[k]; } for (k = 0; k < nrhs; ++k) { /* Multiply by inv(U'). */ sp_dtrsv("U", "T", "N", L, U, &Bmat[k*ldb], stat, info); } } SUPERLU_FREE(soln); }
void zgstrs (trans_t trans, SuperMatrix *L, SuperMatrix *U, int *perm_c, int *perm_r, SuperMatrix *B, SuperLUStat_t *stat, int *info) { #ifdef _CRAY _fcd ftcs1, ftcs2, ftcs3, ftcs4; #endif int incx = 1, incy = 1; #ifdef USE_VENDOR_BLAS doublecomplex alpha = {1.0, 0.0}, beta = {1.0, 0.0}; doublecomplex *work_col; #endif doublecomplex temp_comp; DNformat *Bstore; doublecomplex *Bmat; SCformat *Lstore; NCformat *Ustore; doublecomplex *Lval, *Uval; int fsupc, nrow, nsupr, nsupc, luptr, istart, irow; int i, j, k, iptr, jcol, n, ldb, nrhs; doublecomplex *work, *rhs_work, *soln; flops_t solve_ops; void zprint_soln(); /* Test input parameters ... */ *info = 0; Bstore = B->Store; ldb = Bstore->lda; nrhs = B->ncol; if ( trans != NOTRANS && trans != TRANS && trans != CONJ ) *info = -1; else if ( L->nrow != L->ncol || L->nrow < 0 || L->Stype != SLU_SC || L->Dtype != SLU_Z || L->Mtype != SLU_TRLU ) *info = -2; else if ( U->nrow != U->ncol || U->nrow < 0 || U->Stype != SLU_NC || U->Dtype != SLU_Z || U->Mtype != SLU_TRU ) *info = -3; else if ( ldb < SUPERLU_MAX(0, L->nrow) || B->Stype != SLU_DN || B->Dtype != SLU_Z || B->Mtype != SLU_GE ) *info = -6; if ( *info ) { i = -(*info); input_error("zgstrs", &i); return; } n = L->nrow; work = doublecomplexCalloc(n * nrhs); if ( !work ) ABORT("Malloc fails for local work[]."); soln = doublecomplexMalloc(n); if ( !soln ) ABORT("Malloc fails for local soln[]."); Bmat = Bstore->nzval; Lstore = L->Store; Lval = Lstore->nzval; Ustore = U->Store; Uval = Ustore->nzval; solve_ops = 0; if ( trans == NOTRANS ) { /* Permute right hand sides to form Pr*B */ for (i = 0; i < nrhs; i++) { rhs_work = &Bmat[i*ldb]; for (k = 0; k < n; k++) soln[perm_r[k]] = rhs_work[k]; for (k = 0; k < n; k++) rhs_work[k] = soln[k]; } /* Forward solve PLy=Pb. */ for (k = 0; k <= Lstore->nsuper; k++) { fsupc = L_FST_SUPC(k); istart = L_SUB_START(fsupc); nsupr = L_SUB_START(fsupc+1) - istart; nsupc = L_FST_SUPC(k+1) - fsupc; nrow = nsupr - nsupc; solve_ops += 4 * nsupc * (nsupc - 1) * nrhs; solve_ops += 8 * nrow * nsupc * nrhs; if ( nsupc == 1 ) { for (j = 0; j < nrhs; j++) { rhs_work = &Bmat[j*ldb]; luptr = L_NZ_START(fsupc); for (iptr=istart+1; iptr < L_SUB_START(fsupc+1); iptr++){ irow = L_SUB(iptr); ++luptr; zz_mult(&temp_comp, &rhs_work[fsupc], &Lval[luptr]); z_sub(&rhs_work[irow], &rhs_work[irow], &temp_comp); } } } else { luptr = L_NZ_START(fsupc); #ifdef USE_VENDOR_BLAS #ifdef _CRAY ftcs1 = _cptofcd("L", strlen("L")); ftcs2 = _cptofcd("N", strlen("N")); ftcs3 = _cptofcd("U", strlen("U")); CTRSM( ftcs1, ftcs1, ftcs2, ftcs3, &nsupc, &nrhs, &alpha, &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb); CGEMM( ftcs2, ftcs2, &nrow, &nrhs, &nsupc, &alpha, &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb, &beta, &work[0], &n ); #else ztrsm_("L", "L", "N", "U", &nsupc, &nrhs, &alpha, &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb); zgemm_( "N", "N", &nrow, &nrhs, &nsupc, &alpha, &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb, &beta, &work[0], &n ); #endif for (j = 0; j < nrhs; j++) { rhs_work = &Bmat[j*ldb]; work_col = &work[j*n]; iptr = istart + nsupc; for (i = 0; i < nrow; i++) { irow = L_SUB(iptr); z_sub(&rhs_work[irow], &rhs_work[irow], &work_col[i]); work_col[i].r = 0.0; work_col[i].i = 0.0; iptr++; } } #else for (j = 0; j < nrhs; j++) { rhs_work = &Bmat[j*ldb]; zlsolve (nsupr, nsupc, &Lval[luptr], &rhs_work[fsupc]); zmatvec (nsupr, nrow, nsupc, &Lval[luptr+nsupc], &rhs_work[fsupc], &work[0] ); iptr = istart + nsupc; for (i = 0; i < nrow; i++) { irow = L_SUB(iptr); z_sub(&rhs_work[irow], &rhs_work[irow], &work[i]); work[i].r = 0.; work[i].i = 0.; iptr++; } } #endif } /* else ... */ } /* for L-solve */ #ifdef DEBUG printf("After L-solve: y=\n"); zprint_soln(n, nrhs, Bmat); #endif /* * Back solve Ux=y. */ for (k = Lstore->nsuper; k >= 0; k--) { fsupc = L_FST_SUPC(k); istart = L_SUB_START(fsupc); nsupr = L_SUB_START(fsupc+1) - istart; nsupc = L_FST_SUPC(k+1) - fsupc; luptr = L_NZ_START(fsupc); solve_ops += 4 * nsupc * (nsupc + 1) * nrhs; if ( nsupc == 1 ) { rhs_work = &Bmat[0]; for (j = 0; j < nrhs; j++) { z_div(&rhs_work[fsupc], &rhs_work[fsupc], &Lval[luptr]); rhs_work += ldb; } } else { #ifdef USE_VENDOR_BLAS #ifdef _CRAY ftcs1 = _cptofcd("L", strlen("L")); ftcs2 = _cptofcd("U", strlen("U")); ftcs3 = _cptofcd("N", strlen("N")); CTRSM( ftcs1, ftcs2, ftcs3, ftcs3, &nsupc, &nrhs, &alpha, &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb); #else ztrsm_("L", "U", "N", "N", &nsupc, &nrhs, &alpha, &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb); #endif #else for (j = 0; j < nrhs; j++) zusolve ( nsupr, nsupc, &Lval[luptr], &Bmat[fsupc+j*ldb] ); #endif } for (j = 0; j < nrhs; ++j) { rhs_work = &Bmat[j*ldb]; for (jcol = fsupc; jcol < fsupc + nsupc; jcol++) { solve_ops += 8*(U_NZ_START(jcol+1) - U_NZ_START(jcol)); for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); i++ ){ irow = U_SUB(i); zz_mult(&temp_comp, &rhs_work[jcol], &Uval[i]); z_sub(&rhs_work[irow], &rhs_work[irow], &temp_comp); } } } } /* for U-solve */ #ifdef DEBUG printf("After U-solve: x=\n"); zprint_soln(n, nrhs, Bmat); #endif /* Compute the final solution X := Pc*X. */ for (i = 0; i < nrhs; i++) { rhs_work = &Bmat[i*ldb]; for (k = 0; k < n; k++) soln[k] = rhs_work[perm_c[k]]; for (k = 0; k < n; k++) rhs_work[k] = soln[k]; } stat->ops[SOLVE] = solve_ops; } else { /* Solve A'*X=B or CONJ(A)*X=B */ /* Permute right hand sides to form Pc'*B. */ for (i = 0; i < nrhs; i++) { rhs_work = &Bmat[i*ldb]; for (k = 0; k < n; k++) soln[perm_c[k]] = rhs_work[k]; for (k = 0; k < n; k++) rhs_work[k] = soln[k]; } stat->ops[SOLVE] = 0; if (trans == TRANS) { for (k = 0; k < nrhs; ++k) { /* Multiply by inv(U'). */ sp_ztrsv("U", "T", "N", L, U, &Bmat[k*ldb], stat, info); /* Multiply by inv(L'). */ sp_ztrsv("L", "T", "U", L, U, &Bmat[k*ldb], stat, info); } } else { /* trans == CONJ */ for (k = 0; k < nrhs; ++k) { /* Multiply by conj(inv(U')). */ sp_ztrsv("U", "C", "N", L, U, &Bmat[k*ldb], stat, info); /* Multiply by conj(inv(L')). */ sp_ztrsv("L", "C", "U", L, U, &Bmat[k*ldb], stat, info); } } /* Compute the final solution X := Pr'*X (=inv(Pr)*X) */ for (i = 0; i < nrhs; i++) { rhs_work = &Bmat[i*ldb]; for (k = 0; k < n; k++) soln[k] = rhs_work[perm_r[k]]; for (k = 0; k < n; k++) rhs_work[k] = soln[k]; } } SUPERLU_FREE(work); SUPERLU_FREE(soln); }
/*! \brief Solves one of the systems of equations A*x = b, or A'*x = b * * <pre> * Purpose * ======= * * sp_strsv() solves one of the systems of equations * A*x = b, or A'*x = b, * where b and x are n element vectors and A is a sparse unit , or * non-unit, upper or lower triangular matrix. * No test for singularity or near-singularity is included in this * routine. Such tests must be performed before calling this routine. * * Parameters * ========== * * uplo - (input) char* * On entry, uplo specifies whether the matrix is an upper or * lower triangular matrix as follows: * uplo = 'U' or 'u' A is an upper triangular matrix. * uplo = 'L' or 'l' A is a lower triangular matrix. * * trans - (input) char* * On entry, trans specifies the equations to be solved as * follows: * trans = 'N' or 'n' A*x = b. * trans = 'T' or 't' A'*x = b. * trans = 'C' or 'c' A'*x = b. * * diag - (input) char* * On entry, diag specifies whether or not A is unit * triangular as follows: * diag = 'U' or 'u' A is assumed to be unit triangular. * diag = 'N' or 'n' A is not assumed to be unit * triangular. * * L - (input) SuperMatrix* * The factor L from the factorization Pr*A*Pc=L*U. Use * compressed row subscripts storage for supernodes, * i.e., L has types: Stype = SC, Dtype = SLU_S, Mtype = TRLU. * * U - (input) SuperMatrix* * The factor U from the factorization Pr*A*Pc=L*U. * U has types: Stype = NC, Dtype = SLU_S, Mtype = TRU. * * x - (input/output) float* * Before entry, the incremented array X must contain the n * element right-hand side vector b. On exit, X is overwritten * with the solution vector x. * * info - (output) int* * If *info = -i, the i-th argument had an illegal value. * </pre> */ int sp_strsv(char *uplo, char *trans, char *diag, SuperMatrix *L, SuperMatrix *U, float *x, SuperLUStat_t *stat, int *info) { #ifdef _CRAY _fcd ftcs1 = _cptofcd("L", strlen("L")), ftcs2 = _cptofcd("N", strlen("N")), ftcs3 = _cptofcd("U", strlen("U")); #endif SCformat *Lstore; NCformat *Ustore; float *Lval, *Uval; int incx = 1, incy = 1; float alpha = 1.0, beta = 1.0; int nrow; int fsupc, nsupr, nsupc, luptr, istart, irow; int i, k, iptr, jcol; float *work; flops_t solve_ops; /* Test the input parameters */ *info = 0; if ( !lsame_(uplo,"L") && !lsame_(uplo, "U") ) *info = -1; else if ( !lsame_(trans, "N") && !lsame_(trans, "T") && !lsame_(trans, "C")) *info = -2; else if ( !lsame_(diag, "U") && !lsame_(diag, "N") ) *info = -3; else if ( L->nrow != L->ncol || L->nrow < 0 ) *info = -4; else if ( U->nrow != U->ncol || U->nrow < 0 ) *info = -5; if ( *info ) { i = -(*info); xerbla_("sp_strsv", &i); return 0; } Lstore = L->Store; Lval = Lstore->nzval; Ustore = U->Store; Uval = Ustore->nzval; solve_ops = 0; if ( !(work = floatCalloc(L->nrow)) ) ABORT("Malloc fails for work in sp_strsv()."); if ( lsame_(trans, "N") ) { /* Form x := inv(A)*x. */ if ( lsame_(uplo, "L") ) { /* Form x := inv(L)*x */ if ( L->nrow == 0 ) return 0; /* Quick return */ for (k = 0; k <= Lstore->nsuper; k++) { fsupc = L_FST_SUPC(k); istart = L_SUB_START(fsupc); nsupr = L_SUB_START(fsupc+1) - istart; nsupc = L_FST_SUPC(k+1) - fsupc; luptr = L_NZ_START(fsupc); nrow = nsupr - nsupc; solve_ops += nsupc * (nsupc - 1); solve_ops += 2 * nrow * nsupc; if ( nsupc == 1 ) { for (iptr=istart+1; iptr < L_SUB_START(fsupc+1); ++iptr) { irow = L_SUB(iptr); ++luptr; x[irow] -= x[fsupc] * Lval[luptr]; } } else { #ifdef USE_VENDOR_BLAS #ifdef _CRAY STRSV(ftcs1, ftcs2, ftcs3, &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); SGEMV(ftcs2, &nrow, &nsupc, &alpha, &Lval[luptr+nsupc], &nsupr, &x[fsupc], &incx, &beta, &work[0], &incy); #else strsv_("L", "N", "U", &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); sgemv_("N", &nrow, &nsupc, &alpha, &Lval[luptr+nsupc], &nsupr, &x[fsupc], &incx, &beta, &work[0], &incy); #endif #else slsolve ( nsupr, nsupc, &Lval[luptr], &x[fsupc]); smatvec ( nsupr, nsupr-nsupc, nsupc, &Lval[luptr+nsupc], &x[fsupc], &work[0] ); #endif iptr = istart + nsupc; for (i = 0; i < nrow; ++i, ++iptr) { irow = L_SUB(iptr); x[irow] -= work[i]; /* Scatter */ work[i] = 0.0; } } } /* for k ... */ } else { /* Form x := inv(U)*x */ if ( U->nrow == 0 ) return 0; /* Quick return */ for (k = Lstore->nsuper; k >= 0; k--) { fsupc = L_FST_SUPC(k); nsupr = L_SUB_START(fsupc+1) - L_SUB_START(fsupc); nsupc = L_FST_SUPC(k+1) - fsupc; luptr = L_NZ_START(fsupc); solve_ops += nsupc * (nsupc + 1); if ( nsupc == 1 ) { x[fsupc] /= Lval[luptr]; for (i = U_NZ_START(fsupc); i < U_NZ_START(fsupc+1); ++i) { irow = U_SUB(i); x[irow] -= x[fsupc] * Uval[i]; } } else { #ifdef USE_VENDOR_BLAS #ifdef _CRAY STRSV(ftcs3, ftcs2, ftcs2, &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #else strsv_("U", "N", "N", &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #endif #else susolve ( nsupr, nsupc, &Lval[luptr], &x[fsupc] ); #endif for (jcol = fsupc; jcol < L_FST_SUPC(k+1); jcol++) { solve_ops += 2*(U_NZ_START(jcol+1) - U_NZ_START(jcol)); for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); i++) { irow = U_SUB(i); x[irow] -= x[jcol] * Uval[i]; } } } } /* for k ... */ } } else { /* Form x := inv(A')*x */ if ( lsame_(uplo, "L") ) { /* Form x := inv(L')*x */ if ( L->nrow == 0 ) return 0; /* Quick return */ for (k = Lstore->nsuper; k >= 0; --k) { fsupc = L_FST_SUPC(k); istart = L_SUB_START(fsupc); nsupr = L_SUB_START(fsupc+1) - istart; nsupc = L_FST_SUPC(k+1) - fsupc; luptr = L_NZ_START(fsupc); solve_ops += 2 * (nsupr - nsupc) * nsupc; for (jcol = fsupc; jcol < L_FST_SUPC(k+1); jcol++) { iptr = istart + nsupc; for (i = L_NZ_START(jcol) + nsupc; i < L_NZ_START(jcol+1); i++) { irow = L_SUB(iptr); x[jcol] -= x[irow] * Lval[i]; iptr++; } } if ( nsupc > 1 ) { solve_ops += nsupc * (nsupc - 1); #ifdef _CRAY ftcs1 = _cptofcd("L", strlen("L")); ftcs2 = _cptofcd("T", strlen("T")); ftcs3 = _cptofcd("U", strlen("U")); STRSV(ftcs1, ftcs2, ftcs3, &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #else strsv_("L", "T", "U", &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #endif } } } else { /* Form x := inv(U')*x */ if ( U->nrow == 0 ) return 0; /* Quick return */ for (k = 0; k <= Lstore->nsuper; k++) { fsupc = L_FST_SUPC(k); nsupr = L_SUB_START(fsupc+1) - L_SUB_START(fsupc); nsupc = L_FST_SUPC(k+1) - fsupc; luptr = L_NZ_START(fsupc); for (jcol = fsupc; jcol < L_FST_SUPC(k+1); jcol++) { solve_ops += 2*(U_NZ_START(jcol+1) - U_NZ_START(jcol)); for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); i++) { irow = U_SUB(i); x[jcol] -= x[irow] * Uval[i]; } } solve_ops += nsupc * (nsupc + 1); if ( nsupc == 1 ) { x[fsupc] /= Lval[luptr]; } else { #ifdef _CRAY ftcs1 = _cptofcd("U", strlen("U")); ftcs2 = _cptofcd("T", strlen("T")); ftcs3 = _cptofcd("N", strlen("N")); STRSV( ftcs1, ftcs2, ftcs3, &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #else strsv_("U", "T", "N", &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #endif } } /* for k ... */ } } stat->ops[SOLVE] += solve_ops; SUPERLU_FREE(work); return 0; }
int sp_ctrsv(char *uplo, char *trans, char *diag, SuperMatrix *L, SuperMatrix *U, complex *x, int *info) { /* * Purpose * ======= * * sp_ctrsv() solves one of the systems of equations * A*x = b, or A'*x = b, * where b and x are n element vectors and A is a sparse unit , or * non-unit, upper or lower triangular matrix. * No test for singularity or near-singularity is included in this * routine. Such tests must be performed before calling this routine. * * Parameters * ========== * * uplo - (input) char* * On entry, uplo specifies whether the matrix is an upper or * lower triangular matrix as follows: * uplo = 'U' or 'u' A is an upper triangular matrix. * uplo = 'L' or 'l' A is a lower triangular matrix. * * trans - (input) char* * On entry, trans specifies the equations to be solved as * follows: * trans = 'N' or 'n' A*x = b. * trans = 'T' or 't' A'*x = b. * trans = 'C' or 'c' A^H*x = b. * * diag - (input) char* * On entry, diag specifies whether or not A is unit * triangular as follows: * diag = 'U' or 'u' A is assumed to be unit triangular. * diag = 'N' or 'n' A is not assumed to be unit * triangular. * * L - (input) SuperMatrix* * The factor L from the factorization Pr*A*Pc=L*U. Use * compressed row subscripts storage for supernodes, * i.e., L has types: Stype = SC, Dtype = _C, Mtype = TRLU. * * U - (input) SuperMatrix* * The factor U from the factorization Pr*A*Pc=L*U. * U has types: Stype = NCP, Dtype = _C, Mtype = TRU. * * x - (input/output) complex* * Before entry, the incremented array X must contain the n * element right-hand side vector b. On exit, X is overwritten * with the solution vector x. * * info - (output) int* * If *info = -i, the i-th argument had an illegal value. * */ #if ( MACH==CRAY_PVP ) _fcd ftcs1, ftcs2, ftcs3; #endif SCPformat *Lstore; NCPformat *Ustore; complex *Lval, *Uval; int incx = 1, incy = 1; complex temp; complex alpha = {1.0, 0.0}, beta = {1.0, 0.0}; complex comp_zero = {0.0, 0.0}; register int fsupc, luptr, istart, irow, k, iptr, jcol, nsuper; int nsupr, nsupc, nrow, i; complex *work; flops_t solve_ops; /* Test the input parameters */ *info = 0; if ( !lsame_(uplo,"L") && !lsame_(uplo, "U") ) *info = -1; else if ( !lsame_(trans, "N") && !lsame_(trans, "T") ) *info = -2; else if ( !lsame_(diag, "U") && !lsame_(diag, "N") ) *info = -3; else if ( L->nrow != L->ncol || L->nrow < 0 ) *info = -4; else if ( U->nrow != U->ncol || U->nrow < 0 ) *info = -5; if ( *info ) { i = -(*info); xerbla_("sp_ctrsv", &i); return 0; } Lstore = (SCPformat*) L->Store; Lval = (complex*) Lstore->nzval; Ustore = (NCPformat*) U->Store; Uval = (complex*) Ustore->nzval; nsuper = Lstore->nsuper; solve_ops = 0; if ( !(work = complexCalloc(L->nrow)) ) SUPERLU_ABORT("Malloc fails for work in sp_ctrsv()."); if ( lsame_(trans, "N") ) { /* Form x := inv(A)*x. */ if ( lsame_(uplo, "L") ) { /* Form x := inv(L)*x */ if ( L->nrow == 0 ) return 0; /* Quick return */ for (k = 0; k <= nsuper; k++) { fsupc = L_FST_SUPC(k); istart = L_SUB_START(fsupc); nsupr = L_SUB_END(fsupc) - istart; nsupc = L_LAST_SUPC(k) - fsupc; luptr = L_NZ_START(fsupc); nrow = nsupr - nsupc; /* 1 c_div costs 10 flops */ solve_ops += 4 * nsupc * (nsupc - 1) + 10 * nsupc; solve_ops += 8 * nrow * nsupc; if ( nsupc == 1 ) { for (iptr=istart+1; iptr < L_SUB_END(fsupc); ++iptr) { irow = L_SUB(iptr); ++luptr; cc_mult(&comp_zero, &x[fsupc], &Lval[luptr]); c_sub(&x[irow], &x[irow], &comp_zero); } } else { #ifdef USE_VENDOR_BLAS #if ( MACH==CRAY_PVP ) ftcs1 = _cptofcd("L", strlen("L")); ftcs2 = _cptofcd("N", strlen("N")); ftcs3 = _cptofcd("U", strlen("U")); CTRSV(ftcs1, ftcs2, ftcs3, &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); CGEMV(ftcs2, &nrow, &nsupc, &alpha, &Lval[luptr+nsupc], &nsupr, &x[fsupc], &incx, &beta, &work[0], &incy); #else ctrsv_("L", "N", "U", &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); cgemv_("N", &nrow, &nsupc, &alpha, &Lval[luptr+nsupc], &nsupr, &x[fsupc], &incx, &beta, &work[0], &incy); #endif #else clsolve (nsupr, nsupc, &Lval[luptr], &x[fsupc]); cmatvec (nsupr, nsupr-nsupc, nsupc, &Lval[luptr+nsupc], &x[fsupc], &work[0] ); #endif iptr = istart + nsupc; for (i = 0; i < nrow; ++i, ++iptr) { irow = L_SUB(iptr); c_sub(&x[irow], &x[irow], &work[i]); /* Scatter */ work[i] = comp_zero; } } } /* for k ... */ } else { /* Form x := inv(U)*x */ if ( U->nrow == 0 ) return 0; /* Quick return */ for (k = nsuper; k >= 0; k--) { fsupc = L_FST_SUPC(k); nsupr = L_SUB_END(fsupc) - L_SUB_START(fsupc); nsupc = L_LAST_SUPC(k) - fsupc; luptr = L_NZ_START(fsupc); /* 1 c_div costs 10 flops */ solve_ops += 4 * nsupc * (nsupc + 1) + 10 * nsupc; if ( nsupc == 1 ) { c_div(&x[fsupc], &x[fsupc], &Lval[luptr]); for (i = U_NZ_START(fsupc); i < U_NZ_END(fsupc); ++i) { irow = U_SUB(i); cc_mult(&comp_zero, &x[fsupc], &Uval[i]); c_sub(&x[irow], &x[irow], &comp_zero); } } else { #ifdef USE_VENDOR_BLAS #if ( MACH==CRAY_PVP ) ftcs1 = _cptofcd("U", strlen("U")); ftcs2 = _cptofcd("N", strlen("N")); ftcs3 = _cptofcd("N", strlen("N")); CTRSV(ftcs1, ftcs2, ftcs3, &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #else ctrsv_("U", "N", "N", &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #endif #else cusolve ( nsupr, nsupc, &Lval[luptr], &x[fsupc] ); #endif for (jcol = fsupc; jcol < fsupc + nsupc; jcol++) { solve_ops += 8*(U_NZ_END(jcol) - U_NZ_START(jcol)); for (i = U_NZ_START(jcol); i < U_NZ_END(jcol); i++) { irow = U_SUB(i); cc_mult(&comp_zero, &x[jcol], &Uval[i]); c_sub(&x[irow], &x[irow], &comp_zero); } } } } /* for k ... */ } } else if ( lsame_(trans, "T") ) { /* Form x := inv(A')*x */ if ( lsame_(uplo, "L") ) { /* Form x := inv(L')*x */ if ( L->nrow == 0 ) return 0; /* Quick return */ for (k = Lstore->nsuper; k >= 0; --k) { fsupc = L_FST_SUPC(k); istart = L_SUB_START(fsupc); nsupr = L_SUB_END(fsupc) - istart; nsupc = L_LAST_SUPC(k) - fsupc; luptr = L_NZ_START(fsupc); solve_ops += 8 * (nsupr - nsupc) * nsupc; for (jcol = fsupc; jcol < L_LAST_SUPC(k); jcol++) { iptr = istart + nsupc; for (i = L_NZ_START(jcol) + nsupc; i < L_NZ_END(jcol); i++) { irow = L_SUB(iptr); cc_mult(&comp_zero, &x[irow], &Lval[i]); c_sub(&x[jcol], &x[jcol], &comp_zero); iptr++; } } if ( nsupc > 1 ) { solve_ops += 4 * nsupc * (nsupc - 1); #ifdef _CRAY ftcs1 = _cptofcd("L", strlen("L")); ftcs2 = _cptofcd("T", strlen("T")); ftcs3 = _cptofcd("U", strlen("U")); CTRSV(ftcs1, ftcs2, ftcs3, &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #else ctrsv_("L", "T", "U", &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #endif } } } else { /* Form x := inv(U')*x */ if ( U->nrow == 0 ) return 0; /* Quick return */ for (k = 0; k <= nsuper; k++) { fsupc = L_FST_SUPC(k); nsupr = L_SUB_END(fsupc) - L_SUB_START(fsupc); nsupc = L_LAST_SUPC(k) - fsupc; luptr = L_NZ_START(fsupc); for (jcol = fsupc; jcol < fsupc + nsupc; jcol++) { solve_ops += 8*(U_NZ_START(jcol+1) - U_NZ_START(jcol)); for (i = U_NZ_START(jcol); i < U_NZ_END(jcol); i++) { irow = U_SUB(i); cc_mult(&comp_zero, &x[irow], &Uval[i]); c_sub(&x[jcol], &x[jcol], &comp_zero); } } /* 1 c_div costs 10 flops */ solve_ops += 4 * nsupc * (nsupc + 1) + 10 * nsupc; if ( nsupc == 1 ) { c_div(&x[fsupc], &x[fsupc], &Lval[luptr]); } else { #ifdef _CRAY ftcs1 = _cptofcd("U", strlen("U")); ftcs2 = _cptofcd("T", strlen("T")); ftcs3 = _cptofcd("N", strlen("N")); CTRSV( ftcs1, ftcs2, ftcs3, &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #else ctrsv_("U", "T", "N", &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #endif } } /* for k ... */ } } else { /* Form x := conj(inv(A'))*x */ if ( lsame_(uplo, "L") ) { /* Form x := conj(inv(L'))*x */ if ( L->nrow == 0 ) return 0; /* Quick return */ for (k = Lstore->nsuper; k >= 0; --k) { fsupc = L_FST_SUPC(k); istart = L_SUB_START(fsupc); nsupr = L_SUB_END(fsupc) - istart; nsupc = L_LAST_SUPC(k) - fsupc; luptr = L_NZ_START(fsupc); solve_ops += 8 * (nsupr - nsupc) * nsupc; for (jcol = fsupc; jcol < L_FST_SUPC(k+1); jcol++) { iptr = istart + nsupc; for (i = L_NZ_START(jcol) + nsupc; i < L_NZ_START(jcol+1); i++) { irow = L_SUB(iptr); cc_conj(&temp, &Lval[i]); cc_mult(&comp_zero, &x[irow], &temp); c_sub(&x[jcol], &x[jcol], &comp_zero); iptr++; } } if ( nsupc > 1 ) { solve_ops += 4 * nsupc * (nsupc - 1); #ifdef _CRAY ftcs1 = _cptofcd("L", strlen("L")); ftcs2 = _cptofcd(trans, strlen("T")); ftcs3 = _cptofcd("U", strlen("U")); CTRSV(ftcs1, ftcs2, ftcs3, &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #else ctrsv_("L", trans, "U", &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #endif } } } else { /* Form x := conj(inv(U'))*x */ if ( U->nrow == 0 ) return 0; /* Quick return */ for (k = 0; k <= Lstore->nsuper; k++) { fsupc = L_FST_SUPC(k); nsupr = L_SUB_START(fsupc+1) - L_SUB_START(fsupc); nsupc = L_FST_SUPC(k+1) - fsupc; luptr = L_NZ_START(fsupc); for (jcol = fsupc; jcol < L_FST_SUPC(k+1); jcol++) { solve_ops += 8*(U_NZ_START(jcol+1) - U_NZ_START(jcol)); for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); i++) { irow = U_SUB(i); cc_conj(&temp, &Uval[i]); cc_mult(&comp_zero, &x[irow], &temp); c_sub(&x[jcol], &x[jcol], &comp_zero); } } /* 1 c_div costs 10 flops */ solve_ops += 4 * nsupc * (nsupc + 1) + 10 * nsupc; if ( nsupc == 1 ) { cc_conj(&temp, &Lval[luptr]); c_div(&x[fsupc], &x[fsupc], &temp); } else { #ifdef _CRAY ftcs1 = _cptofcd("U", strlen("U")); ftcs2 = _cptofcd(trans, strlen("T")); ftcs3 = _cptofcd("N", strlen("N")); CTRSV( ftcs1, ftcs2, ftcs3, &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #else ctrsv_("U", trans, "N", &nsupc, &Lval[luptr], &nsupr, &x[fsupc], &incx); #endif } } /* for k ... */ } } SUPERLU_FREE(work); return 0; }
void dgstrsL(char *trans, SuperMatrix *L, int *perm_r, SuperMatrix *B, int *info) { /* * Purpose * ======= * * DGSTRSL only performs the L-solve using the LU factorization computed * by DGSTRF. * * See supermatrix.h for the definition of 'SuperMatrix' structure. * * Arguments * ========= * * trans (input) char* * Specifies the form of the system of equations: * = 'N': A * X = B (No transpose) * = 'T': A'* X = B (Transpose) * = 'C': A**H * X = B (Conjugate transpose) * * L (input) SuperMatrix* * The factor L from the factorization Pr*A*Pc=L*U as computed by * dgstrf(). Use compressed row subscripts storage for supernodes, * i.e., L has types: Stype = SLU_SC, Dtype = SLU_D, Mtype = SLU_TRLU. * * U (input) SuperMatrix* * The factor U from the factorization Pr*A*Pc=L*U as computed by * dgstrf(). Use column-wise storage scheme, i.e., U has types: * Stype = SLU_NC, Dtype = SLU_D, Mtype = SLU_TRU. * * perm_r (input) int*, dimension (L->nrow) * Row permutation vector, which defines the permutation matrix Pr; * perm_r[i] = j means row i of A is in position j in Pr*A. * * B (input/output) SuperMatrix* * B has types: Stype = SLU_DN, Dtype = SLU_D, Mtype = SLU_GE. * On entry, the right hand side matrix. * On exit, the solution matrix if info = 0; * * info (output) int* * = 0: successful exit * < 0: if info = -i, the i-th argument had an illegal value * */ #ifdef _CRAY _fcd ftcs1, ftcs2, ftcs3, ftcs4; #endif int incx = 1, incy = 1; double alpha = 1.0, beta = 1.0; DNformat *Bstore; double *Bmat; SCformat *Lstore; double *Lval, *Uval; int nrow, notran; int fsupc, nsupr, nsupc, luptr, istart, irow; int i, j, k, iptr, jcol, n, ldb, nrhs; double *work, *work_col, *rhs_work, *soln; flops_t solve_ops; extern SuperLUStat_t SuperLUStat; void dprint_soln(); /* Test input parameters ... */ *info = 0; Bstore = B->Store; ldb = Bstore->lda; nrhs = B->ncol; notran = lsame_(trans, "N"); if ( !notran && !lsame_(trans, "T") && !lsame_(trans, "C") ) *info = -1; else if ( L->nrow != L->ncol || L->nrow < 0 || L->Stype != SLU_SC || L->Dtype != SLU_D || L->Mtype != SLU_TRLU ) *info = -2; else if ( ldb < SUPERLU_MAX(0, L->nrow) || B->Stype != SLU_DN || B->Dtype != SLU_D || B->Mtype != SLU_GE ) *info = -4; if ( *info ) { i = -(*info); xerbla_("dgstrsL", &i); return; } n = L->nrow; work = doubleCalloc(n * nrhs); if ( !work ) ABORT("Malloc fails for local work[]."); soln = doubleMalloc(n); if ( !soln ) ABORT("Malloc fails for local soln[]."); Bmat = Bstore->nzval; Lstore = L->Store; Lval = Lstore->nzval; solve_ops = 0; if ( notran ) { /* Permute right hand sides to form Pr*B */ for (i = 0; i < nrhs; i++) { rhs_work = &Bmat[i*ldb]; for (k = 0; k < n; k++) soln[perm_r[k]] = rhs_work[k]; for (k = 0; k < n; k++) rhs_work[k] = soln[k]; } /* Forward solve PLy=Pb. */ for (k = 0; k <= Lstore->nsuper; k++) { fsupc = L_FST_SUPC(k); istart = L_SUB_START(fsupc); nsupr = L_SUB_START(fsupc+1) - istart; nsupc = L_FST_SUPC(k+1) - fsupc; nrow = nsupr - nsupc; solve_ops += nsupc * (nsupc - 1) * nrhs; solve_ops += 2 * nrow * nsupc * nrhs; if ( nsupc == 1 ) { for (j = 0; j < nrhs; j++) { rhs_work = &Bmat[j*ldb]; luptr = L_NZ_START(fsupc); for (iptr=istart+1; iptr < L_SUB_START(fsupc+1); iptr++){ irow = L_SUB(iptr); ++luptr; rhs_work[irow] -= rhs_work[fsupc] * Lval[luptr]; } } } else { luptr = L_NZ_START(fsupc); #ifdef USE_VENDOR_BLAS #ifdef _CRAY ftcs1 = _cptofcd("L", strlen("L")); ftcs2 = _cptofcd("N", strlen("N")); ftcs3 = _cptofcd("U", strlen("U")); STRSM( ftcs1, ftcs1, ftcs2, ftcs3, &nsupc, &nrhs, &alpha, &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb); SGEMM( ftcs2, ftcs2, &nrow, &nrhs, &nsupc, &alpha, &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb, &beta, &work[0], &n ); #else dtrsm_("L", "L", "N", "U", &nsupc, &nrhs, &alpha, &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb); dgemm_( "N", "N", &nrow, &nrhs, &nsupc, &alpha, &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb, &beta, &work[0], &n ); #endif for (j = 0; j < nrhs; j++) { rhs_work = &Bmat[j*ldb]; work_col = &work[j*n]; iptr = istart + nsupc; for (i = 0; i < nrow; i++) { irow = L_SUB(iptr); rhs_work[irow] -= work_col[i]; /* Scatter */ work_col[i] = 0.0; iptr++; } } #else for (j = 0; j < nrhs; j++) { rhs_work = &Bmat[j*ldb]; dlsolve (nsupr, nsupc, &Lval[luptr], &rhs_work[fsupc]); dmatvec (nsupr, nrow, nsupc, &Lval[luptr+nsupc], &rhs_work[fsupc], &work[0] ); iptr = istart + nsupc; for (i = 0; i < nrow; i++) { irow = L_SUB(iptr); rhs_work[irow] -= work[i]; work[i] = 0.0; iptr++; } } #endif } /* else ... */ } /* for L-solve */ #ifdef DEBUG printf("After L-solve: y=\n"); dprint_soln(n, nrhs, Bmat); #endif SuperLUStat.ops[SOLVE] = solve_ops; } else { printf("Transposed solve not implemented.\n"); exit(0); } SUPERLU_FREE(work); SUPERLU_FREE(soln); }
/*! \brief * * <pre> * Purpose * ======= * * dgstrsL only performs the L-solve using the LU factorization computed * by DGSTRF. * * See supermatrix.h for the definition of 'SuperMatrix' structure. * * Arguments * ========= * * trans (input) char* * Specifies the form of the system of equations: * = 'N': A * X = B (No transpose) * = 'T': A'* X = B (Transpose) * = 'C': A**H * X = B (Conjugate transpose) * * L (input) SuperMatrix* * The factor L from the factorization Pr*A*Pc=L*U as computed by * dgstrf(). Use compressed row subscripts storage for supernodes, * i.e., L has types: Stype = SLU_SC, Dtype = SLU_D, Mtype = SLU_TRLU. * * U (input) SuperMatrix* * The factor U from the factorization Pr*A*Pc=L*U as computed by * dgstrf(). Use column-wise storage scheme, i.e., U has types: * Stype = SLU_NC, Dtype = SLU_D, Mtype = SLU_TRU. * * perm_r (input) int*, dimension (L->nrow) * Row permutation vector, which defines the permutation matrix Pr; * perm_r[i] = j means row i of A is in position j in Pr*A. * * B (input/output) SuperMatrix* * B has types: Stype = SLU_DN, Dtype = SLU_D, Mtype = SLU_GE. * On entry, the right hand side matrix. * On exit, the solution matrix if info = 0; * * info (output) int* * = 0: successful exit * < 0: if info = -i, the i-th argument had an illegal value * </pre> */ void dgstrsL(char *trans, SuperMatrix *L, int *perm_r, SuperMatrix *B, int *info) { #ifdef _CRAY _fcd ftcs1, ftcs2, ftcs3, ftcs4; #endif int incx = 1, incy = 1; double alpha = 1.0, beta = 1.0; DNformat *Bstore; double *Bmat; SCformat *Lstore; double *Lval, *Uval; int nrow, notran; int fsupc, nsupr, nsupc, luptr, istart, irow; int i, j, k, iptr, jcol, n, ldb, nrhs; double *work, *work_col, *rhs_work, *soln; flops_t solve_ops; extern SuperLUStat_t SuperLUStat; void dprint_soln(); /* Test input parameters ... */ *info = 0; Bstore = B->Store; ldb = Bstore->lda; nrhs = B->ncol; notran = lsame_(trans, "N"); if ( !notran && !lsame_(trans, "T") && !lsame_(trans, "C") ) *info = -1; else if ( L->nrow != L->ncol || L->nrow < 0 || L->Stype != SLU_SC || L->Dtype != SLU_D || L->Mtype != SLU_TRLU ) *info = -2; else if ( ldb < SUPERLU_MAX(0, L->nrow) || B->Stype != SLU_DN || B->Dtype != SLU_D || B->Mtype != SLU_GE ) *info = -4; if ( *info ) { i = -(*info); xerbla_("dgstrsL", &i); return; } n = L->nrow; work = doubleCalloc(n * nrhs); if ( !work ) ABORT("Malloc fails for local work[]."); soln = doubleMalloc(n); if ( !soln ) ABORT("Malloc fails for local soln[]."); Bmat = Bstore->nzval; Lstore = L->Store; Lval = Lstore->nzval; solve_ops = 0; if ( notran ) { /* Permute right hand sides to form Pr*B */ for (i = 0; i < nrhs; i++) { rhs_work = &Bmat[i*ldb]; for (k = 0; k < n; k++) soln[perm_r[k]] = rhs_work[k]; for (k = 0; k < n; k++) rhs_work[k] = soln[k]; } /* Forward solve PLy=Pb. */ for (k = 0; k <= Lstore->nsuper; k++) { fsupc = L_FST_SUPC(k); istart = L_SUB_START(fsupc); nsupr = L_SUB_START(fsupc+1) - istart; nsupc = L_FST_SUPC(k+1) - fsupc; nrow = nsupr - nsupc; solve_ops += nsupc * (nsupc - 1) * nrhs; solve_ops += 2 * nrow * nsupc * nrhs; if ( nsupc == 1 ) { for (j = 0; j < nrhs; j++) { rhs_work = &Bmat[j*ldb]; luptr = L_NZ_START(fsupc); for (iptr=istart+1; iptr < L_SUB_START(fsupc+1); iptr++){ irow = L_SUB(iptr); ++luptr; rhs_work[irow] -= rhs_work[fsupc] * Lval[luptr]; } } } else { luptr = L_NZ_START(fsupc); #ifdef USE_VENDOR_BLAS #ifdef _CRAY ftcs1 = _cptofcd("L", strlen("L")); ftcs2 = _cptofcd("N", strlen("N")); ftcs3 = _cptofcd("U", strlen("U")); STRSM( ftcs1, ftcs1, ftcs2, ftcs3, &nsupc, &nrhs, &alpha, &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb); SGEMM( ftcs2, ftcs2, &nrow, &nrhs, &nsupc, &alpha, &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb, &beta, &work[0], &n ); #else dtrsm_("L", "L", "N", "U", &nsupc, &nrhs, &alpha, &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb); dgemm_( "N", "N", &nrow, &nrhs, &nsupc, &alpha, &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb, &beta, &work[0], &n ); #endif for (j = 0; j < nrhs; j++) { rhs_work = &Bmat[j*ldb]; work_col = &work[j*n]; iptr = istart + nsupc; for (i = 0; i < nrow; i++) { irow = L_SUB(iptr); rhs_work[irow] -= work_col[i]; /* Scatter */ work_col[i] = 0.0; iptr++; } } #else for (j = 0; j < nrhs; j++) { rhs_work = &Bmat[j*ldb]; dlsolve (nsupr, nsupc, &Lval[luptr], &rhs_work[fsupc]); dmatvec (nsupr, nrow, nsupc, &Lval[luptr+nsupc], &rhs_work[fsupc], &work[0] ); iptr = istart + nsupc; for (i = 0; i < nrow; i++) { irow = L_SUB(iptr); rhs_work[irow] -= work[i]; work[i] = 0.0; iptr++; } } #endif } /* else ... */ } /* for L-solve */ #ifdef DEBUG printf("After L-solve: y=\n"); dprint_soln(n, nrhs, Bmat); #endif SuperLUStat.ops[SOLVE] = solve_ops; } else { printf("Transposed solve not implemented.\n"); exit(0); } SUPERLU_FREE(work); SUPERLU_FREE(soln); }