FLA_Error FLA_Trsv( FLA_Uplo uplo, FLA_Trans trans, FLA_Diag diag, FLA_Obj A, FLA_Obj x ) { FLA_Error r_val; // Check parameters. if ( FLA_Check_error_level() >= FLA_MIN_ERROR_CHECKING ) FLA_Trsv_check( uplo, trans, diag, A, x ); #ifdef FLA_ENABLE_BLAS2_FRONT_END_CNTL_TREES // Invoke FLA_Trsv_internal() with flat control tree that simply calls // external wrapper. r_val = FLA_Trsv_internal( uplo, trans, diag, A, x, fla_trsv_cntl_blas ); #else r_val = FLA_Trsv_external( uplo, trans, diag, A, x ); #endif return r_val; }
FLA_Error FLASH_Trsv( FLA_Uplo uplo, FLA_Trans trans, FLA_Diag diag, FLA_Obj A, FLA_Obj x ) { FLA_Error r_val; FLA_Bool enable_supermatrix; // Check parameters. if ( FLA_Check_error_level() >= FLA_MIN_ERROR_CHECKING ) FLA_Trsv_check( uplo, trans, diag, A, x ); // Find the status of SuperMatrix. enable_supermatrix = FLASH_Queue_get_enabled(); // Temporarily disable SuperMatrix. FLASH_Queue_disable(); // Execute tasks. r_val = FLA_Trsv_internal( uplo, trans, diag, A, x, flash_trsv_cntl ); // Restore SuperMatrix to its previous status. if ( enable_supermatrix ) FLASH_Queue_enable(); return r_val; }
FLA_Error FLA_Trsv_external( FLA_Uplo uplo, FLA_Trans trans, FLA_Diag diag, FLA_Obj A, FLA_Obj x ) { FLA_Datatype datatype; int m_A; int rs_A, cs_A; int inc_x; uplo_t blis_uplo; trans_t blis_trans; diag_t blis_diag; if ( FLA_Check_error_level() == FLA_FULL_ERROR_CHECKING ) FLA_Trsv_check( uplo, trans, diag, A, x ); if ( FLA_Obj_has_zero_dim( A ) ) return FLA_SUCCESS; datatype = FLA_Obj_datatype( A ); m_A = FLA_Obj_length( A ); rs_A = FLA_Obj_row_stride( A ); cs_A = FLA_Obj_col_stride( A ); inc_x = FLA_Obj_vector_inc( x ); FLA_Param_map_flame_to_blis_uplo( uplo, &blis_uplo ); FLA_Param_map_flame_to_blis_trans( trans, &blis_trans ); FLA_Param_map_flame_to_blis_diag( diag, &blis_diag ); switch( datatype ){ case FLA_FLOAT: { float *buff_A = ( float * ) FLA_FLOAT_PTR( A ); float *buff_x = ( float * ) FLA_FLOAT_PTR( x ); bli_strsv( blis_uplo, blis_trans, blis_diag, m_A, buff_A, rs_A, cs_A, buff_x, inc_x ); break; } case FLA_DOUBLE: { double *buff_A = ( double * ) FLA_DOUBLE_PTR( A ); double *buff_x = ( double * ) FLA_DOUBLE_PTR( x ); bli_dtrsv( blis_uplo, blis_trans, blis_diag, m_A, buff_A, rs_A, cs_A, buff_x, inc_x ); break; } case FLA_COMPLEX: { scomplex *buff_A = ( scomplex * ) FLA_COMPLEX_PTR( A ); scomplex *buff_x = ( scomplex * ) FLA_COMPLEX_PTR( x ); bli_ctrsv( blis_uplo, blis_trans, blis_diag, m_A, buff_A, rs_A, cs_A, buff_x, inc_x ); break; } case FLA_DOUBLE_COMPLEX: { dcomplex *buff_A = ( dcomplex * ) FLA_DOUBLE_COMPLEX_PTR( A ); dcomplex *buff_x = ( dcomplex * ) FLA_DOUBLE_COMPLEX_PTR( x ); bli_ztrsv( blis_uplo, blis_trans, blis_diag, m_A, buff_A, rs_A, cs_A, buff_x, inc_x ); break; } } return FLA_SUCCESS; }
FLA_Error FLA_Trsv_external_gpu( FLA_Uplo uplo, FLA_Trans trans, FLA_Diag diag, FLA_Obj A, void* A_gpu, FLA_Obj x, void* x_gpu ) { FLA_Datatype datatype; int m_A; int ldim_A; int inc_x; char blas_uplo; char blas_trans; char blas_diag; if ( FLA_Check_error_level() == FLA_FULL_ERROR_CHECKING ) FLA_Trsv_check( uplo, trans, diag, A, x ); if ( FLA_Obj_has_zero_dim( A ) ) return FLA_SUCCESS; datatype = FLA_Obj_datatype( A ); m_A = FLA_Obj_length( A ); ldim_A = FLA_Obj_length( A ); inc_x = 1; FLA_Param_map_flame_to_netlib_uplo( uplo, &blas_uplo ); FLA_Param_map_flame_to_netlib_trans( trans, &blas_trans ); FLA_Param_map_flame_to_netlib_diag( diag, &blas_diag ); switch( datatype ){ case FLA_FLOAT: { cublasStrsv( blas_uplo, blas_trans, blas_diag, m_A, ( float * ) A_gpu, ldim_A, ( float * ) x_gpu, inc_x ); break; } case FLA_DOUBLE: { cublasDtrsv( blas_uplo, blas_trans, blas_diag, m_A, ( double * ) A_gpu, ldim_A, ( double * ) x_gpu, inc_x ); break; } case FLA_COMPLEX: { cublasCtrsv( blas_uplo, blas_trans, blas_diag, m_A, ( cuComplex * ) A_gpu, ldim_A, ( cuComplex * ) x_gpu, inc_x ); break; } case FLA_DOUBLE_COMPLEX: { cublasZtrsv( blas_uplo, blas_trans, blas_diag, m_A, ( cuDoubleComplex * ) A_gpu, ldim_A, ( cuDoubleComplex * ) x_gpu, inc_x ); break; } } return FLA_SUCCESS; }