cusparseStatus_t __cusparseXcsrgemm__(cusparseHandle_t handle, cusparseOperation_t transA, cusparseOperation_t transB, int m, int n, int k, const cusparseMatDescr_t descrA, const int nnzA, const double *csrValA, const int *csrRowPtrA, const int *csrColIndA, const cusparseMatDescr_t descrB, const int nnzB, const double *csrValB, const int *csrRowPtrB, const int *csrColIndB, const cusparseMatDescr_t descrC, double *csrValC, const int *csrRowPtrC, int *csrColIndC ) { return cusparseDcsrgemm(handle, transA, transB, m, n, k, descrA, nnzA, csrValA, csrRowPtrA, csrColIndA, descrB, nnzB, csrValB, csrRowPtrB, csrColIndB, descrC, csrValC, csrRowPtrC, csrColIndC ); }
extern "C" magma_int_t magma_dcuspmm( magma_d_sparse_matrix A, magma_d_sparse_matrix B, magma_d_sparse_matrix *AB ){ if( A.memory_location == Magma_DEV && B.memory_location == Magma_DEV && ( A.storage_type == Magma_CSR || A.storage_type == Magma_CSRCOO ) && ( B.storage_type == Magma_CSR || B.storage_type == Magma_CSRCOO ) ){ magma_d_sparse_matrix C; C.num_rows = A.num_rows; C.num_cols = A.num_cols; C.storage_type = A.storage_type; C.memory_location = A.memory_location; // CUSPARSE context // cusparseHandle_t handle; cusparseStatus_t cusparseStatus; cusparseStatus = cusparseCreate(&handle); if(cusparseStatus != 0) printf("error in Handle.\n"); cusparseMatDescr_t descrA; cusparseMatDescr_t descrB; cusparseMatDescr_t descrC; cusparseStatus = cusparseCreateMatDescr(&descrA); cusparseStatus = cusparseCreateMatDescr(&descrB); cusparseStatus = cusparseCreateMatDescr(&descrC); if(cusparseStatus != 0) printf("error in MatrDescr.\n"); cusparseStatus = cusparseSetMatType(descrA,CUSPARSE_MATRIX_TYPE_GENERAL); cusparseSetMatType(descrB,CUSPARSE_MATRIX_TYPE_GENERAL); cusparseSetMatType(descrC,CUSPARSE_MATRIX_TYPE_GENERAL); if(cusparseStatus != 0) printf("error in MatrType.\n"); cusparseStatus = cusparseSetMatIndexBase(descrA,CUSPARSE_INDEX_BASE_ZERO); cusparseSetMatIndexBase(descrB,CUSPARSE_INDEX_BASE_ZERO); cusparseSetMatIndexBase(descrC,CUSPARSE_INDEX_BASE_ZERO); if(cusparseStatus != 0) printf("error in IndexBase.\n"); // multiply A and B on the device magma_int_t baseC; // nnzTotalDevHostPtr points to host memory magma_index_t *nnzTotalDevHostPtr = (magma_index_t*) &C.nnz; cusparseSetPointerMode(handle, CUSPARSE_POINTER_MODE_HOST); magma_index_malloc( &C.row, (A.num_rows + 1) ); cusparseXcsrgemmNnz(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE, A.num_rows, A.num_rows, A.num_rows, descrA, A.nnz, A.row, A.col, descrB, B.nnz, B.row, B.col, descrC, C.row, nnzTotalDevHostPtr ); if (NULL != nnzTotalDevHostPtr){ C.nnz = *nnzTotalDevHostPtr; }else{ // workaround as nnz and base C are magma_int_t magma_index_t base_t, nnz_t; magma_index_getvector( 1, C.row+C.num_rows, 1, &nnz_t, 1 ); magma_index_getvector( 1, C.row, 1, &base_t, 1 ); C.nnz = (magma_int_t) nnz_t; baseC = (magma_int_t) base_t; C.nnz -= baseC; } magma_index_malloc( &C.col, C.nnz ); magma_dmalloc( &C.val, C.nnz ); cusparseDcsrgemm(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE, A.num_rows, A.num_rows, A.num_rows, descrA, A.nnz, A.val, A.row, A.col, descrB, B.nnz, B.val, B.row, B.col, descrC, C.val, C.row, C.col); cusparseDestroyMatDescr( descrA ); cusparseDestroyMatDescr( descrB ); cusparseDestroyMatDescr( descrC ); cusparseDestroy( handle ); // end CUSPARSE context // magma_d_mtransfer( C, AB, Magma_DEV, Magma_DEV ); magma_d_mfree( &C ); return MAGMA_SUCCESS; } else{ printf("error: CSRMM only supported on device and CSR format.\n"); return MAGMA_SUCCESS; } }
void cuSPARSE_apply( KernelHandle *handle, typename KernelHandle::row_lno_t m, typename KernelHandle::row_lno_t n, typename KernelHandle::row_lno_t k, in_row_index_view_type row_mapA, in_nonzero_index_view_type entriesA, in_nonzero_value_view_type valuesA, bool transposeA, in_row_index_view_type row_mapB, in_nonzero_index_view_type entriesB, in_nonzero_value_view_type valuesB, bool transposeB, typename in_row_index_view_type::non_const_type &row_mapC, typename in_nonzero_index_view_type::non_const_type &entriesC, typename in_nonzero_value_view_type::non_const_type &valuesC){ #ifdef KERNELS_HAVE_CUSPARSE typedef typename KernelHandle::row_lno_t idx; typedef in_row_index_view_type idx_array_type; typedef typename KernelHandle::nnz_scalar_t value_type; typedef typename in_row_index_view_type::device_type device1; typedef typename in_nonzero_index_view_type::device_type device2; typedef typename in_nonzero_value_view_type::device_type device3; std::cout << "RUNNING CUSParse" << std::endl; if (Kokkos::Impl::is_same<Kokkos::Cuda, device1 >::value){ std::cerr << "MEMORY IS NOT ALLOCATED IN GPU DEVICE for CUSPARSE" << std::endl; return; } if (Kokkos::Impl::is_same<Kokkos::Cuda, device2 >::value){ std::cerr << "MEMORY IS NOT ALLOCATED IN GPU DEVICE for CUSPARSE" << std::endl; return; } if (Kokkos::Impl::is_same<Kokkos::Cuda, device3 >::value){ std::cerr << "MEMORY IS NOT ALLOCATED IN GPU DEVICE for CUSPARSE" << std::endl; return; } if (Kokkos::Impl::is_same<idx, int>::value){ int *a_xadj = (int *)row_mapA.ptr_on_device(); int *b_xadj = (int *)row_mapB.ptr_on_device(); int *c_xadj = (int *)row_mapC.ptr_on_device(); int *a_adj = (int *)entriesA.ptr_on_device(); int *b_adj = (int *)entriesB.ptr_on_device(); int *c_adj = (int *)entriesC.ptr_on_device(); typename KernelHandle::SPGEMMcuSparseHandleType *h = handle->get_cuSparseHandle(); int nnzA = entriesA.dimension_0(); int nnzB = entriesB.dimension_0(); value_type *a_ew = valuesA.ptr_on_device(); value_type *b_ew = valuesB.ptr_on_device(); value_type *c_ew = valuesC.ptr_on_device(); if (Kokkos::Impl::is_same<value_type, float>::value){ cusparseScsrgemm( h->handle, h->transA, h->transB, m, n, k, h->a_descr, nnzA, (float *)a_ew, a_xadj, a_adj, h->b_descr, nnzB, (float *)b_ew, b_xadj, b_adj, h->c_descr, (float *)c_ew, c_xadj, c_adj); } else if (Kokkos::Impl::is_same<value_type, double>::value){ cusparseDcsrgemm( h->handle, h->transA, h->transB, m, n, k, h->a_descr, nnzA, (double *)a_ew, a_xadj, a_adj, h->b_descr, nnzB, (double *)b_ew, b_xadj, b_adj, h->c_descr, (double *)c_ew, c_xadj, c_adj); } else { std::cerr << "CUSPARSE requires float or double values. cuComplex and cuDoubleComplex are not implemented yet." << std::endl; return; } } else { std::cerr << "CUSPARSE requires integer values" << std::endl; return; } #else std::cerr << "CUSPARSE IS NOT DEFINED" << std::endl; return; #endif }