void libblis_test_gemmtrsm_ukr_check( side_t side, obj_t* alpha, obj_t* a1x, obj_t* a11, obj_t* bx1, obj_t* b11, obj_t* c11, obj_t* c11_orig, double* resid ) { num_t dt = bli_obj_datatype( *b11 ); num_t dt_real = bli_obj_datatype_proj_to_real( *b11 ); dim_t m = bli_obj_length( *b11 ); dim_t n = bli_obj_width( *b11 ); dim_t k = bli_obj_width( *a1x ); obj_t kappa, norm; obj_t t, v, w, z; double junk; // // Pre-conditions: // - a1x, a11, bx1, c11_orig are randomized; a11 is triangular. // - contents of b11 == contents of c11. // - side == BLIS_LEFT. // // Under these conditions, we assume that the implementation for // // B := inv(A11) * ( alpha * B11 - A1x * Bx1 ) (side = left) // // is functioning correctly if // // fnorm( v - z ) // // is negligible, where // // v = B11 * t // // z = ( inv(A11) * ( alpha * B11_orig - A1x * Bx1 ) ) * t // = inv(A11) * ( alpha * B11_orig * t - A1x * Bx1 * t ) // = inv(A11) * ( alpha * B11_orig * t - A1x * w ) // bli_obj_scalar_init_detached( dt, &kappa ); bli_obj_scalar_init_detached( dt_real, &norm ); if ( bli_is_left( side ) ) { bli_obj_create( dt, n, 1, 0, 0, &t ); bli_obj_create( dt, m, 1, 0, 0, &v ); bli_obj_create( dt, k, 1, 0, 0, &w ); bli_obj_create( dt, m, 1, 0, 0, &z ); } else // else if ( bli_is_left( side ) ) { // BLIS does not currently support right-side micro-kernels. bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED ); } bli_randv( &t ); bli_setsc( 1.0/( double )n, 0.0, &kappa ); bli_scalv( &kappa, &t ); bli_gemv( &BLIS_ONE, b11, &t, &BLIS_ZERO, &v ); // Restore the diagonal of a11 to its original, un-inverted state // (needed for trsv). bli_invertd( a11 ); if ( bli_is_left( side ) ) { bli_gemv( &BLIS_ONE, bx1, &t, &BLIS_ZERO, &w ); bli_gemv( alpha, c11_orig, &t, &BLIS_ZERO, &z ); bli_gemv( &BLIS_MINUS_ONE, a1x, &w, &BLIS_ONE, &z ); bli_trsv( &BLIS_ONE, a11, &z ); } else // else if ( bli_is_left( side ) ) { // BLIS does not currently support right-side micro-kernels. bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED ); } bli_subv( &z, &v ); bli_fnormv( &v, &norm ); bli_getsc( &norm, resid, &junk ); bli_obj_free( &t ); bli_obj_free( &v ); bli_obj_free( &w ); bli_obj_free( &z ); }
int main( int argc, char** argv ) { obj_t a, b, c; obj_t c_save; obj_t alpha, beta; dim_t m, n; dim_t p; dim_t p_begin, p_end, p_inc; int m_input, n_input; num_t dt_a, dt_b, dt_c; num_t dt_alpha, dt_beta; int r, n_repeats; side_t side; uplo_t uplo; double dtime; double dtime_save; double gflops; bli_init(); n_repeats = 3; if( argc < 7 ) { printf("Usage:\n"); printf("test_foo.x m n k p_begin p_inc p_end:\n"); exit; } int world_size, world_rank, provided; MPI_Init_thread( NULL, NULL, MPI_THREAD_FUNNELED, &provided ); MPI_Comm_size( MPI_COMM_WORLD, &world_size ); MPI_Comm_rank( MPI_COMM_WORLD, &world_rank ); m_input = strtol( argv[1], NULL, 10 ); n_input = strtol( argv[2], NULL, 10 ); p_begin = strtol( argv[4], NULL, 10 ); p_inc = strtol( argv[5], NULL, 10 ); p_end = strtol( argv[6], NULL, 10 ); #if 1 dt_a = BLIS_DOUBLE; dt_b = BLIS_DOUBLE; dt_c = BLIS_DOUBLE; dt_alpha = BLIS_DOUBLE; dt_beta = BLIS_DOUBLE; #else dt_a = dt_b = dt_c = dt_alpha = dt_beta = BLIS_FLOAT; //dt_a = dt_b = dt_c = dt_alpha = dt_beta = BLIS_SCOMPLEX; #endif side = BLIS_LEFT; //side = BLIS_RIGHT; uplo = BLIS_LOWER; //uplo = BLIS_UPPER; for ( p = p_begin + world_rank * p_inc; p <= p_end; p += p_inc * world_size ) { if ( m_input < 0 ) m = p * ( dim_t )abs(m_input); else m = ( dim_t ) m_input; if ( n_input < 0 ) n = p * ( dim_t )abs(n_input); else n = ( dim_t ) n_input; bli_obj_create( dt_alpha, 1, 1, 0, 0, &alpha ); bli_obj_create( dt_beta, 1, 1, 0, 0, &beta ); if ( bli_is_left( side ) ) bli_obj_create( dt_a, m, m, 0, 0, &a ); else bli_obj_create( dt_a, n, n, 0, 0, &a ); bli_obj_create( dt_b, m, n, 0, 0, &b ); bli_obj_create( dt_c, m, n, 0, 0, &c ); bli_obj_create( dt_c, m, n, 0, 0, &c_save ); bli_obj_set_struc( BLIS_TRIANGULAR, &a ); bli_obj_set_uplo( uplo, &a ); //bli_obj_set_diag( BLIS_UNIT_DIAG, &a ); bli_randm( &a ); bli_randm( &c ); bli_randm( &b ); /* { obj_t a2; bli_obj_alias_to( &a, &a2 ); bli_obj_toggle_uplo( &a2 ); bli_obj_inc_diag_offset( 1, &a2 ); bli_setm( &BLIS_ZERO, &a2 ); bli_obj_inc_diag_offset( -2, &a2 ); bli_obj_toggle_uplo( &a2 ); bli_obj_set_diag( BLIS_NONUNIT_DIAG, &a2 ); bli_scalm( &BLIS_TWO, &a2 ); //bli_scalm( &BLIS_TWO, &a ); } */ bli_setsc( (2.0/1.0), 0.0, &alpha ); bli_setsc( (1.0/1.0), 0.0, &beta ); bli_copym( &c, &c_save ); dtime_save = 1.0e9; for ( r = 0; r < n_repeats; ++r ) { bli_copym( &c_save, &c ); dtime = bli_clock(); #ifdef PRINT /* obj_t ar, ai; bli_obj_alias_to( &a, &ar ); bli_obj_alias_to( &a, &ai ); bli_obj_set_dt( BLIS_DOUBLE, &ar ); ar.rs *= 2; ar.cs *= 2; bli_obj_set_dt( BLIS_DOUBLE, &ai ); ai.rs *= 2; ai.cs *= 2; ai.buffer = ( double* )ai.buffer + 1; bli_printm( "ar", &ar, "%4.1f", "" ); bli_printm( "ai", &ai, "%4.1f", "" ); */ bli_invertd( &a ); bli_printm( "a", &a, "%4.1f", "" ); bli_invertd( &a ); bli_printm( "c", &c, "%4.1f", "" ); #endif #ifdef BLIS //bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); bli_trsm( side, //bli_trsm4m( side, //bli_trsm3m( side, &alpha, &a, &c ); #else if ( bli_is_real( dt_a ) ) { f77_char side = 'L'; f77_char uplo = 'L'; f77_char transa = 'N'; f77_char diag = 'N'; f77_int mm = bli_obj_length( &c ); f77_int nn = bli_obj_width( &c ); f77_int lda = bli_obj_col_stride( &a ); f77_int ldc = bli_obj_col_stride( &c ); float * alphap = bli_obj_buffer( &alpha ); float * ap = bli_obj_buffer( &a ); float * cp = bli_obj_buffer( &c ); strsm_( &side, &uplo, &transa, &diag, &mm, &nn, alphap, ap, &lda, cp, &ldc ); } else // if ( bli_is_complex( dt_a ) ) { f77_char side = 'L'; f77_char uplo = 'L'; f77_char transa = 'N'; f77_char diag = 'N'; f77_int mm = bli_obj_length( &c ); f77_int nn = bli_obj_width( &c ); f77_int lda = bli_obj_col_stride( &a ); f77_int ldc = bli_obj_col_stride( &c ); scomplex* alphap = bli_obj_buffer( &alpha ); scomplex* ap = bli_obj_buffer( &a ); scomplex* cp = bli_obj_buffer( &c ); ctrsm_( &side, //ztrsm_( &side, &uplo, &transa, &diag, &mm, &nn, alphap, ap, &lda, cp, &ldc ); } #endif #ifdef PRINT bli_printm( "c after", &c, "%4.1f", "" ); exit(1); #endif dtime_save = bli_clock_min_diff( dtime_save, dtime ); } if ( bli_is_left( side ) ) gflops = ( 1.0 * m * m * n ) / ( dtime_save * 1.0e9 ); else gflops = ( 1.0 * m * n * n ) / ( dtime_save * 1.0e9 ); if ( bli_is_complex( dt_a ) ) gflops *= 4.0; #ifdef BLIS printf( "data_trsm_blis" ); #else printf( "data_trsm_%s", BLAS ); #endif printf( "( %2lu, 1:4 ) = [ %4lu %4lu %10.3e %6.3f ];\n", ( unsigned long )(p - p_begin + 1)/p_inc + 1, ( unsigned long )m, ( unsigned long )n, dtime_save, gflops ); bli_obj_free( &alpha ); bli_obj_free( &beta ); bli_obj_free( &a ); bli_obj_free( &b ); bli_obj_free( &c ); bli_obj_free( &c_save ); } bli_finalize(); return 0; }
int main( int argc, char** argv ) { obj_t a, c; obj_t c_save; obj_t alpha; dim_t m, n; dim_t p; dim_t p_begin, p_end, p_inc; int m_input, n_input; num_t dt; int r, n_repeats; side_t side; uplo_t uploa; trans_t transa; diag_t diaga; f77_char f77_side; f77_char f77_uploa; f77_char f77_transa; f77_char f77_diaga; double dtime; double dtime_save; double gflops; //bli_init(); //bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); n_repeats = 3; #ifndef PRINT p_begin = 200; p_end = 2000; p_inc = 200; m_input = -1; n_input = -1; #else p_begin = 16; p_end = 16; p_inc = 1; m_input = 4; n_input = 4; #endif #if 1 //dt = BLIS_FLOAT; dt = BLIS_DOUBLE; #else //dt = BLIS_SCOMPLEX; dt = BLIS_DCOMPLEX; #endif side = BLIS_LEFT; //side = BLIS_RIGHT; uploa = BLIS_LOWER; //uploa = BLIS_UPPER; transa = BLIS_NO_TRANSPOSE; diaga = BLIS_NONUNIT_DIAG; bli_param_map_blis_to_netlib_side( side, &f77_side ); bli_param_map_blis_to_netlib_uplo( uploa, &f77_uploa ); bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); bli_param_map_blis_to_netlib_diag( diaga, &f77_diaga ); for ( p = p_begin; p <= p_end; p += p_inc ) { if ( m_input < 0 ) m = p * ( dim_t )abs(m_input); else m = ( dim_t ) m_input; if ( n_input < 0 ) n = p * ( dim_t )abs(n_input); else n = ( dim_t ) n_input; bli_obj_create( dt, 1, 1, 0, 0, &alpha ); if ( bli_is_left( side ) ) bli_obj_create( dt, m, m, 0, 0, &a ); else bli_obj_create( dt, n, n, 0, 0, &a ); bli_obj_create( dt, m, n, 0, 0, &c ); bli_obj_create( dt, m, n, 0, 0, &c_save ); bli_randm( &a ); bli_randm( &c ); bli_obj_set_struc( BLIS_TRIANGULAR, &a ); bli_obj_set_uplo( uploa, &a ); bli_obj_set_conjtrans( transa, &a ); bli_obj_set_diag( diaga, &a ); // Randomize A, make it densely Hermitian, and zero the unstored // triangle to ensure the implementation reads only from the stored // region. bli_randm( &a ); bli_mkherm( &a ); bli_mktrim( &a ); bli_setsc( (2.0/1.0), 1.0, &alpha ); bli_copym( &c, &c_save ); dtime_save = DBL_MAX; for ( r = 0; r < n_repeats; ++r ) { bli_copym( &c_save, &c ); dtime = bli_clock(); #ifdef PRINT bli_invertd( &a ); bli_printm( "a", &a, "%4.1f", "" ); bli_invertd( &a ); bli_printm( "c", &c, "%4.1f", "" ); #endif #ifdef BLIS bli_trsm( side, &alpha, &a, &c ); #else if ( bli_is_float( dt ) ) { f77_int mm = bli_obj_length( &c ); f77_int nn = bli_obj_width( &c ); f77_int lda = bli_obj_col_stride( &a ); f77_int ldc = bli_obj_col_stride( &c ); float* alphap = bli_obj_buffer( &alpha ); float* ap = bli_obj_buffer( &a ); float* cp = bli_obj_buffer( &c ); strsm_( &f77_side, &f77_uploa, &f77_transa, &f77_diaga, &mm, &nn, alphap, ap, &lda, cp, &ldc ); } else if ( bli_is_double( dt ) ) { f77_int mm = bli_obj_length( &c ); f77_int nn = bli_obj_width( &c ); f77_int lda = bli_obj_col_stride( &a ); f77_int ldc = bli_obj_col_stride( &c ); double* alphap = bli_obj_buffer( &alpha ); double* ap = bli_obj_buffer( &a ); double* cp = bli_obj_buffer( &c ); dtrsm_( &f77_side, &f77_uploa, &f77_transa, &f77_diaga, &mm, &nn, alphap, ap, &lda, cp, &ldc ); } else if ( bli_is_scomplex( dt ) ) { f77_int mm = bli_obj_length( &c ); f77_int nn = bli_obj_width( &c ); f77_int lda = bli_obj_col_stride( &a ); f77_int ldc = bli_obj_col_stride( &c ); scomplex* alphap = bli_obj_buffer( &alpha ); scomplex* ap = bli_obj_buffer( &a ); scomplex* cp = bli_obj_buffer( &c ); ctrsm_( &f77_side, &f77_uploa, &f77_transa, &f77_diaga, &mm, &nn, alphap, ap, &lda, cp, &ldc ); } else if ( bli_is_dcomplex( dt ) ) { f77_int mm = bli_obj_length( &c ); f77_int nn = bli_obj_width( &c ); f77_int lda = bli_obj_col_stride( &a ); f77_int ldc = bli_obj_col_stride( &c ); dcomplex* alphap = bli_obj_buffer( &alpha ); dcomplex* ap = bli_obj_buffer( &a ); dcomplex* cp = bli_obj_buffer( &c ); ztrsm_( &f77_side, &f77_uploa, &f77_transa, &f77_diaga, &mm, &nn, alphap, ap, &lda, cp, &ldc ); } #endif #ifdef PRINT bli_printm( "c after", &c, "%9.5f", "" ); exit(1); #endif dtime_save = bli_clock_min_diff( dtime_save, dtime ); } if ( bli_is_left( side ) ) gflops = ( 1.0 * m * m * n ) / ( dtime_save * 1.0e9 ); else gflops = ( 1.0 * m * n * n ) / ( dtime_save * 1.0e9 ); if ( bli_is_complex( dt ) ) gflops *= 4.0; #ifdef BLIS printf( "data_trsm_blis" ); #else printf( "data_trsm_%s", BLAS ); #endif printf( "( %2lu, 1:4 ) = [ %4lu %4lu %10.3e %6.3f ];\n", ( unsigned long )(p - p_begin + 1)/p_inc + 1, ( unsigned long )m, ( unsigned long )n, dtime_save, gflops ); bli_obj_free( &alpha ); bli_obj_free( &a ); bli_obj_free( &c ); bli_obj_free( &c_save ); } //bli_finalize(); return 0; }