コード例 #1
0
ファイル: cgstrs.c プロジェクト: DarkOfTheMoon/HONEI
void
cgstrs (trans_t trans, SuperMatrix *L, SuperMatrix *U,
        int *perm_c, int *perm_r, SuperMatrix *B,
        SuperLUStat_t *stat, int *info)
{

#ifdef _CRAY
    _fcd ftcs1, ftcs2, ftcs3, ftcs4;
#endif
    int      incx = 1, incy = 1;
#ifdef USE_VENDOR_BLAS
    complex   alpha = {1.0, 0.0}, beta = {1.0, 0.0};
    complex   *work_col;
#endif
    complex   temp_comp;
    DNformat *Bstore;
    complex   *Bmat;
    SCformat *Lstore;
    NCformat *Ustore;
    complex   *Lval, *Uval;
    int      fsupc, nrow, nsupr, nsupc, luptr, istart, irow;
    int      i, j, k, iptr, jcol, n, ldb, nrhs;
    complex   *work, *rhs_work, *soln;
    flops_t  solve_ops;
    void cprint_soln();

    /* Test input parameters ... */
    *info = 0;
    Bstore = B->Store;
    ldb = Bstore->lda;
    nrhs = B->ncol;
    if ( trans != NOTRANS && trans != TRANS && trans != CONJ ) *info = -1;
    else if ( L->nrow != L->ncol || L->nrow < 0 ||
              L->Stype != SLU_SC || L->Dtype != SLU_C || L->Mtype != SLU_TRLU )
        *info = -2;
    else if ( U->nrow != U->ncol || U->nrow < 0 ||
              U->Stype != SLU_NC || U->Dtype != SLU_C || U->Mtype != SLU_TRU )
        *info = -3;
    else if ( ldb < SUPERLU_MAX(0, L->nrow) ||
              B->Stype != SLU_DN || B->Dtype != SLU_C || B->Mtype != SLU_GE )
        *info = -6;
    if ( *info ) {
        i = -(*info);
        xerbla_("cgstrs", &i);
        return;
    }

    n = L->nrow;
    work = complexCalloc(n * nrhs);
    if ( !work ) ABORT("Malloc fails for local work[].");
    soln = complexMalloc(n);
    if ( !soln ) ABORT("Malloc fails for local soln[].");

    Bmat = Bstore->nzval;
    Lstore = L->Store;
    Lval = Lstore->nzval;
    Ustore = U->Store;
    Uval = Ustore->nzval;
    solve_ops = 0;

    if ( trans == NOTRANS ) {
        /* Permute right hand sides to form Pr*B */
        for (i = 0; i < nrhs; i++) {
            rhs_work = &Bmat[i*ldb];
            for (k = 0; k < n; k++) soln[perm_r[k]] = rhs_work[k];
            for (k = 0; k < n; k++) rhs_work[k] = soln[k];
        }

        /* Forward solve PLy=Pb. */
        for (k = 0; k <= Lstore->nsuper; k++) {
            fsupc = L_FST_SUPC(k);
            istart = L_SUB_START(fsupc);
            nsupr = L_SUB_START(fsupc+1) - istart;
            nsupc = L_FST_SUPC(k+1) - fsupc;
            nrow = nsupr - nsupc;

            solve_ops += 4 * nsupc * (nsupc - 1) * nrhs;
            solve_ops += 8 * nrow * nsupc * nrhs;

            if ( nsupc == 1 ) {
                for (j = 0; j < nrhs; j++) {
                    rhs_work = &Bmat[j*ldb];
                    luptr = L_NZ_START(fsupc);
                    for (iptr=istart+1; iptr < L_SUB_START(fsupc+1); iptr++){
                        irow = L_SUB(iptr);
                        ++luptr;
                        cc_mult(&temp_comp, &rhs_work[fsupc], &Lval[luptr]);
                        c_sub(&rhs_work[irow], &rhs_work[irow], &temp_comp);
                    }
                }
            } else {
                luptr = L_NZ_START(fsupc);
#ifdef USE_VENDOR_BLAS
#ifdef _CRAY
                ftcs1 = _cptofcd("L", strlen("L"));
                ftcs2 = _cptofcd("N", strlen("N"));
                ftcs3 = _cptofcd("U", strlen("U"));
                CTRSM( ftcs1, ftcs1, ftcs2, ftcs3, &nsupc, &nrhs, &alpha,
                       &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);

                CGEMM( ftcs2, ftcs2, &nrow, &nrhs, &nsupc, &alpha,
                        &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb,
                        &beta, &work[0], &n );
#else
                ctrsm_("L", "L", "N", "U", &nsupc, &nrhs, &alpha,
                       &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);

                cgemm_( "N", "N", &nrow, &nrhs, &nsupc, &alpha,
                        &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb,
                        &beta, &work[0], &n );
#endif
                for (j = 0; j < nrhs; j++) {
                    rhs_work = &Bmat[j*ldb];
                    work_col = &work[j*n];
                    iptr = istart + nsupc;
                    for (i = 0; i < nrow; i++) {
                        irow = L_SUB(iptr);
                        c_sub(&rhs_work[irow], &rhs_work[irow], &work_col[i]);
                        work_col[i].r = 0.0;
                        work_col[i].i = 0.0;
                        iptr++;
                    }
                }
#else
                for (j = 0; j < nrhs; j++) {
                    rhs_work = &Bmat[j*ldb];
                    clsolve (nsupr, nsupc, &Lval[luptr], &rhs_work[fsupc]);
                    cmatvec (nsupr, nrow, nsupc, &Lval[luptr+nsupc],
                            &rhs_work[fsupc], &work[0] );

                    iptr = istart + nsupc;
                    for (i = 0; i < nrow; i++) {
                        irow = L_SUB(iptr);
                        c_sub(&rhs_work[irow], &rhs_work[irow], &work[i]);
                        work[i].r = 0.;
                        work[i].i = 0.;
                        iptr++;
                    }
                }
#endif
            } /* else ... */
        } /* for L-solve */

#ifdef DEBUG
        printf("After L-solve: y=\n");
        cprint_soln(n, nrhs, Bmat);
#endif

        /*
         * Back solve Ux=y.
         */
        for (k = Lstore->nsuper; k >= 0; k--) {
            fsupc = L_FST_SUPC(k);
            istart = L_SUB_START(fsupc);
            nsupr = L_SUB_START(fsupc+1) - istart;
            nsupc = L_FST_SUPC(k+1) - fsupc;
            luptr = L_NZ_START(fsupc);

            solve_ops += 4 * nsupc * (nsupc + 1) * nrhs;

            if ( nsupc == 1 ) {
                rhs_work = &Bmat[0];
                for (j = 0; j < nrhs; j++) {
                    c_div(&rhs_work[fsupc], &rhs_work[fsupc], &Lval[luptr]);
                    rhs_work += ldb;
                }
            } else {
#ifdef USE_VENDOR_BLAS
#ifdef _CRAY
                ftcs1 = _cptofcd("L", strlen("L"));
                ftcs2 = _cptofcd("U", strlen("U"));
                ftcs3 = _cptofcd("N", strlen("N"));
                CTRSM( ftcs1, ftcs2, ftcs3, ftcs3, &nsupc, &nrhs, &alpha,
                       &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
#else
                ctrsm_("L", "U", "N", "N", &nsupc, &nrhs, &alpha,
                       &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
#endif
#else
                for (j = 0; j < nrhs; j++)
                    cusolve ( nsupr, nsupc, &Lval[luptr], &Bmat[fsupc+j*ldb] );
#endif
            }

            for (j = 0; j < nrhs; ++j) {
                rhs_work = &Bmat[j*ldb];
                for (jcol = fsupc; jcol < fsupc + nsupc; jcol++) {
                    solve_ops += 8*(U_NZ_START(jcol+1) - U_NZ_START(jcol));
                    for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); i++ ){
                        irow = U_SUB(i);
                        cc_mult(&temp_comp, &rhs_work[jcol], &Uval[i]);
                        c_sub(&rhs_work[irow], &rhs_work[irow], &temp_comp);
                    }
                }
            }

        } /* for U-solve */

#ifdef DEBUG
        printf("After U-solve: x=\n");
        cprint_soln(n, nrhs, Bmat);
#endif

        /* Compute the final solution X := Pc*X. */
        for (i = 0; i < nrhs; i++) {
            rhs_work = &Bmat[i*ldb];
            for (k = 0; k < n; k++) soln[k] = rhs_work[perm_c[k]];
            for (k = 0; k < n; k++) rhs_work[k] = soln[k];
        }

        stat->ops[SOLVE] = solve_ops;

    } else { /* Solve A'*X=B or CONJ(A)*X=B */
        /* Permute right hand sides to form Pc'*B. */
        for (i = 0; i < nrhs; i++) {
            rhs_work = &Bmat[i*ldb];
            for (k = 0; k < n; k++) soln[perm_c[k]] = rhs_work[k];
            for (k = 0; k < n; k++) rhs_work[k] = soln[k];
        }

        stat->ops[SOLVE] = 0;
        if (trans == TRANS) {
            for (k = 0; k < nrhs; ++k) {
                /* Multiply by inv(U'). */
                sp_ctrsv("U", "T", "N", L, U, &Bmat[k*ldb], stat, info);

                /* Multiply by inv(L'). */
                sp_ctrsv("L", "T", "U", L, U, &Bmat[k*ldb], stat, info);
            }
         } else { /* trans == CONJ */
            for (k = 0; k < nrhs; ++k) {
                /* Multiply by conj(inv(U')). */
                sp_ctrsv("U", "C", "N", L, U, &Bmat[k*ldb], stat, info);

                /* Multiply by conj(inv(L')). */
                sp_ctrsv("L", "C", "U", L, U, &Bmat[k*ldb], stat, info);
            }
         }
        /* Compute the final solution X := Pr'*X (=inv(Pr)*X) */
        for (i = 0; i < nrhs; i++) {
            rhs_work = &Bmat[i*ldb];
            for (k = 0; k < n; k++) soln[k] = rhs_work[perm_r[k]];
            for (k = 0; k < n; k++) rhs_work[k] = soln[k];
        }

    }

    SUPERLU_FREE(work);
    SUPERLU_FREE(soln);
}
コード例 #2
0
ファイル: csp_blas2.c プロジェクト: BranYang/scipy
/*! \brief Solves one of the systems of equations A*x = b,   or   A'*x = b
 * 
 * <pre>
 *   Purpose
 *   =======
 *
 *   sp_ctrsv() solves one of the systems of equations   
 *       A*x = b,   or   A'*x = b,
 *   where b and x are n element vectors and A is a sparse unit , or   
 *   non-unit, upper or lower triangular matrix.   
 *   No test for singularity or near-singularity is included in this   
 *   routine. Such tests must be performed before calling this routine.   
 *
 *   Parameters   
 *   ==========   
 *
 *   uplo   - (input) char*
 *            On entry, uplo specifies whether the matrix is an upper or   
 *             lower triangular matrix as follows:   
 *                uplo = 'U' or 'u'   A is an upper triangular matrix.   
 *                uplo = 'L' or 'l'   A is a lower triangular matrix.   
 *
 *   trans  - (input) char*
 *             On entry, trans specifies the equations to be solved as   
 *             follows:   
 *                trans = 'N' or 'n'   A*x = b.   
 *                trans = 'T' or 't'   A'*x = b.
 *                trans = 'C' or 'c'   A^H*x = b.   
 *
 *   diag   - (input) char*
 *             On entry, diag specifies whether or not A is unit   
 *             triangular as follows:   
 *                diag = 'U' or 'u'   A is assumed to be unit triangular.   
 *                diag = 'N' or 'n'   A is not assumed to be unit   
 *                                    triangular.   
 *	     
 *   L       - (input) SuperMatrix*
 *	       The factor L from the factorization Pr*A*Pc=L*U. Use
 *             compressed row subscripts storage for supernodes,
 *             i.e., L has types: Stype = SC, Dtype = SLU_C, Mtype = TRLU.
 *
 *   U       - (input) SuperMatrix*
 *	        The factor U from the factorization Pr*A*Pc=L*U.
 *	        U has types: Stype = NC, Dtype = SLU_C, Mtype = TRU.
 *    
 *   x       - (input/output) complex*
 *             Before entry, the incremented array X must contain the n   
 *             element right-hand side vector b. On exit, X is overwritten 
 *             with the solution vector x.
 *
 *   info    - (output) int*
 *             If *info = -i, the i-th argument had an illegal value.
 * </pre>
 */
int
sp_ctrsv(char *uplo, char *trans, char *diag, SuperMatrix *L, 
         SuperMatrix *U, complex *x, SuperLUStat_t *stat, int *info)
{
#ifdef _CRAY
    _fcd ftcs1 = _cptofcd("L", strlen("L")),
	 ftcs2 = _cptofcd("N", strlen("N")),
	 ftcs3 = _cptofcd("U", strlen("U"));
#endif
    SCformat *Lstore;
    NCformat *Ustore;
    complex   *Lval, *Uval;
    int incx = 1, incy = 1;
    complex temp;
    complex alpha = {1.0, 0.0}, beta = {1.0, 0.0};
    complex comp_zero = {0.0, 0.0};
    int nrow;
    int fsupc, nsupr, nsupc, luptr, istart, irow;
    int i, k, iptr, jcol;
    complex *work;
    flops_t solve_ops;

    /* Test the input parameters */
    *info = 0;
    if ( strncmp(uplo,"L", 1)!=0 && strncmp(uplo, "U", 1)!=0 ) *info = -1;
    else if ( strncmp(trans, "N", 1)!=0 && strncmp(trans, "T", 1)!=0 && 
              strncmp(trans, "C", 1)!=0) *info = -2;
    else if ( strncmp(diag, "U", 1)!=0 && strncmp(diag, "N", 1)!=0 )
         *info = -3;
    else if ( L->nrow != L->ncol || L->nrow < 0 ) *info = -4;
    else if ( U->nrow != U->ncol || U->nrow < 0 ) *info = -5;
    if ( *info ) {
	i = -(*info);
	input_error("sp_ctrsv", &i);
	return 0;
    }

    Lstore = L->Store;
    Lval = Lstore->nzval;
    Ustore = U->Store;
    Uval = Ustore->nzval;
    solve_ops = 0;

    if ( !(work = complexCalloc(L->nrow)) )
	ABORT("Malloc fails for work in sp_ctrsv().");
    
    if ( strncmp(trans, "N", 1)==0 ) {	/* Form x := inv(A)*x. */
	
	if ( strncmp(uplo, "L", 1)==0 ) {
	    /* Form x := inv(L)*x */
    	    if ( L->nrow == 0 ) return 0; /* Quick return */
	    
	    for (k = 0; k <= Lstore->nsuper; k++) {
		fsupc = L_FST_SUPC(k);
		istart = L_SUB_START(fsupc);
		nsupr = L_SUB_START(fsupc+1) - istart;
		nsupc = L_FST_SUPC(k+1) - fsupc;
		luptr = L_NZ_START(fsupc);
		nrow = nsupr - nsupc;

                /* 1 c_div costs 10 flops */
	        solve_ops += 4 * nsupc * (nsupc - 1) + 10 * nsupc;
	        solve_ops += 8 * nrow * nsupc;

		if ( nsupc == 1 ) {
		    for (iptr=istart+1; iptr < L_SUB_START(fsupc+1); ++iptr) {
			irow = L_SUB(iptr);
			++luptr;
			cc_mult(&comp_zero, &x[fsupc], &Lval[luptr]);
			c_sub(&x[irow], &x[irow], &comp_zero);
		    }
		} else {
#ifdef USE_VENDOR_BLAS
#ifdef _CRAY
		    CTRSV(ftcs1, ftcs2, ftcs3, &nsupc, &Lval[luptr], &nsupr,
		       	&x[fsupc], &incx);
		
		    CGEMV(ftcs2, &nrow, &nsupc, &alpha, &Lval[luptr+nsupc], 
		       	&nsupr, &x[fsupc], &incx, &beta, &work[0], &incy);
#else
		    ctrsv_("L", "N", "U", &nsupc, &Lval[luptr], &nsupr,
		       	&x[fsupc], &incx);
		
		    cgemv_("N", &nrow, &nsupc, &alpha, &Lval[luptr+nsupc], 
		       	&nsupr, &x[fsupc], &incx, &beta, &work[0], &incy);
#endif
#else
		    clsolve ( nsupr, nsupc, &Lval[luptr], &x[fsupc]);
		
		    cmatvec ( nsupr, nsupr-nsupc, nsupc, &Lval[luptr+nsupc],
                             &x[fsupc], &work[0] );
#endif		
		
		    iptr = istart + nsupc;
		    for (i = 0; i < nrow; ++i, ++iptr) {
			irow = L_SUB(iptr);
			c_sub(&x[irow], &x[irow], &work[i]); /* Scatter */
			work[i] = comp_zero;

		    }
	 	}
	    } /* for k ... */
	    
	} else {
	    /* Form x := inv(U)*x */
	    
	    if ( U->nrow == 0 ) return 0; /* Quick return */
	    
	    for (k = Lstore->nsuper; k >= 0; k--) {
	    	fsupc = L_FST_SUPC(k);
	    	nsupr = L_SUB_START(fsupc+1) - L_SUB_START(fsupc);
	    	nsupc = L_FST_SUPC(k+1) - fsupc;
	    	luptr = L_NZ_START(fsupc);
		
                /* 1 c_div costs 10 flops */
    	        solve_ops += 4 * nsupc * (nsupc + 1) + 10 * nsupc;

		if ( nsupc == 1 ) {
		    c_div(&x[fsupc], &x[fsupc], &Lval[luptr]);
		    for (i = U_NZ_START(fsupc); i < U_NZ_START(fsupc+1); ++i) {
			irow = U_SUB(i);
			cc_mult(&comp_zero, &x[fsupc], &Uval[i]);
			c_sub(&x[irow], &x[irow], &comp_zero);
		    }
		} else {
#ifdef USE_VENDOR_BLAS
#ifdef _CRAY
		    CTRSV(ftcs3, ftcs2, ftcs2, &nsupc, &Lval[luptr], &nsupr,
		       &x[fsupc], &incx);
#else
		    ctrsv_("U", "N", "N", &nsupc, &Lval[luptr], &nsupr,
                           &x[fsupc], &incx);
#endif
#else		
		    cusolve ( nsupr, nsupc, &Lval[luptr], &x[fsupc] );
#endif		

		    for (jcol = fsupc; jcol < L_FST_SUPC(k+1); jcol++) {
		        solve_ops += 8*(U_NZ_START(jcol+1) - U_NZ_START(jcol));
		    	for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); 
				i++) {
			    irow = U_SUB(i);
			cc_mult(&comp_zero, &x[jcol], &Uval[i]);
			c_sub(&x[irow], &x[irow], &comp_zero);
		    	}
                    }
		}
	    } /* for k ... */
	    
	}
    } else if ( strncmp(trans, "T", 1)==0 ) { /* Form x := inv(A')*x */
	
	if ( strncmp(uplo, "L", 1)==0 ) {
	    /* Form x := inv(L')*x */
    	    if ( L->nrow == 0 ) return 0; /* Quick return */
	    
	    for (k = Lstore->nsuper; k >= 0; --k) {
	    	fsupc = L_FST_SUPC(k);
	    	istart = L_SUB_START(fsupc);
	    	nsupr = L_SUB_START(fsupc+1) - istart;
	    	nsupc = L_FST_SUPC(k+1) - fsupc;
	    	luptr = L_NZ_START(fsupc);

		solve_ops += 8 * (nsupr - nsupc) * nsupc;

		for (jcol = fsupc; jcol < L_FST_SUPC(k+1); jcol++) {
		    iptr = istart + nsupc;
		    for (i = L_NZ_START(jcol) + nsupc; 
				i < L_NZ_START(jcol+1); i++) {
			irow = L_SUB(iptr);
			cc_mult(&comp_zero, &x[irow], &Lval[i]);
		    	c_sub(&x[jcol], &x[jcol], &comp_zero);
			iptr++;
		    }
		}
		
		if ( nsupc > 1 ) {
		    solve_ops += 4 * nsupc * (nsupc - 1);
#ifdef _CRAY
                    ftcs1 = _cptofcd("L", strlen("L"));
                    ftcs2 = _cptofcd("T", strlen("T"));
                    ftcs3 = _cptofcd("U", strlen("U"));
		    CTRSV(ftcs1, ftcs2, ftcs3, &nsupc, &Lval[luptr], &nsupr,
			&x[fsupc], &incx);
#else
		    ctrsv_("L", "T", "U", &nsupc, &Lval[luptr], &nsupr,
			&x[fsupc], &incx);
#endif
		}
	    }
	} else {
	    /* Form x := inv(U')*x */
	    if ( U->nrow == 0 ) return 0; /* Quick return */
	    
	    for (k = 0; k <= Lstore->nsuper; k++) {
	    	fsupc = L_FST_SUPC(k);
	    	nsupr = L_SUB_START(fsupc+1) - L_SUB_START(fsupc);
	    	nsupc = L_FST_SUPC(k+1) - fsupc;
	    	luptr = L_NZ_START(fsupc);

		for (jcol = fsupc; jcol < L_FST_SUPC(k+1); jcol++) {
		    solve_ops += 8*(U_NZ_START(jcol+1) - U_NZ_START(jcol));
		    for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); i++) {
			irow = U_SUB(i);
			cc_mult(&comp_zero, &x[irow], &Uval[i]);
		    	c_sub(&x[jcol], &x[jcol], &comp_zero);
		    }
		}

                /* 1 c_div costs 10 flops */
		solve_ops += 4 * nsupc * (nsupc + 1) + 10 * nsupc;

		if ( nsupc == 1 ) {
		    c_div(&x[fsupc], &x[fsupc], &Lval[luptr]);
		} else {
#ifdef _CRAY
                    ftcs1 = _cptofcd("U", strlen("U"));
                    ftcs2 = _cptofcd("T", strlen("T"));
                    ftcs3 = _cptofcd("N", strlen("N"));
		    CTRSV( ftcs1, ftcs2, ftcs3, &nsupc, &Lval[luptr], &nsupr,
			    &x[fsupc], &incx);
#else
		    ctrsv_("U", "T", "N", &nsupc, &Lval[luptr], &nsupr,
			    &x[fsupc], &incx);
#endif
		}
	    } /* for k ... */
	}
    } else { /* Form x := conj(inv(A'))*x */
	
	if ( strncmp(uplo, "L", 1)==0 ) {
	    /* Form x := conj(inv(L'))*x */
    	    if ( L->nrow == 0 ) return 0; /* Quick return */
	    
	    for (k = Lstore->nsuper; k >= 0; --k) {
	    	fsupc = L_FST_SUPC(k);
	    	istart = L_SUB_START(fsupc);
	    	nsupr = L_SUB_START(fsupc+1) - istart;
	    	nsupc = L_FST_SUPC(k+1) - fsupc;
	    	luptr = L_NZ_START(fsupc);

		solve_ops += 8 * (nsupr - nsupc) * nsupc;

		for (jcol = fsupc; jcol < L_FST_SUPC(k+1); jcol++) {
		    iptr = istart + nsupc;
		    for (i = L_NZ_START(jcol) + nsupc; 
				i < L_NZ_START(jcol+1); i++) {
			irow = L_SUB(iptr);
                        cc_conj(&temp, &Lval[i]);
			cc_mult(&comp_zero, &x[irow], &temp);
		    	c_sub(&x[jcol], &x[jcol], &comp_zero);
			iptr++;
		    }
 		}
 		
 		if ( nsupc > 1 ) {
		    solve_ops += 4 * nsupc * (nsupc - 1);
#ifdef _CRAY
                    ftcs1 = _cptofcd("L", strlen("L"));
                    ftcs2 = _cptofcd(trans, strlen("T"));
                    ftcs3 = _cptofcd("U", strlen("U"));
		    CTRSV(ftcs1, ftcs2, ftcs3, &nsupc, &Lval[luptr], &nsupr,
			&x[fsupc], &incx);
#else
                    ctrsv_("L", trans, "U", &nsupc, &Lval[luptr], &nsupr,
                           &x[fsupc], &incx);
#endif
		}
	    }
	} else {
	    /* Form x := conj(inv(U'))*x */
	    if ( U->nrow == 0 ) return 0; /* Quick return */
	    
	    for (k = 0; k <= Lstore->nsuper; k++) {
	    	fsupc = L_FST_SUPC(k);
	    	nsupr = L_SUB_START(fsupc+1) - L_SUB_START(fsupc);
	    	nsupc = L_FST_SUPC(k+1) - fsupc;
	    	luptr = L_NZ_START(fsupc);

		for (jcol = fsupc; jcol < L_FST_SUPC(k+1); jcol++) {
		    solve_ops += 8*(U_NZ_START(jcol+1) - U_NZ_START(jcol));
		    for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); i++) {
			irow = U_SUB(i);
                        cc_conj(&temp, &Uval[i]);
			cc_mult(&comp_zero, &x[irow], &temp);
		    	c_sub(&x[jcol], &x[jcol], &comp_zero);
		    }
		}

                /* 1 c_div costs 10 flops */
		solve_ops += 4 * nsupc * (nsupc + 1) + 10 * nsupc;
 
		if ( nsupc == 1 ) {
                    cc_conj(&temp, &Lval[luptr]);
		    c_div(&x[fsupc], &x[fsupc], &temp);
		} else {
#ifdef _CRAY
                    ftcs1 = _cptofcd("U", strlen("U"));
                    ftcs2 = _cptofcd(trans, strlen("T"));
                    ftcs3 = _cptofcd("N", strlen("N"));
		    CTRSV( ftcs1, ftcs2, ftcs3, &nsupc, &Lval[luptr], &nsupr,
			    &x[fsupc], &incx);
#else
                    ctrsv_("U", trans, "N", &nsupc, &Lval[luptr], &nsupr,
                               &x[fsupc], &incx);
#endif
  		}
  	    } /* for k ... */
  	}
    }

    stat->ops[SOLVE] += solve_ops;
    SUPERLU_FREE(work);
    return 0;
}
コード例 #3
0
int
sp_ctrsv(char *uplo, char *trans, char *diag, SuperMatrix *L, 
	 SuperMatrix *U, complex *x, int *info)
{
/*
 *   Purpose
 *   =======
 *
 *   sp_ctrsv() solves one of the systems of equations   
 *       A*x = b,   or   A'*x = b,
 *   where b and x are n element vectors and A is a sparse unit , or   
 *   non-unit, upper or lower triangular matrix.   
 *   No test for singularity or near-singularity is included in this   
 *   routine. Such tests must be performed before calling this routine.   
 *
 *   Parameters   
 *   ==========   
 *
 *   uplo   - (input) char*
 *            On entry, uplo specifies whether the matrix is an upper or   
 *             lower triangular matrix as follows:   
 *                uplo = 'U' or 'u'   A is an upper triangular matrix.   
 *                uplo = 'L' or 'l'   A is a lower triangular matrix.   
 *
 *   trans  - (input) char*
 *             On entry, trans specifies the equations to be solved as   
 *             follows:   
 *                trans = 'N' or 'n'   A*x = b.   
 *                trans = 'T' or 't'   A'*x = b.   
 *                trans = 'C' or 'c'   A'*x = b.   
 *
 *   diag   - (input) char*
 *             On entry, diag specifies whether or not A is unit   
 *             triangular as follows:   
 *                diag = 'U' or 'u'   A is assumed to be unit triangular.   
 *                diag = 'N' or 'n'   A is not assumed to be unit   
 *                                    triangular.   
 *	     
 *   L       - (input) SuperMatrix*
 *	       The factor L from the factorization Pr*A*Pc=L*U. Use
 *             compressed row subscripts storage for supernodes,
 *             i.e., L has types: Stype = SC, Dtype = SLU_C, Mtype = TRLU.
 *
 *   U       - (input) SuperMatrix*
 *	        The factor U from the factorization Pr*A*Pc=L*U.
 *	        U has types: Stype = NC, Dtype = SLU_C, Mtype = TRU.
 *    
 *   x       - (input/output) complex*
 *             Before entry, the incremented array X must contain the n   
 *             element right-hand side vector b. On exit, X is overwritten 
 *             with the solution vector x.
 *
 *   info    - (output) int*
 *             If *info = -i, the i-th argument had an illegal value.
 *
 */
#ifdef _CRAY
    _fcd ftcs1 = _cptofcd("L", strlen("L")),
	 ftcs2 = _cptofcd("N", strlen("N")),
	 ftcs3 = _cptofcd("U", strlen("U"));
#endif
    SCformat *Lstore;
    NCformat *Ustore;
    complex   *Lval, *Uval;
    int incx = 1, incy = 1;
    complex alpha = {1.0, 0.0}, beta = {1.0, 0.0};
    complex comp_zero = {0.0, 0.0};
    int nrow;
    int fsupc, nsupr, nsupc, luptr, istart, irow;
    int i, k, iptr, jcol;
    complex *work;
    flops_t solve_ops;
    extern SuperLUStat_t SuperLUStat;

    /* Test the input parameters */
    *info = 0;
    if ( !lsame_(uplo,"L") && !lsame_(uplo, "U") ) *info = -1;
    else if ( !lsame_(trans, "N") && !lsame_(trans, "T") ) *info = -2;
    else if ( !lsame_(diag, "U") && !lsame_(diag, "N") ) *info = -3;
    else if ( L->nrow != L->ncol || L->nrow < 0 ) *info = -4;
    else if ( U->nrow != U->ncol || U->nrow < 0 ) *info = -5;
    if ( *info ) {
	i = -(*info);
	xerbla_("sp_ctrsv", &i);
	return 0;
    }

    Lstore = L->Store;
    Lval = Lstore->nzval;
    Ustore = U->Store;
    Uval = Ustore->nzval;
    solve_ops = 0;

    if ( !(work = complexCalloc(L->nrow)) )
	ABORT("Malloc fails for work in sp_ctrsv().");
    
    if ( lsame_(trans, "N") ) {	/* Form x := inv(A)*x. */
	
	if ( lsame_(uplo, "L") ) {
	    /* Form x := inv(L)*x */
    	    if ( L->nrow == 0 ) return 0; /* Quick return */
	    
	    for (k = 0; k <= Lstore->nsuper; k++) {
		fsupc = L_FST_SUPC(k);
		istart = L_SUB_START(fsupc);
		nsupr = L_SUB_START(fsupc+1) - istart;
		nsupc = L_FST_SUPC(k+1) - fsupc;
		luptr = L_NZ_START(fsupc);
		nrow = nsupr - nsupc;

	        solve_ops += 4 * nsupc * (nsupc - 1);
	        solve_ops += 8 * nrow * nsupc;

		if ( nsupc == 1 ) {
		    for (iptr=istart+1; iptr < L_SUB_START(fsupc+1); ++iptr) {
			irow = L_SUB(iptr);
			++luptr;
			cc_mult(&comp_zero, &x[fsupc], &Lval[luptr]);
			c_sub(&x[irow], &x[irow], &comp_zero);
		    }
		} else {
#ifdef USE_VENDOR_BLAS
#ifdef _CRAY
		    CTRSV(ftcs1, ftcs2, ftcs3, &nsupc, &Lval[luptr], &nsupr,
		       	&x[fsupc], &incx);
		
		    CGEMV(ftcs2, &nrow, &nsupc, &alpha, &Lval[luptr+nsupc], 
		       	&nsupr, &x[fsupc], &incx, &beta, &work[0], &incy);
#else
		    ctrsv_("L", "N", "U", &nsupc, &Lval[luptr], &nsupr,
		       	&x[fsupc], &incx);
		
		    cgemv_("N", &nrow, &nsupc, &alpha, &Lval[luptr+nsupc], 
		       	&nsupr, &x[fsupc], &incx, &beta, &work[0], &incy);
#endif
#else
		    clsolve ( nsupr, nsupc, &Lval[luptr], &x[fsupc]);
		
		    cmatvec ( nsupr, nsupr-nsupc, nsupc, &Lval[luptr+nsupc],
			&x[fsupc], &work[0] );
#endif		
		
		    iptr = istart + nsupc;
		    for (i = 0; i < nrow; ++i, ++iptr) {
			irow = L_SUB(iptr);
			c_sub(&x[irow], &x[irow], &work[i]); /* Scatter */
			work[i] = comp_zero;

		    }
	 	}
	    } /* for k ... */
	    
	} else {
	    /* Form x := inv(U)*x */
	    
	    if ( U->nrow == 0 ) return 0; /* Quick return */
	    
	    for (k = Lstore->nsuper; k >= 0; k--) {
	    	fsupc = L_FST_SUPC(k);
	    	nsupr = L_SUB_START(fsupc+1) - L_SUB_START(fsupc);
	    	nsupc = L_FST_SUPC(k+1) - fsupc;
	    	luptr = L_NZ_START(fsupc);
		
    	        solve_ops += 4 * nsupc * (nsupc + 1);

		if ( nsupc == 1 ) {
		    c_div(&x[fsupc], &x[fsupc], &Lval[luptr]);
		    for (i = U_NZ_START(fsupc); i < U_NZ_START(fsupc+1); ++i) {
			irow = U_SUB(i);
			cc_mult(&comp_zero, &x[fsupc], &Uval[i]);
			c_sub(&x[irow], &x[irow], &comp_zero);
		    }
		} else {
#ifdef USE_VENDOR_BLAS
#ifdef _CRAY
		    CTRSV(ftcs3, ftcs2, ftcs2, &nsupc, &Lval[luptr], &nsupr,
		       &x[fsupc], &incx);
#else
		    ctrsv_("U", "N", "N", &nsupc, &Lval[luptr], &nsupr,
		       &x[fsupc], &incx);
#endif
#else		
		    cusolve ( nsupr, nsupc, &Lval[luptr], &x[fsupc] );
#endif		

		    for (jcol = fsupc; jcol < L_FST_SUPC(k+1); jcol++) {
		        solve_ops += 8*(U_NZ_START(jcol+1) - U_NZ_START(jcol));
		    	for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); 
				i++) {
			    irow = U_SUB(i);
			cc_mult(&comp_zero, &x[jcol], &Uval[i]);
			c_sub(&x[irow], &x[irow], &comp_zero);
		    	}
                    }
		}
	    } /* for k ... */
	    
	}
    } else { /* Form x := inv(A')*x */
	
	if ( lsame_(uplo, "L") ) {
	    /* Form x := inv(L')*x */
    	    if ( L->nrow == 0 ) return 0; /* Quick return */
	    
	    for (k = Lstore->nsuper; k >= 0; --k) {
	    	fsupc = L_FST_SUPC(k);
	    	istart = L_SUB_START(fsupc);
	    	nsupr = L_SUB_START(fsupc+1) - istart;
	    	nsupc = L_FST_SUPC(k+1) - fsupc;
	    	luptr = L_NZ_START(fsupc);

		solve_ops += 8 * (nsupr - nsupc) * nsupc;

		for (jcol = fsupc; jcol < L_FST_SUPC(k+1); jcol++) {
		    iptr = istart + nsupc;
		    for (i = L_NZ_START(jcol) + nsupc; 
				i < L_NZ_START(jcol+1); i++) {
			irow = L_SUB(iptr);
			cc_mult(&comp_zero, &x[irow], &Lval[i]);
		    	c_sub(&x[jcol], &x[jcol], &comp_zero);
			iptr++;
		    }
		}
		
		if ( nsupc > 1 ) {
		    solve_ops += 4 * nsupc * (nsupc - 1);
#ifdef _CRAY
                    ftcs1 = _cptofcd("L", strlen("L"));
                    ftcs2 = _cptofcd("T", strlen("T"));
                    ftcs3 = _cptofcd("U", strlen("U"));
		    CTRSV(ftcs1, ftcs2, ftcs3, &nsupc, &Lval[luptr], &nsupr,
			&x[fsupc], &incx);
#else
		    ctrsv_("L", "T", "U", &nsupc, &Lval[luptr], &nsupr,
			&x[fsupc], &incx);
#endif
		}
	    }
	} else {
	    /* Form x := inv(U')*x */
	    if ( U->nrow == 0 ) return 0; /* Quick return */
	    
	    for (k = 0; k <= Lstore->nsuper; k++) {
	    	fsupc = L_FST_SUPC(k);
	    	nsupr = L_SUB_START(fsupc+1) - L_SUB_START(fsupc);
	    	nsupc = L_FST_SUPC(k+1) - fsupc;
	    	luptr = L_NZ_START(fsupc);

		for (jcol = fsupc; jcol < L_FST_SUPC(k+1); jcol++) {
		    solve_ops += 8*(U_NZ_START(jcol+1) - U_NZ_START(jcol));
		    for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); i++) {
			irow = U_SUB(i);
			cc_mult(&comp_zero, &x[irow], &Uval[i]);
		    	c_sub(&x[jcol], &x[jcol], &comp_zero);
		    }
		}

		solve_ops += 4 * nsupc * (nsupc + 1);

		if ( nsupc == 1 ) {
		    c_div(&x[fsupc], &x[fsupc], &Lval[luptr]);
		} else {
#ifdef _CRAY
                    ftcs1 = _cptofcd("U", strlen("U"));
                    ftcs2 = _cptofcd("T", strlen("T"));
                    ftcs3 = _cptofcd("N", strlen("N"));
		    CTRSV( ftcs1, ftcs2, ftcs3, &nsupc, &Lval[luptr], &nsupr,
			    &x[fsupc], &incx);
#else
		    ctrsv_("U", "T", "N", &nsupc, &Lval[luptr], &nsupr,
			    &x[fsupc], &incx);
#endif
		}
	    } /* for k ... */
	}
    }

    SuperLUStat.ops[SOLVE] += solve_ops;
    SUPERLU_FREE(work);
    return 0;
}