void bli_blksz_reduce_dt_to( num_t dt_bm, blksz_t* bmult, num_t dt_bs, blksz_t* blksz ) { dim_t blksz_def = bli_blksz_get_def( dt_bs, blksz ); dim_t blksz_max = bli_blksz_get_max( dt_bs, blksz ); dim_t bmult_val = bli_blksz_get_def( dt_bm, bmult ); // If the blocksize multiple is zero, we do nothing. if ( bmult_val == 0 ) return; // Round the default and maximum blocksize values down to their // respective nearest multiples of bmult_val. (Notice that we // ignore the "max" entry in the bmult object since that would // correspond to the packing dimension, which plays no role // as a blocksize multiple.) blksz_def = ( blksz_def / bmult_val ) * bmult_val; blksz_max = ( blksz_max / bmult_val ) * bmult_val; // Make sure the new blocksize values are at least the blocksize // multiple. if ( blksz_def == 0 ) blksz_def = bmult_val; if ( blksz_max == 0 ) blksz_max = bmult_val; // Store the new blocksizes back to the object. bli_blksz_set_def( blksz_def, dt_bs, blksz ); bli_blksz_set_max( blksz_max, dt_bs, blksz ); }
dim_t bli_trsm_determine_kc_b( dim_t i, dim_t dim, obj_t* obj, blksz_t* bsize ) { num_t dt; dim_t mr; dim_t b_alg, b_max; dim_t b_use; // We assume that this function is being called from an algorithm that // is moving "backward" (ie: bottom to top, right to left, bottom-right // to top-left). // Extract the execution datatype and use it to query the corresponding // blocksize and blocksize maximum values from the blksz_t object. dt = bli_obj_execution_datatype( *obj ); b_alg = bli_blksz_get_def( dt, bsize ); b_max = bli_blksz_get_max( dt, bsize ); // Nudge the default and maximum kc blocksizes up to the nearest // multiple of MR. We always use MR (rather than sometimes using NR) // because even when the triangle is on the right, packing of that // matrix uses MR, since only left-side trsm micro-kernels are // supported. mr = bli_blksz_get_mr( dt, bsize ); b_alg = bli_align_dim_to_mult( b_alg, mr ); b_max = bli_align_dim_to_mult( b_max, mr ); b_use = bli_determine_blocksize_b_sub( i, dim, b_alg, b_max ); return b_use; }
dim_t bli_gemm_determine_kc_b( dim_t i, dim_t dim, obj_t* a, obj_t* b, blksz_t* bsize ) { num_t dt; dim_t mnr; dim_t b_alg, b_max; dim_t b_use; // We assume that this function is being called from an algorithm that // is moving "backward" (ie: bottom to top, right to left, bottom-right // to top-left). // Extract the execution datatype and use it to query the corresponding // blocksize and blocksize maximum values from the blksz_t object. dt = bli_obj_execution_datatype( *a ); b_alg = bli_blksz_get_def( dt, bsize ); b_max = bli_blksz_get_max( dt, bsize ); // Nudge the default and maximum kc blocksizes up to the nearest // multiple of MR if A is Hermitian or symmetric, or NR if B is // Hermitian or symmetric. If neither case applies, then we leave // the blocksizes unchanged. if ( bli_obj_root_is_herm_or_symm( *a ) ) { mnr = bli_blksz_get_mr( dt, bsize ); b_alg = bli_align_dim_to_mult( b_alg, mnr ); b_max = bli_align_dim_to_mult( b_max, mnr ); } else if ( bli_obj_root_is_herm_or_symm( *b ) ) { mnr = bli_blksz_get_nr( dt, bsize ); b_alg = bli_align_dim_to_mult( b_alg, mnr ); b_max = bli_align_dim_to_mult( b_max, mnr ); } b_use = bli_determine_blocksize_b_sub( i, dim, b_alg, b_max ); return b_use; }
dim_t bli_determine_blocksize_b( dim_t i, dim_t dim, obj_t* obj, bszid_t bszid, cntx_t* cntx ) { num_t dt; blksz_t* bsize; dim_t b_alg, b_max; dim_t b_use; // Extract the execution datatype and use it to query the corresponding // blocksize and blocksize maximum values from the blksz_t object. dt = bli_obj_execution_datatype( *obj ); bsize = bli_cntx_get_blksz( bszid, cntx ); b_alg = bli_blksz_get_def( dt, bsize ); b_max = bli_blksz_get_max( dt, bsize ); b_use = bli_determine_blocksize_b_sub( i, dim, b_alg, b_max ); return b_use; }
int main( int argc, char** argv ) { obj_t a, b, c; obj_t c_save; obj_t alpha, beta; dim_t m, n, k; dim_t p; dim_t p_begin, p_end, p_inc; int m_input, n_input, k_input; num_t dt, dt_real; char dt_ch; int r, n_repeats; trans_t transa; trans_t transb; f77_char f77_transa; f77_char f77_transb; double dtime; double dtime_save; double gflops; extern blksz_t* gemm_kc; bli_init(); //bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); n_repeats = 3; dt = DT; dt_real = bli_datatype_proj_to_real( DT ); p_begin = P_BEGIN; p_end = P_END; p_inc = P_INC; m_input = -1; n_input = -1; k_input = -1; // Extract the kc blocksize for the requested datatype and its // real analogue. dim_t kc = bli_blksz_get_def( dt, gemm_kc ); dim_t kc_real = bli_blksz_get_def( dt_real, gemm_kc ); // Assign the k dimension depending on which implementation is // being tested. Note that the BLIS_NAT case handles the real // domain cases as well as native complex. if ( IND == BLIS_NAT ) k_input = kc; else if ( IND == BLIS_3M1 ) k_input = kc_real / 3; else if ( IND == BLIS_4M1A ) k_input = kc_real / 2; else k_input = kc_real; // Adjust the relative dimensions, if requested. #if (defined ADJ_MK) m_input = -2; k_input = -2; n_input = -1; #elif (defined ADJ_KN) k_input = -2; n_input = -2; m_input = -1; #elif (defined ADJ_MN) m_input = -2; n_input = -2; k_input = -1; #endif // Choose the char corresponding to the requested datatype. if ( bli_is_float( dt ) ) dt_ch = 's'; else if ( bli_is_double( dt ) ) dt_ch = 'd'; else if ( bli_is_scomplex( dt ) ) dt_ch = 'c'; else dt_ch = 'z'; transa = BLIS_NO_TRANSPOSE; transb = BLIS_NO_TRANSPOSE; bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); bli_param_map_blis_to_netlib_trans( transb, &f77_transb ); // Begin with initializing the last entry to zero so that // matlab allocates space for the entire array once up-front. for ( p = p_begin; p + p_inc <= p_end; p += p_inc ) ; #ifdef BLIS printf( "data_%s_%cgemm_%s_blis", THR_STR, dt_ch, STR ); #else printf( "data_%s_%cgemm_%s", THR_STR, dt_ch, STR ); #endif printf( "( %2lu, 1:5 ) = [ %4lu %4lu %4lu %10.3e %6.3f ];\n", ( unsigned long )(p - p_begin + 1)/p_inc + 1, ( unsigned long )0, ( unsigned long )0, ( unsigned long )0, 0.0, 0.0 ); for ( p = p_begin; p <= p_end; p += p_inc ) { if ( m_input < 0 ) m = p / ( dim_t )abs(m_input); else m = ( dim_t ) m_input; if ( n_input < 0 ) n = p / ( dim_t )abs(n_input); else n = ( dim_t ) n_input; if ( k_input < 0 ) k = p / ( dim_t )abs(k_input); else k = ( dim_t ) k_input; bli_obj_create( dt, 1, 1, 0, 0, &alpha ); bli_obj_create( dt, 1, 1, 0, 0, &beta ); bli_obj_create( dt, m, k, 0, 0, &a ); bli_obj_create( dt, k, n, 0, 0, &b ); bli_obj_create( dt, m, n, 0, 0, &c ); //bli_obj_create( dt, m, k, 2, 2*m, &a ); //bli_obj_create( dt, k, n, 2, 2*k, &b ); //bli_obj_create( dt, m, n, 2, 2*m, &c ); bli_obj_create( dt, m, n, 0, 0, &c_save ); bli_randm( &a ); bli_randm( &b ); bli_randm( &c ); bli_obj_set_conjtrans( transa, a ); bli_obj_set_conjtrans( transb, b ); bli_setsc( (2.0/1.0), 0.0, &alpha ); bli_setsc( -(1.0/1.0), 0.0, &beta ); bli_copym( &c, &c_save ); #ifdef BLIS bli_ind_disable_all_dt( dt ); bli_ind_enable_dt( IND, dt ); #endif dtime_save = 1.0e9; for ( r = 0; r < n_repeats; ++r ) { bli_copym( &c_save, &c ); dtime = bli_clock(); #ifdef PRINT bli_printm( "a", &a, "%4.1f", "" ); bli_printm( "b", &b, "%4.1f", "" ); bli_printm( "c", &c, "%4.1f", "" ); #endif #ifdef BLIS bli_gemm( &alpha, &a, &b, &beta, &c ); #else if ( bli_is_float( dt ) ) { f77_int mm = bli_obj_length( c ); f77_int kk = bli_obj_width_after_trans( a ); f77_int nn = bli_obj_width( c ); f77_int lda = bli_obj_col_stride( a ); f77_int ldb = bli_obj_col_stride( b ); f77_int ldc = bli_obj_col_stride( c ); float* alphap = bli_obj_buffer( alpha ); float* ap = bli_obj_buffer( a ); float* bp = bli_obj_buffer( b ); float* betap = bli_obj_buffer( beta ); float* cp = bli_obj_buffer( c ); sgemm_( &f77_transa, &f77_transb, &mm, &nn, &kk, alphap, ap, &lda, bp, &ldb, betap, cp, &ldc ); } else if ( bli_is_double( dt ) ) { f77_int mm = bli_obj_length( c ); f77_int kk = bli_obj_width_after_trans( a ); f77_int nn = bli_obj_width( c ); f77_int lda = bli_obj_col_stride( a ); f77_int ldb = bli_obj_col_stride( b ); f77_int ldc = bli_obj_col_stride( c ); double* alphap = bli_obj_buffer( alpha ); double* ap = bli_obj_buffer( a ); double* bp = bli_obj_buffer( b ); double* betap = bli_obj_buffer( beta ); double* cp = bli_obj_buffer( c ); dgemm_( &f77_transa, &f77_transb, &mm, &nn, &kk, alphap, ap, &lda, bp, &ldb, betap, cp, &ldc ); } else if ( bli_is_scomplex( dt ) ) { f77_int mm = bli_obj_length( c ); f77_int kk = bli_obj_width_after_trans( a ); f77_int nn = bli_obj_width( c ); f77_int lda = bli_obj_col_stride( a ); f77_int ldb = bli_obj_col_stride( b ); f77_int ldc = bli_obj_col_stride( c ); scomplex* alphap = bli_obj_buffer( alpha ); scomplex* ap = bli_obj_buffer( a ); scomplex* bp = bli_obj_buffer( b ); scomplex* betap = bli_obj_buffer( beta ); scomplex* cp = bli_obj_buffer( c ); cgemm_( &f77_transa, &f77_transb, &mm, &nn, &kk, alphap, ap, &lda, bp, &ldb, betap, cp, &ldc ); } else if ( bli_is_dcomplex( dt ) ) { f77_int mm = bli_obj_length( c ); f77_int kk = bli_obj_width_after_trans( a ); f77_int nn = bli_obj_width( c ); f77_int lda = bli_obj_col_stride( a ); f77_int ldb = bli_obj_col_stride( b ); f77_int ldc = bli_obj_col_stride( c ); dcomplex* alphap = bli_obj_buffer( alpha ); dcomplex* ap = bli_obj_buffer( a ); dcomplex* bp = bli_obj_buffer( b ); dcomplex* betap = bli_obj_buffer( beta ); dcomplex* cp = bli_obj_buffer( c ); zgemm_( &f77_transa, //zgemm3m_( &f77_transa, &f77_transb, &mm, &nn, &kk, alphap, ap, &lda, bp, &ldb, betap, cp, &ldc ); } #endif #ifdef PRINT bli_printm( "c after", &c, "%4.1f", "" ); exit(1); #endif dtime_save = bli_clock_min_diff( dtime_save, dtime ); } gflops = ( 2.0 * m * k * n ) / ( dtime_save * 1.0e9 ); if ( bli_is_complex( dt ) ) gflops *= 4.0; #ifdef BLIS printf( "data_%s_%cgemm_%s_blis", THR_STR, dt_ch, STR ); #else printf( "data_%s_%cgemm_%s", THR_STR, dt_ch, STR ); #endif printf( "( %2lu, 1:5 ) = [ %4lu %4lu %4lu %10.3e %6.3f ];\n", ( unsigned long )(p - p_begin + 1)/p_inc + 1, ( unsigned long )m, ( unsigned long )k, ( unsigned long )n, dtime_save, gflops ); bli_obj_free( &alpha ); bli_obj_free( &beta ); bli_obj_free( &a ); bli_obj_free( &b ); bli_obj_free( &c ); bli_obj_free( &c_save ); } bli_finalize(); return 0; }