예제 #1
0
/*!\rst
  Test matrix trace and ``tr(AB)``.

  \return
    number of cases where trace functions fail
\endrst*/
OL_WARN_UNUSED_RESULT int TestMatrixTrace() noexcept {
  int total_errors = 0;

  UniformRandomGenerator uniform_generator(34187);

  // test MatrixTrace
  {
    const int size = 4;
    double matrix[Square(size)];
    BuildRandomVector(Square(size), -1.0, 1.0, &uniform_generator, matrix);
    // replace diagonal with known values
    matrix[0*4 + 0] = 1.5;
    matrix[1*4 + 1] = -2.3;
    matrix[2*4 + 2] = 0.0;
    matrix[3*4 + 3] = 3.1;
    const double result = 1.5 - 2.3 + 0.0 + 3.1;
    double output = MatrixTrace(matrix, size);
    if (!CheckDoubleWithinRelative(output, result, std::numeric_limits<double>::epsilon())) {
      ++total_errors;
    }
  }

  // test TraceOfGeneralMatrixMatrixMultiply
  // test by forming random matrices, evaluating the matrix product, and comparing traces of
  // the explicit and shortcut solutions
  {
    double trace_by_explicit_product;
    double trace_by_shortcut;
    // loop over several sizes to make sure we hit all loop unroll paths
    for (int i = 10; i < 15; ++i) {
      std::vector<double> matrix_A(i*i);
      std::vector<double> matrix_B(i*i);
      std::vector<double> matrix_C(i*i);

      BuildRandomVector(Square(i), -1.0, 1.0, &uniform_generator, matrix_A.data());
      BuildRandomVector(Square(i), -1.0, 1.0, &uniform_generator, matrix_B.data());
      GeneralMatrixMatrixMultiply(matrix_A.data(), 'N', matrix_B.data(), 1.0, 0.0, i, i, i, matrix_C.data());
      trace_by_explicit_product = MatrixTrace(matrix_C.data(), i);
      trace_by_shortcut = TraceOfGeneralMatrixMatrixMultiply(matrix_A.data(), matrix_B.data(), i);

      if (!CheckDoubleWithinRelative(trace_by_shortcut, trace_by_explicit_product, 2.0e-14)) {
        ++total_errors;
      }
    }
  }
  return total_errors;
}
예제 #2
0
int main(int argc, char *argv[])
{
	// Initialize libMesh
	libMesh::LibMeshInit init(argc, argv);
	libMesh::Parallel::Communicator& WorldComm = init.comm();

    libMesh::PetscMatrix<libMesh::Number> matrix_A(WorldComm);
    matrix_A.init(4,4,4,4);

    matrix_A.set(0,0,1.);
//    matrix_A.set(0,1,2.);
//    matrix_A.set(0,2,3.);
//    matrix_A.set(0,3,4.);

//    matrix_A.set(1,0,2.);
    matrix_A.set(1,1,5.);
//    matrix_A.set(1,2,3.);
//    matrix_A.set(1,3,7.);

//    matrix_A.set(2,0,3.);
//    matrix_A.set(2,1,3.);
    matrix_A.set(2,2,9.);
//    matrix_A.set(2,3,6.);

//    matrix_A.set(3,0,4.);
//    matrix_A.set(3,1,7.);
//    matrix_A.set(3,2,6.);
    matrix_A.set(3,3,1.);

    matrix_A.close();

    Mat dummy_inv_A;
    MatCreate(PETSC_COMM_WORLD,&dummy_inv_A);
    MatSetType(dummy_inv_A,MATMPIAIJ);
    MatSetSizes(dummy_inv_A,PETSC_DECIDE,PETSC_DECIDE,4,4);
    MatMPIAIJSetPreallocation(dummy_inv_A,2,NULL,0,NULL);
    MatSetUp(dummy_inv_A);

    // Dummy matrices
//    Mat dummy_A, dummy_inv_A;
//
//
    libMesh::PetscVector<libMesh::Number> vector_unity(WorldComm,4,4);
    libMesh::PetscVector<libMesh::Number> vector_dummy_answer(WorldComm,4,4);

	VecSet(vector_unity.vec(),1);
	vector_unity.close();

	VecSet(vector_dummy_answer.vec(),0);
	vector_dummy_answer.close();

    // Solver
//	libMesh::PetscLinearSolver<libMesh::Number> KSP_dummy_solver(WorldComm);
//	KSP_dummy_solver.init(&matrix_A);

//	KSPSetOperators(KSP_dummy_solver.ksp(),matrix_A.mat(),NULL);

	KSP ksp;
	PC pc;

	KSPCreate(PETSC_COMM_WORLD,&ksp);
	KSPSetOperators(ksp, matrix_A.mat(), matrix_A.mat());
	KSPGetPC(ksp,&pc);
	PCSetFromOptions(pc);
	PCType dummy_type;
	PCGetType(pc,&dummy_type);
	std::cout << std::endl << dummy_type << std::endl << std::endl;
//	PCSetType(pc,PCSPAI);
//	PCHYPRESetType(pc,"parasails");
	KSPSetUp(ksp);
	KSPSolve(ksp,vector_unity.vec(),vector_dummy_answer.vec());
	PCComputeExplicitOperator(pc,&dummy_inv_A);
//	KSPGetOperators(KSP_dummy_solver.ksp(),&dummy_A,&dummy_inv_A);

	libMesh::PetscMatrix<libMesh::Number> matrix_invA(dummy_inv_A,WorldComm);
	matrix_invA.close();
//
//	//    KSP_dummy_solver.solve(matrix_A,vector_dummy_answer,vector_unity,1E-5,10000);
//
//	vector_dummy_answer.print_matlab();
//

//	libMesh::PetscMatrix<libMesh::Number> product_mat(WorldComm);
	matrix_A.print_matlab();
	matrix_invA.print_matlab();
	vector_dummy_answer.print_matlab();
	return 0;
}