//======================================================= int EpetraExt_HypreIJMatrix::Multiply(bool TransA, const Epetra_MultiVector& X, Epetra_MultiVector& Y) const { //printf("Proc[%d], Row start: %d, Row End: %d\n", Comm().MyPID(), MyRowStart_, MyRowEnd_); bool SameVectors = false; int NumVectors = X.NumVectors(); if (NumVectors != Y.NumVectors()) return -1; // X and Y must have same number of vectors if(X.Pointers() == Y.Pointers()){ SameVectors = true; } for(int VecNum = 0; VecNum < NumVectors; VecNum++) { //Get values for current vector in multivector. double * x_values; double * y_values; EPETRA_CHK_ERR((*X(VecNum)).ExtractView(&x_values)); double *x_temp = x_local->data; double *y_temp = y_local->data; if(!SameVectors){ EPETRA_CHK_ERR((*Y(VecNum)).ExtractView(&y_values)); } else { y_values = new double[X.MyLength()]; } y_local->data = y_values; EPETRA_CHK_ERR(HYPRE_ParVectorSetConstantValues(par_y,0.0)); // Temporarily make a pointer to data in Hypre for end // Replace data in Hypre vectors with epetra values x_local->data = x_values; // Do actual computation. if(TransA) { // Use transpose of A in multiply EPETRA_CHK_ERR(HYPRE_ParCSRMatrixMatvecT(1.0, ParMatrix_, par_x, 1.0, par_y)); } else { EPETRA_CHK_ERR(HYPRE_ParCSRMatrixMatvec(1.0, ParMatrix_, par_x, 1.0, par_y)); } if(SameVectors){ int NumEntries = Y.MyLength(); std::vector<double> new_values; new_values.resize(NumEntries); std::vector<int> new_indices; new_indices.resize(NumEntries); for(int i = 0; i < NumEntries; i++){ new_values[i] = y_values[i]; new_indices[i] = i; } EPETRA_CHK_ERR((*Y(VecNum)).ReplaceMyValues(NumEntries, &new_values[0], &new_indices[0])); delete[] y_values; } x_local->data = x_temp; y_local->data = y_temp; } double flops = (double) NumVectors * (double) NumGlobalNonzeros(); UpdateFlops(flops); return 0; } //Multiply()
void hypre_F90_IFACE(hypre_parvectorsetconstantvalue, HYPRE_PARVECTORSETCONSTANTVALUE) ( hypre_F90_Obj *vector, hypre_F90_Complex *value, hypre_F90_Int *ierr ) { *ierr = (hypre_F90_Int) ( HYPRE_ParVectorSetConstantValues( hypre_F90_PassObj (HYPRE_ParVector, vector), hypre_F90_PassComplex (value)) ); }
//======================================================= int EpetraExt_HypreIJMatrix::Solve(bool Upper, bool transpose, bool UnitDiagonal, const Epetra_MultiVector & X, Epetra_MultiVector & Y) const { bool SameVectors = false; int NumVectors = X.NumVectors(); if (NumVectors != Y.NumVectors()) return -1; // X and Y must have same number of vectors if(X.Pointers() == Y.Pointers()){ SameVectors = true; } if(SolveOrPrec_ == Solver){ if(IsSolverSetup_[0] == false){ SetupSolver(); } } else { if(IsPrecondSetup_[0] == false){ SetupPrecond(); } } for(int VecNum = 0; VecNum < NumVectors; VecNum++) { //Get values for current vector in multivector. double * x_values; EPETRA_CHK_ERR((*X(VecNum)).ExtractView(&x_values)); double * y_values; if(!SameVectors){ EPETRA_CHK_ERR((*Y(VecNum)).ExtractView(&y_values)); } else { y_values = new double[X.MyLength()]; } // Temporarily make a pointer to data in Hypre for end double *x_temp = x_local->data; // Replace data in Hypre vectors with epetra values x_local->data = x_values; double *y_temp = y_local->data; y_local->data = y_values; EPETRA_CHK_ERR(HYPRE_ParVectorSetConstantValues(par_y, 0.0)); if(transpose && !TransposeSolve_){ // User requested a transpose solve, but the solver selected doesn't provide one EPETRA_CHK_ERR(-1); } if(SolveOrPrec_ == Solver){ // Use the solver methods EPETRA_CHK_ERR(SolverSolvePtr_(Solver_, ParMatrix_, par_x, par_y)); } else { // Apply the preconditioner EPETRA_CHK_ERR(PrecondSolvePtr_(Preconditioner_, ParMatrix_, par_x, par_y)); } if(SameVectors){ int NumEntries = Y.MyLength(); std::vector<double> new_values; new_values.resize(NumEntries); std::vector<int> new_indices; new_indices.resize(NumEntries); for(int i = 0; i < NumEntries; i++){ new_values[i] = y_values[i]; new_indices[i] = i; } EPETRA_CHK_ERR((*Y(VecNum)).ReplaceMyValues(NumEntries, &new_values[0], &new_indices[0])); delete[] y_values; } x_local->data = x_temp; y_local->data = y_temp; } double flops = (double) NumVectors * (double) NumGlobalNonzeros(); UpdateFlops(flops); return 0; } //Solve()