예제 #1
0
void bli_unpackm_blk_var2( obj_t*     p,
                           obj_t*     c,
                           unpackm_t* cntl )
{
	num_t     dt_cp     = bli_obj_datatype( *c );

	// Normally we take the parameters from the source argument. But here,
	// the packm/unpackm framework is not yet solidified enough for us to
	// assume that at this point struc(P) == struc(C), (ie: since
	// densification may have marked P's structure as dense when the root
	// is upper or lower). So, we take the struc field from C, not P.
	struc_t   strucc    = bli_obj_struc( *c );
	doff_t    diagoffc  = bli_obj_diag_offset( *c );
	diag_t    diagc     = bli_obj_diag( *c );
	uplo_t    uploc     = bli_obj_uplo( *c );

	// Again, normally the trans argument is on the source matrix. But we
	// know that the packed matrix is not transposed. If there is to be a
	// transposition, it is because C was originally transposed when packed.
	// Thus, we query C for the trans status, not P. Also, we only query
	// the trans status (not the conjugation status), since we probably
	// don't want to un-conjugate if the original matrix was conjugated
	// when packed.
	trans_t   transc    = bli_obj_onlytrans_status( *c );

	dim_t     m_c       = bli_obj_length( *c );
	dim_t     n_c       = bli_obj_width( *c );
	dim_t     m_panel   = bli_obj_panel_length( *c );
	dim_t     n_panel   = bli_obj_panel_width( *c );

	void*     buf_p     = bli_obj_buffer_at_off( *p );
	inc_t     rs_p      = bli_obj_row_stride( *p );
	inc_t     cs_p      = bli_obj_col_stride( *p );
	dim_t     pd_p      = bli_obj_panel_dim( *p );
	inc_t     ps_p      = bli_obj_panel_stride( *p );

	void*     buf_c     = bli_obj_buffer_at_off( *c );
	inc_t     rs_c      = bli_obj_row_stride( *c );
	inc_t     cs_c      = bli_obj_col_stride( *c );

	FUNCPTR_T f;

	// Index into the type combination array to extract the correct
	// function pointer.
	f = ftypes[dt_cp];

	// Invoke the function.
	f( strucc,
	   diagoffc,
	   diagc,
	   uploc,
	   transc,
	   m_c,
	   n_c,
	   m_panel,
	   n_panel,
	   buf_p, rs_p, cs_p,
	          pd_p, ps_p,
	   buf_c, rs_c, cs_c );
}
예제 #2
0
void bli_packm_unb_var1( obj_t*   c,
                         obj_t*   p,
                         packm_thrinfo_t* thread )
{
	num_t     dt_cp     = bli_obj_datatype( *c );

	struc_t   strucc    = bli_obj_struc( *c );
	doff_t    diagoffc  = bli_obj_diag_offset( *c );
	diag_t    diagc     = bli_obj_diag( *c );
	uplo_t    uploc     = bli_obj_uplo( *c );
	trans_t   transc    = bli_obj_conjtrans_status( *c );

	dim_t     m_p       = bli_obj_length( *p );
	dim_t     n_p       = bli_obj_width( *p );
	dim_t     m_max_p   = bli_obj_padded_length( *p );
	dim_t     n_max_p   = bli_obj_padded_width( *p );

	void*     buf_c     = bli_obj_buffer_at_off( *c );
	inc_t     rs_c      = bli_obj_row_stride( *c );
	inc_t     cs_c      = bli_obj_col_stride( *c );

	void*     buf_p     = bli_obj_buffer_at_off( *p );
	inc_t     rs_p      = bli_obj_row_stride( *p );
	inc_t     cs_p      = bli_obj_col_stride( *p );

	void*     buf_kappa;

	FUNCPTR_T f;


	// This variant assumes that the computational kernel will always apply
	// the alpha scalar of the higher-level operation. Thus, we use BLIS_ONE
	// for kappa so that the underlying packm implementation does not scale
	// during packing.
	buf_kappa = bli_obj_buffer_for_const( dt_cp, BLIS_ONE );

	// Index into the type combination array to extract the correct
	// function pointer.
	f = ftypes[dt_cp];

    if( thread_am_ochief( thread ) ) {
        // Invoke the function.
        f( strucc,
           diagoffc,
           diagc,
           uploc,
           transc,
           m_p,
           n_p,
           m_max_p,
           n_max_p,
           buf_kappa,
           buf_c, rs_c, cs_c,
           buf_p, rs_p, cs_p );
    }
}
예제 #3
0
dim_t bli_packm_offset_to_panel_for( dim_t offmn, obj_t* p )
{
	dim_t panel_off;

	if      ( bli_obj_pack_schema( *p ) == BLIS_PACKED_ROWS )
	{
		// For the "packed rows" schema, a single row is effectively one
		// row panel, and so we use the row offset as the panel offset.
		// Then we multiply this offset by the effective panel stride
		// (ie: the row stride) to arrive at the desired offset.
		panel_off = offmn * bli_obj_row_stride( *p );
	}
	else if ( bli_obj_pack_schema( *p ) == BLIS_PACKED_COLUMNS )
	{
		// For the "packed columns" schema, a single column is effectively one
		// column panel, and so we use the column offset as the panel offset.
		// Then we multiply this offset by the effective panel stride
		// (ie: the column stride) to arrive at the desired offset.
		panel_off = offmn * bli_obj_col_stride( *p );
	}
	else if ( bli_obj_pack_schema( *p ) == BLIS_PACKED_ROW_PANELS )
	{
		// For the "packed row panels" schema, the column stride is equal to
		// the panel dimension (length). So we can divide it into offmn
		// (interpreted as a row offset) to arrive at a panel offset. Then
		// we multiply this offset by the panel stride to arrive at the total
		// offset to the panel (in units of elements).
		panel_off = offmn / bli_obj_col_stride( *p );
		panel_off = panel_off * bli_obj_panel_stride( *p );

		// Sanity check.
		if ( offmn % bli_obj_col_stride( *p ) > 0 ) bli_abort();
	}
	else if ( bli_obj_pack_schema( *p ) == BLIS_PACKED_COL_PANELS )
	{
		// For the "packed column panels" schema, the row stride is equal to
		// the panel dimension (width). So we can divide it into offmn
		// (interpreted as a column offset) to arrive at a panel offset. Then
		// we multiply this offset by the panel stride to arrive at the total
		// offset to the panel (in units of elements).
		panel_off = offmn / bli_obj_row_stride( *p );
		panel_off = panel_off * bli_obj_panel_stride( *p );

		// Sanity check.
		if ( offmn % bli_obj_row_stride( *p ) > 0 ) bli_abort();
	}
	else
	{
		panel_off = 0;
		bli_check_error_code( BLIS_NOT_YET_IMPLEMENTED );
	}

	return panel_off;
}
예제 #4
0
void bli_trsm_rl_ker_var2( obj_t*  alpha,
                           obj_t*  a,
                           obj_t*  b,
                           obj_t*  beta,
                           obj_t*  c,
                           trsm_t* cntl )
{
	num_t     dt_exec   = bli_obj_execution_datatype( *c );

	doff_t    diagoffb  = bli_obj_diag_offset( *b );

	dim_t     m         = bli_obj_length( *c );
	dim_t     n         = bli_obj_width( *c );
	dim_t     k         = bli_obj_width( *a );

	void*     buf_a     = bli_obj_buffer_at_off( *a );
	inc_t     rs_a      = bli_obj_row_stride( *a );
	inc_t     cs_a      = bli_obj_col_stride( *a );
	inc_t     ps_a      = bli_obj_panel_stride( *a );

	void*     buf_b     = bli_obj_buffer_at_off( *b );
	inc_t     rs_b      = bli_obj_row_stride( *b );
	inc_t     cs_b      = bli_obj_col_stride( *b );
	inc_t     ps_b      = bli_obj_panel_stride( *b );

	void*     buf_c     = bli_obj_buffer_at_off( *c );
	inc_t     rs_c      = bli_obj_row_stride( *c );
	inc_t     cs_c      = bli_obj_col_stride( *c );

	num_t     dt_alpha;
	void*     buf_alpha;

	FUNCPTR_T f;

	// If alpha is a scalar constant, use dt_exec to extract the address of the
	// corresponding constant value; otherwise, use the datatype encoded
	// within the alpha object and extract the buffer at the alpha offset.
	bli_set_scalar_dt_buffer( alpha, dt_exec, dt_alpha, buf_alpha );

	// Index into the type combination array to extract the correct
	// function pointer.
	f = ftypes[dt_exec];

	// Invoke the function.
	f( diagoffb,
	   m,
	   n,
	   k,
	   buf_alpha,
	   buf_a, rs_a, cs_a, ps_a,
	   buf_b, rs_b, cs_b, ps_b,
	   buf_c, rs_c, cs_c );
}
예제 #5
0
void bli_axpym_unb_var1( obj_t*  alpha,
                         obj_t*  x,
                         obj_t*  y,
                         cntx_t* cntx )
{
	num_t     dt_x      = bli_obj_datatype( *x );
	num_t     dt_y      = bli_obj_datatype( *y );

	doff_t    diagoffx  = bli_obj_diag_offset( *x );
	diag_t    diagx     = bli_obj_diag( *x );
	uplo_t    uplox     = bli_obj_uplo( *x );
	trans_t   transx    = bli_obj_conjtrans_status( *x );

	dim_t     m         = bli_obj_length( *y );
	dim_t     n         = bli_obj_width( *y );

	inc_t     rs_x      = bli_obj_row_stride( *x );
	inc_t     cs_x      = bli_obj_col_stride( *x );
	void*     buf_x     = bli_obj_buffer_at_off( *x );

	inc_t     rs_y      = bli_obj_row_stride( *y );
	inc_t     cs_y      = bli_obj_col_stride( *y );
	void*     buf_y     = bli_obj_buffer_at_off( *y );

	num_t     dt_alpha;
	void*     buf_alpha;

	FUNCPTR_T f;

	// If alpha is a scalar constant, use dt_x to extract the address of the
	// corresponding constant value; otherwise, use the datatype encoded
	// within the alpha object and extract the buffer at the alpha offset.
	bli_set_scalar_dt_buffer( alpha, dt_x, dt_alpha, buf_alpha );

	// Index into the type combination array to extract the correct
	// function pointer.
	f = ftypes[dt_alpha][dt_x][dt_y];

	// Invoke the function.
	f( diagoffx,
	   diagx,
	   uplox,
	   transx,
	   m,
	   n,
	   buf_alpha,
	   buf_x, rs_x, cs_x,
	   buf_y, rs_y, cs_y );
}
예제 #6
0
void bli_norm1m_unb_var1( obj_t* x,
                          obj_t* norm )
{
    num_t     dt_x     = bli_obj_datatype( *x );

    doff_t    diagoffx = bli_obj_diag_offset( *x );
    uplo_t    diagx    = bli_obj_diag( *x );
    uplo_t    uplox    = bli_obj_uplo( *x );

    dim_t     m        = bli_obj_length( *x );
    dim_t     n        = bli_obj_width( *x );

    void*     buf_x    = bli_obj_buffer_at_off( *x );
    inc_t     rs_x     = bli_obj_row_stride( *x );
    inc_t     cs_x     = bli_obj_col_stride( *x );

    void*     buf_norm = bli_obj_buffer_at_off( *norm );

    FUNCPTR_T f;

    // Index into the type combination array to extract the correct
    // function pointer.
    f = ftypes[dt_x];

    // Invoke the function.
    f( diagoffx,
       diagx,
       uplox,
       m,
       n,
       buf_x, rs_x, cs_x,
       buf_norm );
}
예제 #7
0
void bli_setid_unb_var1( obj_t*  beta,
                         obj_t*  x )
{
	num_t     dt_xr     = bli_obj_datatype_proj_to_real( *x );
	num_t     dt_x      = bli_obj_datatype( *x );

	doff_t    diagoffx  = bli_obj_diag_offset( *x );

	dim_t     m         = bli_obj_length( *x );
	dim_t     n         = bli_obj_width( *x );

	void*     buf_beta  = bli_obj_buffer_for_1x1( dt_xr, *beta );

	void*     buf_x     = bli_obj_buffer_at_off( *x );
	inc_t     rs_x      = bli_obj_row_stride( *x );
	inc_t     cs_x      = bli_obj_col_stride( *x );

	FUNCPTR_T f;

	// Index into the type combination array to extract the correct
	// function pointer.
	f = ftypes[dt_x];

	// Invoke the function.
	f( diagoffx,
	   m,
	   n,
	   buf_beta,
	   buf_x, rs_x, cs_x );
}
예제 #8
0
void bli_gemv_unb_var2( obj_t*  alpha,
                        obj_t*  a,
                        obj_t*  x,
                        obj_t*  beta,
                        obj_t*  y,
                        gemv_t* cntl )
{
	num_t     dt_a      = bli_obj_datatype( *a );
	num_t     dt_x      = bli_obj_datatype( *x );
	num_t     dt_y      = bli_obj_datatype( *y );

	conj_t    transa    = bli_obj_conjtrans_status( *a );
	conj_t    conjx     = bli_obj_conj_status( *x );

	dim_t     m         = bli_obj_length( *a );
	dim_t     n         = bli_obj_width( *a );

	void*     buf_a     = bli_obj_buffer_at_off( *a );
	inc_t     rs_a      = bli_obj_row_stride( *a );
	inc_t     cs_a      = bli_obj_col_stride( *a );

	void*     buf_x     = bli_obj_buffer_at_off( *x );
	inc_t     incx      = bli_obj_vector_inc( *x );

	void*     buf_y     = bli_obj_buffer_at_off( *y );
	inc_t     incy      = bli_obj_vector_inc( *y );

	num_t     dt_alpha;
	void*     buf_alpha;

	num_t     dt_beta;
	void*     buf_beta;

	FUNCPTR_T f;

	// The datatype of alpha MUST be the type union of a and x. This is to
	// prevent any unnecessary loss of information during computation.
	dt_alpha  = bli_datatype_union( dt_a, dt_x );
	buf_alpha = bli_obj_scalar_buffer( dt_alpha, *alpha );

	// The datatype of beta MUST be the same as the datatype of y.
	dt_beta   = dt_y;
	buf_beta  = bli_obj_scalar_buffer( dt_beta, *beta );

	// Index into the type combination array to extract the correct
	// function pointer.
	f = ftypes[dt_a][dt_x][dt_y];

	// Invoke the function.
	f( transa,
	   conjx,
	   m,
	   n,
	   buf_alpha,
	   buf_a, rs_a, cs_a,
	   buf_x, incx,
	   buf_beta,
	   buf_y, incy );
}
예제 #9
0
void bli_gemmtrsm_ukr( obj_t*  alpha,
                       obj_t*  a1x,
                       obj_t*  a11,
                       obj_t*  bx1,
                       obj_t*  b11,
                       obj_t*  c11 )
{
	dim_t     k         = bli_obj_width( *a1x );

	num_t     dt        = bli_obj_datatype( *c11 );

	void*     buf_a1x   = bli_obj_buffer_at_off( *a1x );

	void*     buf_a11   = bli_obj_buffer_at_off( *a11 );

	void*     buf_bx1   = bli_obj_buffer_at_off( *bx1 );

	void*     buf_b11   = bli_obj_buffer_at_off( *b11 );

	void*     buf_c11   = bli_obj_buffer_at_off( *c11 );
	inc_t     rs_c      = bli_obj_row_stride( *c11 );
	inc_t     cs_c      = bli_obj_col_stride( *c11 );

	void*     buf_alpha = bli_obj_buffer_for_1x1( dt, *alpha );

	inc_t     ps_a      = bli_obj_panel_stride( *a1x );
	inc_t     ps_b      = bli_obj_panel_stride( *bx1 );

	FUNCPTR_T f;

	auxinfo_t data;


	// Fill the auxinfo_t struct in case the micro-kernel uses it.
	if ( bli_obj_is_lower( *a11 ) )
	{ bli_auxinfo_set_next_a( buf_a1x, data ); }
	else
	{ bli_auxinfo_set_next_a( buf_a11, data ); }
	bli_auxinfo_set_next_b( buf_bx1, data );

	bli_auxinfo_set_ps_a( ps_a, data );
	bli_auxinfo_set_ps_b( ps_b, data );

	// Index into the type combination array to extract the correct
	// function pointer.
	if ( bli_obj_is_lower( *a11 ) ) f = ftypes_l[dt];
	else                            f = ftypes_u[dt];

	// Invoke the function.
	f( k,
	   buf_alpha,
	   buf_a1x,
	   buf_a11,
	   buf_bx1,
	   buf_b11,
	   buf_c11, rs_c, cs_c,
	   &data );
}
예제 #10
0
void bli_dotxf_kernel( obj_t*  alpha,
                       obj_t*  a,
                       obj_t*  x,
                       obj_t*  beta,
                       obj_t*  y )
{
	num_t     dt_a      = bli_obj_datatype( *a );
	num_t     dt_x      = bli_obj_datatype( *x );
	num_t     dt_y      = bli_obj_datatype( *y );

	conj_t    conjat    = bli_obj_conj_status( *a );
	conj_t    conjx     = bli_obj_conj_status( *x );

	dim_t     m         = bli_obj_vector_dim( *x );
	dim_t     b_n       = bli_obj_vector_dim( *y );

	void*     buf_a     = bli_obj_buffer_at_off( *a );
	inc_t     rs_a      = bli_obj_row_stride( *a );
	inc_t     cs_a      = bli_obj_col_stride( *a );

	inc_t     inc_x     = bli_obj_vector_inc( *x );
	void*     buf_x     = bli_obj_buffer_at_off( *x );

	inc_t     inc_y     = bli_obj_vector_inc( *y );
	void*     buf_y     = bli_obj_buffer_at_off( *y );

	num_t     dt_alpha;
	void*     buf_alpha;

	num_t     dt_beta;
	void*     buf_beta;

	FUNCPTR_T f;

	// The datatype of alpha MUST be the type union of a and x. This is to
	// prevent any unnecessary loss of information during computation.
	dt_alpha  = bli_datatype_union( dt_a, dt_x );
	buf_alpha = bli_obj_buffer_for_1x1( dt_alpha, *alpha );

	// The datatype of beta MUST be the same as the datatype of y.
	dt_beta   = dt_y;
	buf_beta  = bli_obj_buffer_for_1x1( dt_beta, *beta );

	// Index into the type combination array to extract the correct
	// function pointer.
	f = ftypes[dt_a][dt_x][dt_y];

	// Invoke the function.
	f( conjat,
	   conjx,
	   m,
	   b_n,
	   buf_alpha,
	   buf_a, rs_a, cs_a,
	   buf_x, inc_x,
	   buf_beta,
	   buf_y, inc_y );
}
예제 #11
0
void bli_unpackm_unb_var1
     (
       obj_t*  p,
       obj_t*  c,
       cntx_t* cntx,
       cntl_t* cntl,
       thrinfo_t* thread
     )
{
	num_t     dt_pc     = bli_obj_datatype( *p );

	doff_t    diagoffp  = bli_obj_diag_offset( *p );
	uplo_t    uplop     = bli_obj_uplo( *p );
	trans_t   transc    = bli_obj_onlytrans_status( *c );

	dim_t     m_c       = bli_obj_length( *c );
	dim_t     n_c       = bli_obj_width( *c );

	void*     buf_p     = bli_obj_buffer_at_off( *p );
	inc_t     rs_p      = bli_obj_row_stride( *p );
	inc_t     cs_p      = bli_obj_col_stride( *p );

	void*     buf_c     = bli_obj_buffer_at_off( *c );
	inc_t     rs_c      = bli_obj_row_stride( *c );
	inc_t     cs_c      = bli_obj_col_stride( *c );

	FUNCPTR_T f;

	// Index into the type combination array to extract the correct
	// function pointer.
	f = ftypes[dt_pc];

	// Invoke the function.
	f( diagoffp,
	   uplop,
	   transc,
	   m_c,
	   n_c,
	   buf_p, rs_p, cs_p,
	   buf_c, rs_c, cs_c,
	   cntx
	);
}
예제 #12
0
void bli_scalm_unb_var1( obj_t*  alpha,
                         obj_t*  x,
                         cntx_t* cntx )
{
	num_t     dt_x      = bli_obj_datatype( *x );

	doff_t    diagoffx  = bli_obj_diag_offset( *x );
	uplo_t    diagx     = bli_obj_diag( *x );
	uplo_t    uplox     = bli_obj_uplo( *x );

	dim_t     m         = bli_obj_length( *x );
	dim_t     n         = bli_obj_width( *x );

	void*     buf_x     = bli_obj_buffer_at_off( *x );
	inc_t     rs_x      = bli_obj_row_stride( *x );
	inc_t     cs_x      = bli_obj_col_stride( *x );

	void*     buf_alpha;

	obj_t     x_local;

	FUNCPTR_T f;

	// Alias x to x_local so we can apply alpha if it is non-unit.
	bli_obj_alias_to( *x, x_local );

	// If alpha is non-unit, apply it to the scalar attached to x.
	if ( !bli_obj_equals( alpha, &BLIS_ONE ) )
	{
		bli_obj_scalar_apply_scalar( alpha, &x_local );
	}

	// Grab the address of the internal scalar buffer for the scalar
	// attached to x.
	buf_alpha_x = bli_obj_internal_scalar_buffer( *x );

	// Index into the type combination array to extract the correct
	// function pointer.
	// NOTE: We use dt_x for both alpha and x because alpha was obtained
	// from the attached scalar of x, which is guaranteed to be of the
	// same datatype as x.
	f = ftypes[dt_x][dt_x];

	// Invoke the function.
	// NOTE: We unconditionally pass in BLIS_NO_CONJUGATE for alpha
	// because it would have already been conjugated by the front-end.
	f( BLIS_NO_CONJUGATE,
	   diagoffx,
	   diagx,
	   uplox,
	   m,
	   n,
	   buf_alpha,
	   buf_x, rs_x, cs_x );
}
예제 #13
0
void bli_ger_unb_var2( obj_t*  alpha,
                       obj_t*  x,
                       obj_t*  y,
                       obj_t*  a,
                       cntx_t* cntx,
                       ger_t*  cntl )
{
	num_t     dt_x      = bli_obj_datatype( *x );
	num_t     dt_y      = bli_obj_datatype( *y );
	num_t     dt_a      = bli_obj_datatype( *a );

	conj_t    conjx     = bli_obj_conj_status( *x );
	conj_t    conjy     = bli_obj_conj_status( *y );

	dim_t     m         = bli_obj_length( *a );
	dim_t     n         = bli_obj_width( *a );

	void*     buf_x     = bli_obj_buffer_at_off( *x );
	inc_t     incx      = bli_obj_vector_inc( *x );

	void*     buf_y     = bli_obj_buffer_at_off( *y );
	inc_t     incy      = bli_obj_vector_inc( *y );

	void*     buf_a     = bli_obj_buffer_at_off( *a );
	inc_t     rs_a      = bli_obj_row_stride( *a );
	inc_t     cs_a      = bli_obj_col_stride( *a );

	num_t     dt_alpha;
	void*     buf_alpha;

	FUNCPTR_T f;

	// The datatype of alpha MUST be the type union of x and y. This is to
	// prevent any unnecessary loss of information during computation.
	dt_alpha  = bli_datatype_union( dt_x, dt_y );
	buf_alpha = bli_obj_buffer_for_1x1( dt_alpha, *alpha );

	// Index into the type combination array to extract the correct
	// function pointer.
	f = ftypes[dt_a];

	// Invoke the function.
	f( conjx,
	   conjy,
	   m,
	   n,
	   buf_alpha,
	   buf_x, incx,
	   buf_y, incy,
	   buf_a, rs_a, cs_a,
       cntx );
}
예제 #14
0
void bli_her2_unb_var1( conj_t  conjh,
                        obj_t*  alpha,
                        obj_t*  alpha_conj,
                        obj_t*  x,
                        obj_t*  y,
                        obj_t*  c,
                        her2_t* cntl )
{
	num_t     dt_x      = bli_obj_datatype( *x );
	num_t     dt_y      = bli_obj_datatype( *y );
	num_t     dt_c      = bli_obj_datatype( *c );

	uplo_t    uplo      = bli_obj_uplo( *c );
	conj_t    conjx     = bli_obj_conj_status( *x );
	conj_t    conjy     = bli_obj_conj_status( *y );

	dim_t     m         = bli_obj_length( *c );

	void*     buf_x     = bli_obj_buffer_at_off( *x );
	inc_t     incx      = bli_obj_vector_inc( *x );

	void*     buf_y     = bli_obj_buffer_at_off( *y );
	inc_t     incy      = bli_obj_vector_inc( *y );

	void*     buf_c     = bli_obj_buffer_at_off( *c );
	inc_t     rs_c      = bli_obj_row_stride( *c );
	inc_t     cs_c      = bli_obj_col_stride( *c );

	num_t     dt_alpha;
	void*     buf_alpha;

	FUNCPTR_T f;

	// The datatype of alpha MUST be the type union of the datatypes of x and y.
	dt_alpha  = bli_datatype_union( dt_x, dt_y );
	buf_alpha = bli_obj_buffer_for_1x1( dt_alpha, *alpha );

	// Index into the type combination array to extract the correct
	// function pointer.
	f = ftypes[dt_x][dt_y][dt_c];

	// Invoke the function.
	f( uplo,
	   conjx,
	   conjy,
	   conjh,
	   m,
	   buf_alpha,
	   buf_x, incx,
	   buf_y, incy,
	   buf_c, rs_c, cs_c );
}
예제 #15
0
void bli_addm_unb_var1( obj_t*  x,
                        obj_t*  y,
                        cntx_t* cntx )
{
	num_t     dt_x      = bli_obj_datatype( *x );
	num_t     dt_y      = bli_obj_datatype( *y );

	doff_t    diagoffx  = bli_obj_diag_offset( *x );
	diag_t    diagx     = bli_obj_diag( *x );
	uplo_t    uplox     = bli_obj_uplo( *x );
	trans_t   transx    = bli_obj_conjtrans_status( *x );

	dim_t     m         = bli_obj_length( *y );
	dim_t     n         = bli_obj_width( *y );

	inc_t     rs_x      = bli_obj_row_stride( *x );
	inc_t     cs_x      = bli_obj_col_stride( *x );
	void*     buf_x     = bli_obj_buffer_at_off( *x );

	inc_t     rs_y      = bli_obj_row_stride( *y );
	inc_t     cs_y      = bli_obj_col_stride( *y );
	void*     buf_y     = bli_obj_buffer_at_off( *y );

	FUNCPTR_T f;

	// Index into the type combination array to extract the correct
	// function pointer.
	f = ftypes[dt_x][dt_y];

	// Invoke the function.
	f( diagoffx,
	   diagx,
	   uplox,
	   transx,
	   m,
	   n,
	   buf_x, rs_x, cs_x,
	   buf_y, rs_y, cs_y );
}
예제 #16
0
void bli_obj_print( char* label, obj_t* obj )
{
	FILE*  file     = stdout;

	if ( bli_error_checking_is_enabled() )
		bli_obj_print_check( label, obj );

	fprintf( file, "\n" );
	fprintf( file, "%s\n", label );
	fprintf( file, "\n" );

	fprintf( file, " m x n           %lu x %lu\n", ( unsigned long int )bli_obj_length( *obj ),
	                                               ( unsigned long int )bli_obj_width( *obj ) );
	fprintf( file, "\n" );

	fprintf( file, " offm, offn      %lu, %lu\n", ( unsigned long int )bli_obj_row_off( *obj ),
	                                              ( unsigned long int )bli_obj_col_off( *obj ) );
	fprintf( file, " diagoff         %ld\n", ( signed long int )bli_obj_diag_offset( *obj ) );
	fprintf( file, "\n" );

	fprintf( file, " buf             %p\n",  ( void* )bli_obj_buffer( *obj ) );
	fprintf( file, " elem size       %lu\n", ( unsigned long int )bli_obj_elem_size( *obj ) );
	fprintf( file, " rs, cs          %ld, %ld\n", ( signed long int )bli_obj_row_stride( *obj ),
	                                              ( signed long int )bli_obj_col_stride( *obj ) );
	fprintf( file, " is              %ld\n", ( signed long int )bli_obj_imag_stride( *obj ) );
	fprintf( file, " m_padded        %lu\n", ( unsigned long int )bli_obj_padded_length( *obj ) );
	fprintf( file, " n_padded        %lu\n", ( unsigned long int )bli_obj_padded_width( *obj ) );
	fprintf( file, " ps              %lu\n", ( unsigned long int )bli_obj_panel_stride( *obj ) );
	fprintf( file, "\n" );

	fprintf( file, " info            %lX\n", ( unsigned long int )(*obj).info );
	fprintf( file, " - is complex    %lu\n", ( unsigned long int )bli_obj_is_complex( *obj ) );
	fprintf( file, " - is d. prec    %lu\n", ( unsigned long int )bli_obj_is_double_precision( *obj ) );
	fprintf( file, " - datatype      %lu\n", ( unsigned long int )bli_obj_datatype( *obj ) );
	fprintf( file, " - target dt     %lu\n", ( unsigned long int )bli_obj_target_datatype( *obj ) );
	fprintf( file, " - exec dt       %lu\n", ( unsigned long int )bli_obj_execution_datatype( *obj ) );
	fprintf( file, " - has trans     %lu\n", ( unsigned long int )bli_obj_has_trans( *obj ) );
	fprintf( file, " - has conj      %lu\n", ( unsigned long int )bli_obj_has_conj( *obj ) );
	fprintf( file, " - unit diag?    %lu\n", ( unsigned long int )bli_obj_has_unit_diag( *obj ) );
	fprintf( file, " - struc type    %lu\n", ( unsigned long int )bli_obj_struc( *obj ) >> BLIS_STRUC_SHIFT );
	fprintf( file, " - uplo type     %lu\n", ( unsigned long int )bli_obj_uplo( *obj ) >> BLIS_UPLO_SHIFT );
	fprintf( file, "   - is upper    %lu\n", ( unsigned long int )bli_obj_is_upper( *obj ) );
	fprintf( file, "   - is lower    %lu\n", ( unsigned long int )bli_obj_is_lower( *obj ) );
	fprintf( file, "   - is dense    %lu\n", ( unsigned long int )bli_obj_is_dense( *obj ) );
	fprintf( file, " - pack schema   %lu\n", ( unsigned long int )bli_obj_pack_schema( *obj ) >> BLIS_PACK_SCHEMA_SHIFT );
	fprintf( file, " - packinv diag? %lu\n", ( unsigned long int )bli_obj_has_inverted_diag( *obj ) );
	fprintf( file, " - pack ordifup  %lu\n", ( unsigned long int )bli_obj_is_pack_rev_if_upper( *obj ) );
	fprintf( file, " - pack ordiflo  %lu\n", ( unsigned long int )bli_obj_is_pack_rev_if_lower( *obj ) );
	fprintf( file, " - packbuf type  %lu\n", ( unsigned long int )bli_obj_pack_buffer_type( *obj ) >> BLIS_PACK_BUFFER_SHIFT );
	fprintf( file, "\n" );
}
예제 #17
0
void bli_fprintm( FILE* file, char* s1, obj_t* x, char* format, char* s2 )
{
	num_t     dt_x      = bli_obj_datatype( *x );

	dim_t     m         = bli_obj_length( *x );
	dim_t     n         = bli_obj_width( *x );

	inc_t     rs_x      = bli_obj_row_stride( *x );
	inc_t     cs_x      = bli_obj_col_stride( *x );
	void*     buf_x     = bli_obj_buffer_at_off( *x );

	FUNCPTR_T f;

	if ( bli_error_checking_is_enabled() )
		bli_fprintm_check( file, s1, x, format, s2 );

	// Handle constants up front.
	if ( dt_x == BLIS_CONSTANT )
	{
		float*    sp = bli_obj_buffer_for_const( BLIS_FLOAT,    *x );
		double*   dp = bli_obj_buffer_for_const( BLIS_DOUBLE,   *x );
		scomplex* cp = bli_obj_buffer_for_const( BLIS_SCOMPLEX, *x );
		dcomplex* zp = bli_obj_buffer_for_const( BLIS_DCOMPLEX, *x );
		gint_t*   ip = bli_obj_buffer_for_const( BLIS_INT,      *x );

		fprintf( file, "%s\n", s1 );
		fprintf( file, " float:     %9.2e\n",         bli_sreal( *sp ) );
		fprintf( file, " double:    %9.2e\n",         bli_dreal( *dp ) );
		fprintf( file, " scomplex:  %9.2e + %9.2e\n", bli_creal( *cp ), bli_cimag( *cp ) );
		fprintf( file, " dcomplex:  %9.2e + %9.2e\n", bli_zreal( *zp ), bli_zimag( *zp ) );
		fprintf( file, " int:       %ld\n",           *ip );
		fprintf( file, "\n" );
		return;
	}

	// Index into the type combination array to extract the correct
	// function pointer.
	f = ftypes[dt_x];

	// Invoke the function.
	f( file,
	   s1,
	   m,
	   n,
	   buf_x, rs_x, cs_x,
	   format,
	   s2 );
}
예제 #18
0
void bli_trmv_unf_var2( obj_t*  alpha,
                        obj_t*  a,
                        obj_t*  x,
                        cntx_t* cntx,
                        trmv_t* cntl )
{
	num_t     dt_a      = bli_obj_datatype( *a );
	num_t     dt_x      = bli_obj_datatype( *x );

	uplo_t    uplo      = bli_obj_uplo( *a );
	trans_t   trans     = bli_obj_conjtrans_status( *a );
	diag_t    diag      = bli_obj_diag( *a );

	dim_t     m         = bli_obj_length( *a );

	void*     buf_a     = bli_obj_buffer_at_off( *a );
	inc_t     rs_a      = bli_obj_row_stride( *a );
	inc_t     cs_a      = bli_obj_col_stride( *a );

	void*     buf_x     = bli_obj_buffer_at_off( *x );
	inc_t     incx      = bli_obj_vector_inc( *x );

	num_t     dt_alpha;
	void*     buf_alpha;

	FUNCPTR_T f;

	// The datatype of alpha MUST be the type union of a and x. This is to
	// prevent any unnecessary loss of information during computation.
	dt_alpha  = bli_datatype_union( dt_a, dt_x );
	buf_alpha = bli_obj_buffer_for_1x1( dt_alpha, *alpha );

	// Index into the type combination array to extract the correct
	// function pointer.
	f = ftypes[dt_a][dt_x];

	// Invoke the function.
	f( uplo,
	   trans,
	   diag,
	   m,
	   buf_alpha,
	   buf_a, rs_a, cs_a,
	   buf_x, incx );
}
예제 #19
0
void bli_scald_unb_var1( obj_t*  beta,
                         obj_t*  x )
{
	num_t     dt_x      = bli_obj_datatype( *x );

	conj_t    conjbeta  = bli_obj_conj_status( *beta );

	doff_t    diagoffx  = bli_obj_diag_offset( *x );

	dim_t     m         = bli_obj_length( *x );
	dim_t     n         = bli_obj_width( *x );

	void*     buf_x     = bli_obj_buffer_at_off( *x );
	inc_t     rs_x      = bli_obj_row_stride( *x );
	inc_t     cs_x      = bli_obj_col_stride( *x );

	void*     buf_beta;
	num_t     dt_beta;

	FUNCPTR_T f;

	// If beta is a scalar constant, use dt_x to extract the address of the
	// corresponding constant value; otherwise, use the datatype encoded
	// within the beta object and extract the buffer at the beta offset.
	bli_set_scalar_dt_buffer( beta, dt_x, dt_beta, buf_beta );

	// Index into the type combination array to extract the correct
	// function pointer.
	f = ftypes[dt_beta][dt_x];

	// Invoke the function.
	f( conjbeta,
	   diagoffx,
	   m,
	   n,
	   buf_beta,
	   buf_x, rs_x, cs_x );
}
예제 #20
0
void bli_mksymm_unb_var1( obj_t* a )
{
	num_t     dt_a      = bli_obj_datatype( *a );

	uplo_t    uploa     = bli_obj_uplo( *a );

	dim_t     m         = bli_obj_length( *a );

	void*     buf_a     = bli_obj_buffer_at_off( *a );
	inc_t     rs_a      = bli_obj_row_stride( *a );
	inc_t     cs_a      = bli_obj_col_stride( *a );

	FUNCPTR_T f;

	// Index into the type combination array to extract the correct
	// function pointer.
	f = ftypes[dt_a];

	// Invoke the function.
	f( uploa,
	   m,
	   buf_a, rs_a, cs_a );
}
예제 #21
0
void bli_trsm_rl_ker_var2( obj_t*  a,
                           obj_t*  b,
                           obj_t*  c,
                           trsm_t* cntl,
                           trsm_thrinfo_t* thread )
{
	num_t     dt_exec   = bli_obj_execution_datatype( *c );

	doff_t    diagoffb  = bli_obj_diag_offset( *b );

	pack_t    schema_a  = bli_obj_pack_schema( *a );
	pack_t    schema_b  = bli_obj_pack_schema( *b );

	dim_t     m         = bli_obj_length( *c );
	dim_t     n         = bli_obj_width( *c );
	dim_t     k         = bli_obj_width( *a );

	void*     buf_a     = bli_obj_buffer_at_off( *a );
	inc_t     cs_a      = bli_obj_col_stride( *a );
	dim_t     pd_a      = bli_obj_panel_dim( *a );
	inc_t     ps_a      = bli_obj_panel_stride( *a );

	void*     buf_b     = bli_obj_buffer_at_off( *b );
	inc_t     rs_b      = bli_obj_row_stride( *b );
	dim_t     pd_b      = bli_obj_panel_dim( *b );
	inc_t     ps_b      = bli_obj_panel_stride( *b );

	void*     buf_c     = bli_obj_buffer_at_off( *c );
	inc_t     rs_c      = bli_obj_row_stride( *c );
	inc_t     cs_c      = bli_obj_col_stride( *c );

	void*     buf_alpha1;
	void*     buf_alpha2;

	FUNCPTR_T f;

	func_t*   gemmtrsm_ukrs;
	func_t*   gemm_ukrs;
	void*     gemmtrsm_ukr;
	void*     gemm_ukr;


	// Grab the address of the internal scalar buffer for the scalar
	// attached to A. This will be the alpha scalar used in the gemmtrsm
	// subproblems (ie: the scalar that would be applied to the packed
	// copy of A prior to it being updated by the trsm subproblem). This
	// scalar may be unit, if for example it was applied during packing.
	buf_alpha1 = bli_obj_internal_scalar_buffer( *a );

	// Grab the address of the internal scalar buffer for the scalar
	// attached to C. This will be the "beta" scalar used in the gemm-only
	// subproblems that correspond to micro-panels that do not intersect
	// the diagonal. We need this separate scalar because it's possible
	// that the alpha attached to B was reset, if it was applied during
	// packing.
	buf_alpha2 = bli_obj_internal_scalar_buffer( *c );

	// Index into the type combination array to extract the correct
	// function pointer.
	f = ftypes[dt_exec];

	// Extract from the control tree node the func_t objects containing
	// the gemmtrsm and gemm micro-kernel function addresses, and then
	// query the function addresses corresponding to the current datatype.
	gemmtrsm_ukrs = cntl_gemmtrsm_u_ukrs( cntl );
	gemm_ukrs     = cntl_gemm_ukrs( cntl );
	gemmtrsm_ukr  = bli_func_obj_query( dt_exec, gemmtrsm_ukrs );
	gemm_ukr      = bli_func_obj_query( dt_exec, gemm_ukrs );

	// Invoke the function.
	f( diagoffb,
	   schema_a,
	   schema_b,
	   m,
	   n,
	   k,
	   buf_alpha1,
	   buf_a, cs_a, pd_a, ps_a,
	   buf_b, rs_b, pd_b, ps_b,
	   buf_alpha2,
	   buf_c, rs_c, cs_c,
	   gemmtrsm_ukr,
	   gemm_ukr,
	   thread );
}
예제 #22
0
파일: test_trsm.c 프로젝트: flame/blis
int main( int argc, char** argv )
{
	obj_t    a, c;
	obj_t    c_save;
	obj_t    alpha;
	dim_t    m, n;
	dim_t    p;
	dim_t    p_begin, p_max, p_inc;
	int      m_input, n_input;
	ind_t    ind;
	num_t    dt;
	char     dt_ch;
	int      r, n_repeats;
	side_t   side;
	uplo_t   uploa;
	trans_t  transa;
	diag_t   diaga;
	f77_char f77_side;
	f77_char f77_uploa;
	f77_char f77_transa;
	f77_char f77_diaga;

	double   dtime;
	double   dtime_save;
	double   gflops;

	//bli_init();

	//bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING );

	n_repeats = 3;

	dt      = DT;

	ind     = IND;

	p_begin = P_BEGIN;
	p_max   = P_MAX;
	p_inc   = P_INC;

	m_input = -1;
	n_input = -1;


	// Supress compiler warnings about unused variable 'ind'.
	( void )ind;

#if 0

	cntx_t* cntx;

	ind_t ind_mod = ind;

	// A hack to use 3m1 as 1mpb (with 1m as 1mbp).
	if ( ind == BLIS_3M1 ) ind_mod = BLIS_1M;

	// Initialize a context for the current induced method and datatype.
	cntx = bli_gks_query_ind_cntx( ind_mod, dt );

	// Set k to the kc blocksize for the current datatype.
	k_input = bli_cntx_get_blksz_def_dt( dt, BLIS_KC, cntx );

#elif 1

	//k_input = 256;

#endif

	// Choose the char corresponding to the requested datatype.
	if      ( bli_is_float( dt ) )    dt_ch = 's';
	else if ( bli_is_double( dt ) )   dt_ch = 'd';
	else if ( bli_is_scomplex( dt ) ) dt_ch = 'c';
	else                              dt_ch = 'z';

#if 0
	side   = BLIS_LEFT;
#else
	side   = BLIS_RIGHT;
#endif
#if 0
	uploa  = BLIS_LOWER;
#else
	uploa  = BLIS_UPPER;
#endif
	transa = BLIS_NO_TRANSPOSE;
	diaga  = BLIS_NONUNIT_DIAG;

	bli_param_map_blis_to_netlib_side( side, &f77_side );
	bli_param_map_blis_to_netlib_uplo( uploa, &f77_uploa );
	bli_param_map_blis_to_netlib_trans( transa, &f77_transa );
	bli_param_map_blis_to_netlib_diag( diaga, &f77_diaga );

	// Begin with initializing the last entry to zero so that
	// matlab allocates space for the entire array once up-front.
	for ( p = p_begin; p + p_inc <= p_max; p += p_inc ) ;

	printf( "data_%s_%ctrsm_%s", THR_STR, dt_ch, STR );
	printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n",
	        ( unsigned long )(p - p_begin + 1)/p_inc + 1,
	        ( unsigned long )0,
	        ( unsigned long )0, 0.0 );


	for ( p = p_begin; p <= p_max; p += p_inc )
	{

		if ( m_input < 0 ) m = p / ( dim_t )abs(m_input);
		else               m =     ( dim_t )    m_input;
		if ( n_input < 0 ) n = p / ( dim_t )abs(n_input);
		else               n =     ( dim_t )    n_input;

		bli_obj_create( dt, 1, 1, 0, 0, &alpha );

		if ( bli_is_left( side ) )
			bli_obj_create( dt, m, m, 0, 0, &a );
        else
			bli_obj_create( dt, n, n, 0, 0, &a );
		bli_obj_create( dt, m, n, 0, 0, &c );
		//bli_obj_create( dt, m, n, n, 1, &c );
		bli_obj_create( dt, m, n, 0, 0, &c_save );

		bli_randm( &a );
		bli_randm( &c );

		bli_obj_set_struc( BLIS_TRIANGULAR, &a );
		bli_obj_set_uplo( uploa, &a );
		bli_obj_set_conjtrans( transa, &a );
		bli_obj_set_diag( diaga, &a );

		bli_randm( &a );
		bli_mktrim( &a );

		// Load the diagonal of A to make it more likely to be invertible.
		bli_shiftd( &BLIS_TWO, &a );

		bli_setsc(  (2.0/1.0), 0.0, &alpha );

		bli_copym( &c, &c_save );
	
#if 0 //def BLIS
		bli_ind_disable_all_dt( dt );
		bli_ind_enable_dt( ind, dt );
#endif

		dtime_save = DBL_MAX;

		for ( r = 0; r < n_repeats; ++r )
		{
			bli_copym( &c_save, &c );

			dtime = bli_clock();

#ifdef PRINT
			bli_printm( "a", &a, "%4.1f", "" );
			bli_printm( "c", &c, "%4.1f", "" );
#endif

#ifdef BLIS

			bli_trsm( side,
			          &alpha,
			          &a,
			          &c );

#else

			if ( bli_is_float( dt ) )
			{
				f77_int   mm     = bli_obj_length( &c );
				f77_int   kk     = bli_obj_width( &c );
				f77_int   lda    = bli_obj_col_stride( &a );
				f77_int   ldc    = bli_obj_col_stride( &c );
				float*    alphap = ( float* )bli_obj_buffer( &alpha );
				float*    ap     = ( float* )bli_obj_buffer( &a );
				float*    cp     = ( float* )bli_obj_buffer( &c );

				strsm_( &f77_side,
						&f77_uploa,
						&f77_transa,
						&f77_diaga,
						&mm,
						&kk,
						alphap,
						ap, &lda,
						cp, &ldc );
			}
			else if ( bli_is_double( dt ) )
			{
				f77_int   mm     = bli_obj_length( &c );
				f77_int   kk     = bli_obj_width( &c );
				f77_int   lda    = bli_obj_col_stride( &a );
				f77_int   ldc    = bli_obj_col_stride( &c );
				double*   alphap = ( double* )bli_obj_buffer( &alpha );
				double*   ap     = ( double* )bli_obj_buffer( &a );
				double*   cp     = ( double* )bli_obj_buffer( &c );

				dtrsm_( &f77_side,
						&f77_uploa,
						&f77_transa,
						&f77_diaga,
						&mm,
						&kk,
						alphap,
						ap, &lda,
						cp, &ldc );
			}
			else if ( bli_is_scomplex( dt ) )
			{
				f77_int   mm     = bli_obj_length( &c );
				f77_int   kk     = bli_obj_width( &c );
				f77_int   lda    = bli_obj_col_stride( &a );
				f77_int   ldc    = bli_obj_col_stride( &c );
#ifdef EIGEN
				float*    alphap = ( float*    )bli_obj_buffer( &alpha );
				float*    ap     = ( float*    )bli_obj_buffer( &a );
				float*    cp     = ( float*    )bli_obj_buffer( &c );
#else
				scomplex* alphap = ( scomplex* )bli_obj_buffer( &alpha );
				scomplex* ap     = ( scomplex* )bli_obj_buffer( &a );
				scomplex* cp     = ( scomplex* )bli_obj_buffer( &c );
#endif

				ctrsm_( &f77_side,
						&f77_uploa,
						&f77_transa,
						&f77_diaga,
						&mm,
						&kk,
						alphap,
						ap, &lda,
						cp, &ldc );
			}
			else if ( bli_is_dcomplex( dt ) )
			{
				f77_int   mm     = bli_obj_length( &c );
				f77_int   kk     = bli_obj_width( &c );
				f77_int   lda    = bli_obj_col_stride( &a );
				f77_int   ldc    = bli_obj_col_stride( &c );
#ifdef EIGEN
				double*   alphap = ( double*   )bli_obj_buffer( &alpha );
				double*   ap     = ( double*   )bli_obj_buffer( &a );
				double*   cp     = ( double*   )bli_obj_buffer( &c );
#else
				dcomplex* alphap = ( dcomplex* )bli_obj_buffer( &alpha );
				dcomplex* ap     = ( dcomplex* )bli_obj_buffer( &a );
				dcomplex* cp     = ( dcomplex* )bli_obj_buffer( &c );
#endif

				ztrsm_( &f77_side,
						&f77_uploa,
						&f77_transa,
						&f77_diaga,
						&mm,
						&kk,
						alphap,
						ap, &lda,
						cp, &ldc );
			}
#endif

#ifdef PRINT
			bli_printm( "c after", &c, "%4.1f", "" );
			exit(1);
#endif

			dtime_save = bli_clock_min_diff( dtime_save, dtime );
		}

		if ( bli_is_left( side ) )
			gflops = ( 1.0 * m * m * n ) / ( dtime_save * 1.0e9 );
		else
			gflops = ( 1.0 * m * n * n ) / ( dtime_save * 1.0e9 );

		if ( bli_is_complex( dt ) ) gflops *= 4.0;

		printf( "data_%s_%ctrsm_%s", THR_STR, dt_ch, STR );
		printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n",
		        ( unsigned long )(p - p_begin + 1)/p_inc + 1,
		        ( unsigned long )m,
		        ( unsigned long )n, gflops );

		bli_obj_free( &alpha );

		bli_obj_free( &a );
		bli_obj_free( &c );
		bli_obj_free( &c_save );
	}

	//bli_finalize();

	return 0;
}
예제 #23
0
void bli_trsm_ru_ker_var2( obj_t*  a,
                           obj_t*  b,
                           obj_t*  c,
                           trsm_t* cntl )
{
	num_t     dt_exec   = bli_obj_execution_datatype( *c );

	doff_t    diagoffb  = bli_obj_diag_offset( *b );

	dim_t     m         = bli_obj_length( *c );
	dim_t     n         = bli_obj_width( *c );
	dim_t     k         = bli_obj_width( *a );

	void*     buf_a     = bli_obj_buffer_at_off( *a );
	inc_t     cs_a      = bli_obj_col_stride( *a );
	inc_t     pd_a      = bli_obj_panel_dim( *a );
	inc_t     ps_a      = bli_obj_panel_stride( *a );

	void*     buf_b     = bli_obj_buffer_at_off( *b );
	inc_t     rs_b      = bli_obj_row_stride( *b );
	inc_t     pd_b      = bli_obj_panel_dim( *b );
	inc_t     ps_b      = bli_obj_panel_stride( *b );

	void*     buf_c     = bli_obj_buffer_at_off( *c );
	inc_t     rs_c      = bli_obj_row_stride( *c );
	inc_t     cs_c      = bli_obj_col_stride( *c );

	void*     buf_alpha1;
	void*     buf_alpha2;

	FUNCPTR_T f;

	func_t*   gemmtrsm_ukrs;
	func_t*   gemm_ukrs;
	void*     gemmtrsm_ukr;
	void*     gemm_ukr;


	// Grab the address of the internal scalar buffer for the scalar
	// attached to A. This will be the alpha scalar used in the gemmtrsm
	// subproblems (ie: the scalar that would be applied to the packed
	// copy of A prior to it being updated by the trsm subproblem). This
	// scalar may be unit, if for example it was applied during packing.
	buf_alpha1 = bli_obj_internal_scalar_buffer( *a );

	// Grab the address of the internal scalar buffer for the scalar
	// attached to C. This will be the "beta" scalar used in the gemm-only
	// subproblems that correspond to micro-panels that do not intersect
	// the diagonal. We need this separate scalar because it's possible
	// that the alpha attached to B was reset, if it was applied during
	// packing.
	buf_alpha2 = bli_obj_internal_scalar_buffer( *c );

	// Index into the type combination array to extract the correct
	// function pointer.
	f = ftypes[dt_exec];

	// Adjust cs_a and rs_b if A and B were packed for 4m or 3m. This
	// is needed because cs_a and rs_b are used to index into the
	// micro-panels of A and B, respectively, and since the pointer
	// types in the macro-kernel (scomplex or dcomplex) will result
	// in pointer arithmetic that moves twice as far as it should,
	// given the datatypes actually stored (float or double), we must
	// halve the strides to compensate.
	if ( bli_obj_is_panel_packed_4m( *a ) ||
	     bli_obj_is_panel_packed_3m( *a ) ) { cs_a /= 2; rs_b /= 2; }

	// Extract from the control tree node the func_t objects containing
	// the gemmtrsm and gemm micro-kernel function addresses, and then
	// query the function addresses corresponding to the current datatype.
	gemmtrsm_ukrs = cntl_gemmtrsm_l_ukrs( cntl );
	gemm_ukrs     = cntl_gemm_ukrs( cntl );
	gemmtrsm_ukr  = bli_func_obj_query( dt_exec, gemmtrsm_ukrs );
	gemm_ukr      = bli_func_obj_query( dt_exec, gemm_ukrs );

	// Invoke the function.
	f( diagoffb,
	   m,
	   n,
	   k,
	   buf_alpha1,
	   buf_a, cs_a, pd_a, ps_a,
	   buf_b, rs_b, pd_b, ps_b,
	   buf_alpha2,
	   buf_c, rs_c, cs_c,
	   gemmtrsm_ukr,
	   gemm_ukr );
}
예제 #24
0
파일: test_trsm.c 프로젝트: figual/blis
int main( int argc, char** argv )
{
	obj_t a, b, c;
	obj_t c_save;
	obj_t alpha, beta;
	dim_t m, n;
	dim_t p;
	dim_t p_begin, p_end, p_inc;
	int   m_input, n_input;
	num_t dt_a, dt_b, dt_c;
	num_t dt_alpha, dt_beta;
	int   r, n_repeats;
	side_t side;
	uplo_t uplo;

	double dtime;
	double dtime_save;
	double gflops;

	bli_init();

	n_repeats = 3;

    if( argc < 7 ) 
    {   
        printf("Usage:\n");
        printf("test_foo.x m n k p_begin p_inc p_end:\n");
        exit;
    }   

    int world_size, world_rank, provided;
    MPI_Init_thread( NULL, NULL, MPI_THREAD_FUNNELED, &provided );
    MPI_Comm_size( MPI_COMM_WORLD, &world_size );
    MPI_Comm_rank( MPI_COMM_WORLD, &world_rank );

    m_input = strtol( argv[1], NULL, 10 );
    n_input = strtol( argv[2], NULL, 10 );
    p_begin = strtol( argv[4], NULL, 10 );
    p_inc   = strtol( argv[5], NULL, 10 );
    p_end   = strtol( argv[6], NULL, 10 );

#if 1
	dt_a = BLIS_DOUBLE;
	dt_b = BLIS_DOUBLE;
	dt_c = BLIS_DOUBLE;
	dt_alpha = BLIS_DOUBLE;
	dt_beta = BLIS_DOUBLE;
#else
	dt_a = dt_b = dt_c = dt_alpha = dt_beta = BLIS_FLOAT; 
	//dt_a = dt_b = dt_c = dt_alpha = dt_beta = BLIS_SCOMPLEX; 
#endif

	side = BLIS_LEFT;
	//side = BLIS_RIGHT;

	uplo = BLIS_LOWER;
	//uplo = BLIS_UPPER;

    for ( p = p_begin + world_rank * p_inc; p <= p_end; p += p_inc * world_size )
	{

		if ( m_input < 0 ) m = p * ( dim_t )abs(m_input);
		else               m =     ( dim_t )    m_input;
		if ( n_input < 0 ) n = p * ( dim_t )abs(n_input);
		else               n =     ( dim_t )    n_input;


		bli_obj_create( dt_alpha, 1, 1, 0, 0, &alpha );
		bli_obj_create( dt_beta,  1, 1, 0, 0, &beta );

		if ( bli_is_left( side ) )
			bli_obj_create( dt_a, m, m, 0, 0, &a );
		else
			bli_obj_create( dt_a, n, n, 0, 0, &a );
		bli_obj_create( dt_b, m, n, 0, 0, &b );
		bli_obj_create( dt_c, m, n, 0, 0, &c );
		bli_obj_create( dt_c, m, n, 0, 0, &c_save );

		bli_obj_set_struc( BLIS_TRIANGULAR, &a );
		bli_obj_set_uplo( uplo, &a );
		//bli_obj_set_diag( BLIS_UNIT_DIAG, &a );

		bli_randm( &a );
		bli_randm( &c );
		bli_randm( &b );

/*
		{ 
			obj_t a2;

			bli_obj_alias_to( &a, &a2 );
			bli_obj_toggle_uplo( &a2 );
			bli_obj_inc_diag_offset( 1, &a2 );
			bli_setm( &BLIS_ZERO, &a2 );
			bli_obj_inc_diag_offset( -2, &a2 );
			bli_obj_toggle_uplo( &a2 );
			bli_obj_set_diag( BLIS_NONUNIT_DIAG, &a2 );
			bli_scalm( &BLIS_TWO, &a2 );
			//bli_scalm( &BLIS_TWO, &a );
		} 
*/

		bli_setsc(  (2.0/1.0), 0.0, &alpha );
		bli_setsc(  (1.0/1.0), 0.0, &beta );


		bli_copym( &c, &c_save );
	
		dtime_save = 1.0e9;

		for ( r = 0; r < n_repeats; ++r )
		{
			bli_copym( &c_save, &c );

			dtime = bli_clock();


#ifdef PRINT
/*
			obj_t ar, ai;
			bli_obj_alias_to( &a, &ar );
			bli_obj_alias_to( &a, &ai );
			bli_obj_set_dt( BLIS_DOUBLE, &ar ); ar.rs *= 2; ar.cs *= 2;
			bli_obj_set_dt( BLIS_DOUBLE, &ai ); ai.rs *= 2; ai.cs *= 2; ai.buffer = ( double* )ai.buffer + 1;

			bli_printm( "ar", &ar, "%4.1f", "" );
			bli_printm( "ai", &ai, "%4.1f", "" );
*/

			bli_invertd( &a );
			bli_printm( "a", &a, "%4.1f", "" );
			bli_invertd( &a );
			bli_printm( "c", &c, "%4.1f", "" );
#endif

#ifdef BLIS
			//bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING );

			bli_trsm( side,
			//bli_trsm4m( side,
			//bli_trsm3m( side,
			          &alpha,
			          &a,
			          &c );
#else

		if ( bli_is_real( dt_a ) )
		{
			f77_char side   = 'L';
			f77_char uplo   = 'L';
			f77_char transa = 'N';
			f77_char diag   = 'N';
			f77_int  mm     = bli_obj_length( &c );
			f77_int  nn     = bli_obj_width( &c );
			f77_int  lda    = bli_obj_col_stride( &a );
			f77_int  ldc    = bli_obj_col_stride( &c );
			float *  alphap = bli_obj_buffer( &alpha );
			float *  ap     = bli_obj_buffer( &a );
			float *  cp     = bli_obj_buffer( &c );

			strsm_( &side,
			        &uplo,
			        &transa,
			        &diag,
			        &mm,
			        &nn,
			        alphap,
			        ap, &lda,
			        cp, &ldc );
		}
		else // if ( bli_is_complex( dt_a ) )
		{
			f77_char  side   = 'L';
			f77_char  uplo   = 'L';
			f77_char  transa = 'N';
			f77_char  diag   = 'N';
			f77_int   mm     = bli_obj_length( &c );
			f77_int   nn     = bli_obj_width( &c );
			f77_int   lda    = bli_obj_col_stride( &a );
			f77_int   ldc    = bli_obj_col_stride( &c );
			scomplex* alphap = bli_obj_buffer( &alpha );
			scomplex* ap     = bli_obj_buffer( &a );
			scomplex* cp     = bli_obj_buffer( &c );

			ctrsm_( &side,
			//ztrsm_( &side,
			        &uplo,
			        &transa,
			        &diag,
			        &mm,
			        &nn,
			        alphap,
			        ap, &lda,
			        cp, &ldc );
		}
		
#endif

#ifdef PRINT
			bli_printm( "c after", &c, "%4.1f", "" );
			exit(1);
#endif


			dtime_save = bli_clock_min_diff( dtime_save, dtime );
		}

		if ( bli_is_left( side ) )
			gflops = ( 1.0 * m * m * n ) / ( dtime_save * 1.0e9 );
		else
			gflops = ( 1.0 * m * n * n ) / ( dtime_save * 1.0e9 );

		if ( bli_is_complex( dt_a ) ) gflops *= 4.0;

#ifdef BLIS
		printf( "data_trsm_blis" );
#else
		printf( "data_trsm_%s", BLAS );
#endif
		printf( "( %2lu, 1:4 ) = [ %4lu %4lu  %10.3e  %6.3f ];\n",
		        ( unsigned long )(p - p_begin + 1)/p_inc + 1,
		        ( unsigned long )m,
		        ( unsigned long )n, dtime_save, gflops );


		bli_obj_free( &alpha );
		bli_obj_free( &beta );

		bli_obj_free( &a );
		bli_obj_free( &b );
		bli_obj_free( &c );
		bli_obj_free( &c_save );
	}

	bli_finalize();

	return 0;
}
예제 #25
0
파일: test_herk.c 프로젝트: ShawnLess/blis
int main( int argc, char** argv )
{
	obj_t a, c;
	obj_t c_save;
	obj_t alpha, beta;
	dim_t m, k;
	dim_t p;
	dim_t p_begin, p_end, p_inc;
	int   m_input, k_input;
	num_t dt;
	int   r, n_repeats;
	uplo_t uploc;
	trans_t transa;
	f77_char f77_uploc;
	f77_char f77_transa;

	double dtime;
	double dtime_save;
	double gflops;

	bli_init();

	//bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING );

	n_repeats = 3;

#ifndef PRINT
	p_begin = 200;
	p_end   = 2000;
	p_inc   = 200;

	m_input = -1;
	k_input = -1;
#else
	p_begin = 16;
	p_end   = 16;
	p_inc   = 1;

	m_input = 3;
	k_input = 1;
#endif

#if 1
	//dt = BLIS_FLOAT;
	dt = BLIS_DOUBLE;
#else
	//dt = BLIS_SCOMPLEX;
	dt = BLIS_DCOMPLEX;
#endif

	uploc = BLIS_LOWER;
	//uploc = BLIS_UPPER;

	transa = BLIS_NO_TRANSPOSE;

	bli_param_map_blis_to_netlib_uplo( uploc, &f77_uploc );
	bli_param_map_blis_to_netlib_trans( transa, &f77_transa );


	for ( p = p_begin; p <= p_end; p += p_inc )
	{
		if ( m_input < 0 ) m = p * ( dim_t )abs(m_input);
		else               m =     ( dim_t )    m_input;
		if ( k_input < 0 ) k = p * ( dim_t )abs(k_input);
		else               k =     ( dim_t )    k_input;

		bli_obj_create( dt, 1, 1, 0, 0, &alpha );
		bli_obj_create( dt, 1, 1, 0, 0, &beta );

		if ( bli_does_trans( transa ) )
			bli_obj_create( dt, k, m, 0, 0, &a );
		else
			bli_obj_create( dt, m, k, 0, 0, &a );
		bli_obj_create( dt, m, m, 0, 0, &c );
		bli_obj_create( dt, m, m, 0, 0, &c_save );

		bli_randm( &a );
		bli_randm( &c );

		bli_obj_set_struc( BLIS_HERMITIAN, c );
		bli_obj_set_uplo( uploc, c );

		bli_obj_set_conjtrans( transa, a );


		bli_setsc(  (2.0/1.0), 0.0, &alpha );
		bli_setsc( -(1.0/1.0), 0.0, &beta );


		bli_copym( &c, &c_save );
	
		dtime_save = 1.0e9;

		for ( r = 0; r < n_repeats; ++r )
		{
			bli_copym( &c_save, &c );


			dtime = bli_clock();


#ifdef PRINT
			bli_printm( "a", &a, "%4.1f", "" );
			bli_printm( "c", &c, "%4.1f", "" );
#endif

#ifdef BLIS

			bli_herk( &alpha,
			          &a,
			          &beta,
			          &c );

#else
		if ( bli_is_float( dt ) )
		{
			f77_int  mm     = bli_obj_length( c );
			f77_int  kk     = bli_obj_width_after_trans( a );
			f77_int  lda    = bli_obj_col_stride( a );
			f77_int  ldc    = bli_obj_col_stride( c );
			float*   alphap = bli_obj_buffer( alpha );
			float*   ap     = bli_obj_buffer( a );
			float*   betap  = bli_obj_buffer( beta );
			float*   cp     = bli_obj_buffer( c );

			ssyrk_( &f77_uploc,
			        &f77_transa,
			        &mm,
			        &kk,
			        alphap,
			        ap, &lda,
			        betap,
			        cp, &ldc );
		}
		else if ( bli_is_double( dt ) )
		{
			f77_int  mm     = bli_obj_length( c );
			f77_int  kk     = bli_obj_width_after_trans( a );
			f77_int  lda    = bli_obj_col_stride( a );
			f77_int  ldc    = bli_obj_col_stride( c );
			double*  alphap = bli_obj_buffer( alpha );
			double*  ap     = bli_obj_buffer( a );
			double*  betap  = bli_obj_buffer( beta );
			double*  cp     = bli_obj_buffer( c );

			dsyrk_( &f77_uploc,
			        &f77_transa,
			        &mm,
			        &kk,
			        alphap,
			        ap, &lda,
			        betap,
			        cp, &ldc );
		}
		else if ( bli_is_scomplex( dt ) )
		{
			f77_int  mm     = bli_obj_length( c );
			f77_int  kk     = bli_obj_width_after_trans( a );
			f77_int  lda    = bli_obj_col_stride( a );
			f77_int  ldc    = bli_obj_col_stride( c );
			float*     alphap = bli_obj_buffer( alpha );
			scomplex*  ap     = bli_obj_buffer( a );
			float*     betap  = bli_obj_buffer( beta );
			scomplex*  cp     = bli_obj_buffer( c );

			cherk_( &f77_uploc,
			        &f77_transa,
			        &mm,
			        &kk,
			        alphap,
			        ap, &lda,
			        betap,
			        cp, &ldc );
		}
		else if ( bli_is_dcomplex( dt ) )
		{
			f77_int  mm     = bli_obj_length( c );
			f77_int  kk     = bli_obj_width_after_trans( a );
			f77_int  lda    = bli_obj_col_stride( a );
			f77_int  ldc    = bli_obj_col_stride( c );
			double*    alphap = bli_obj_buffer( alpha );
			dcomplex*  ap     = bli_obj_buffer( a );
			double*    betap  = bli_obj_buffer( beta );
			dcomplex*  cp     = bli_obj_buffer( c );

			zherk_( &f77_uploc,
			        &f77_transa,
			        &mm,
			        &kk,
			        alphap,
			        ap, &lda,
			        betap,
			        cp, &ldc );
		}
#endif

#ifdef PRINT
			bli_printm( "c after", &c, "%4.1f", "" );
			exit(1);
#endif


			dtime_save = bli_clock_min_diff( dtime_save, dtime );
		}

		gflops = ( 1.0 * m * k * m ) / ( dtime_save * 1.0e9 );

		if ( bli_is_complex( dt ) ) gflops *= 4.0;

#ifdef BLIS
		printf( "data_herk_blis" );
#else
		printf( "data_herk_%s", BLAS );
#endif
		printf( "( %2lu, 1:4 ) = [ %4lu %4lu  %10.3e  %6.3f ];\n",
		        ( unsigned long )(p - p_begin + 1)/p_inc + 1,
		        ( unsigned long )m,
		        ( unsigned long )k, dtime_save, gflops );


		bli_obj_free( &alpha );
		bli_obj_free( &beta );

		bli_obj_free( &a );
		bli_obj_free( &c );
		bli_obj_free( &c_save );
	}

	bli_finalize();

	return 0;
}
예제 #26
0
int main( int argc, char** argv )
{
	obj_t a, b, c;
	obj_t c_save;
	obj_t alpha, beta;
	dim_t m, n, k;
	dim_t p;
	dim_t p_begin, p_end, p_inc;
	int   m_input, n_input, k_input;
	num_t dt_a, dt_b, dt_c;
	num_t dt_alpha, dt_beta;
	int   r, n_repeats;

	double dtime;
	double dtime_save;
	double gflops;

    
    int world_size, world_rank, provided;
    MPI_Init_thread( NULL, NULL, MPI_THREAD_FUNNELED, &provided );
    MPI_Comm_size( MPI_COMM_WORLD, &world_size );
    MPI_Comm_rank( MPI_COMM_WORLD, &world_rank );

	bli_init();

	n_repeats = 3;

#ifndef PRINT
	p_begin = 16;
	p_end   = 2048;
	p_inc   = 16;

	m_input = 10240;
	n_input = 10240;
	k_input = -1;
#else
	p_begin = 24;
	p_end   = 24;
	p_inc   = 1;

	m_input = -1;
	k_input = -1;
	n_input = -1;
#endif

	dt_a = BLIS_DOUBLE;
	dt_b = BLIS_DOUBLE;
	dt_c = BLIS_DOUBLE;
	dt_alpha = BLIS_DOUBLE;
	dt_beta = BLIS_DOUBLE;

	for ( p = p_begin + world_rank * p_inc ; p <= p_end; p += p_inc * world_size )
	{

		if ( m_input < 0 ) m = p * ( dim_t )abs(m_input);
		else               m =     ( dim_t )    m_input;
		if ( n_input < 0 ) n = p * ( dim_t )abs(n_input);
		else               n =     ( dim_t )    n_input;
		if ( k_input < 0 ) k = p * ( dim_t )abs(k_input);
		else               k =     ( dim_t )    k_input;


		bli_obj_create( dt_alpha, 1, 1, 0, 0, &alpha );
		bli_obj_create( dt_beta,  1, 1, 0, 0, &beta );

		bli_obj_create( dt_a, m, k, 0, 0, &a );
		bli_obj_create( dt_b, k, n, 0, 0, &b );
		bli_obj_create( dt_c, m, n, 0, 0, &c );
		bli_obj_create( dt_c, m, n, 0, 0, &c_save );

		bli_randm( &a );
		bli_randm( &b );
		bli_randm( &c );


		bli_setsc(  (1.0/1.0), 0.0, &alpha );
		bli_setsc(  (1.0/1.0), 0.0, &beta );

		bli_copym( &c, &c_save );
	
		dtime_save = 1.0e9;

		for ( r = 0; r < n_repeats; ++r )
		{
			bli_copym( &c_save, &c );


			dtime = bli_clock();


#ifdef PRINT
			bli_printm( "a", &a, "%4.1f", "" );
			bli_printm( "b", &b, "%4.1f", "" );
			bli_printm( "c", &c, "%4.1f", "" );
#endif

#ifdef BLIS
			//bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING );
            {
                bli_gemm( &alpha,
                          &a,
                          &b,
                          &beta,
                          &c );
            }

#else
			char    transa = 'N';
			char    transb = 'N';
			int     mm     = bli_obj_length( c );
			int     kk     = bli_obj_width_after_trans( a );
			int     nn     = bli_obj_width( c );
			int     lda    = bli_obj_col_stride( a );
			int     ldb    = bli_obj_col_stride( b );
			int     ldc    = bli_obj_col_stride( c );
			double* alphap = bli_obj_buffer( alpha );
			double* ap     = bli_obj_buffer( a );
			double* bp     = bli_obj_buffer( b );
			double* betap  = bli_obj_buffer( beta );
			double* cp     = bli_obj_buffer( c );

			dgemm_( &transa,
			        &transb,
			        &mm,
			        &nn,
			        &kk,
			        alphap,
			        ap, &lda,
			        bp, &ldb,
			        betap,
			        cp, &ldc );
            
#endif

#ifdef PRINT
			bli_printm( "c after", &c, "%4.1f", "" );
			exit(1);
#endif


			dtime_save = bli_clock_min_diff( dtime_save, dtime );
		}
		
        gflops = ( 2.0 * m * k * n ) / ( dtime_save * 1.0e9 );

        //if(world_rank == 0){
#ifdef BLIS
            printf( "data_gemm_blis" );
#else
            printf( "data_gemm_%s", BLAS );
#endif
            printf( "( %2ld, 1:5 ) = [ %4lu %4lu %4lu  %10.3e  %6.3f %d ];\n",
                    (p - p_begin + 1)/p_inc + 1, m, k, n, dtime_save, gflops, world_rank );
        //}
    
        bli_obj_free( &alpha );
        bli_obj_free( &beta );

        bli_obj_free( &a );
        bli_obj_free( &b );
        bli_obj_free( &c );
        bli_obj_free( &c_save );
	}

	bli_finalize();

    MPI_Finalize();

	return 0;
}
예제 #27
0
void bli_herk_l_ker_var2( obj_t*  a,
                          obj_t*  b,
                          obj_t*  c,
                          gemm_t* cntl,
                          herk_thrinfo_t* thread )
{
	num_t     dt_exec   = bli_obj_execution_datatype( *c );

	doff_t    diagoffc  = bli_obj_diag_offset( *c );

	pack_t    schema_a  = bli_obj_pack_schema( *a );
	pack_t    schema_b  = bli_obj_pack_schema( *b );

	dim_t     m         = bli_obj_length( *c );
	dim_t     n         = bli_obj_width( *c );
	dim_t     k         = bli_obj_width( *a );

	void*     buf_a     = bli_obj_buffer_at_off( *a );
	inc_t     cs_a      = bli_obj_col_stride( *a );
	inc_t     pd_a      = bli_obj_panel_dim( *a );
	inc_t     ps_a      = bli_obj_panel_stride( *a );

	void*     buf_b     = bli_obj_buffer_at_off( *b );
	inc_t     rs_b      = bli_obj_row_stride( *b );
	inc_t     pd_b      = bli_obj_panel_dim( *b );
	inc_t     ps_b      = bli_obj_panel_stride( *b );

	void*     buf_c     = bli_obj_buffer_at_off( *c );
	inc_t     rs_c      = bli_obj_row_stride( *c );
	inc_t     cs_c      = bli_obj_col_stride( *c );

	obj_t     scalar_a;
	obj_t     scalar_b;

	void*     buf_alpha;
	void*     buf_beta;

	FUNCPTR_T f;

	func_t*   gemm_ukrs;
	void*     gemm_ukr;


	// Detach and multiply the scalars attached to A and B.
	bli_obj_scalar_detach( a, &scalar_a );
	bli_obj_scalar_detach( b, &scalar_b );
	bli_mulsc( &scalar_a, &scalar_b );

	// Grab the addresses of the internal scalar buffers for the scalar
	// merged above and the scalar attached to C.
	buf_alpha = bli_obj_internal_scalar_buffer( scalar_b );
	buf_beta  = bli_obj_internal_scalar_buffer( *c );

	// Index into the type combination array to extract the correct
	// function pointer.
	f = ftypes[dt_exec];

	// Extract from the control tree node the func_t object containing
	// the gemm micro-kernel function addresses, and then query the
	// function address corresponding to the current datatype.
	gemm_ukrs = cntl_gemm_ukrs( cntl );
	gemm_ukr  = bli_func_obj_query( dt_exec, gemm_ukrs );

	// Invoke the function.
	f( diagoffc,
	   schema_a,
	   schema_b,
	   m,
	   n,
	   k,
	   buf_alpha,
	   buf_a, cs_a, pd_a, ps_a,
	   buf_b, rs_b, pd_b, ps_b,
	   buf_beta,
	   buf_c, rs_c, cs_c,
	   gemm_ukr,
	   thread );
}
예제 #28
0
void bli_herk_u_ker_var2
     (
       obj_t*  a,
       obj_t*  b,
       obj_t*  c,
       cntx_t* cntx,
       cntl_t* cntl,
       thrinfo_t* thread
     )
{
	num_t     dt_exec   = bli_obj_exec_dt( c );

	doff_t    diagoffc  = bli_obj_diag_offset( c );

	pack_t    schema_a  = bli_obj_pack_schema( a );
	pack_t    schema_b  = bli_obj_pack_schema( b );

	dim_t     m         = bli_obj_length( c );
	dim_t     n         = bli_obj_width( c );
	dim_t     k         = bli_obj_width( a );

	void*     buf_a     = bli_obj_buffer_at_off( a );
	inc_t     cs_a      = bli_obj_col_stride( a );
	inc_t     is_a      = bli_obj_imag_stride( a );
	dim_t     pd_a      = bli_obj_panel_dim( a );
	inc_t     ps_a      = bli_obj_panel_stride( a );

	void*     buf_b     = bli_obj_buffer_at_off( b );
	inc_t     rs_b      = bli_obj_row_stride( b );
	inc_t     is_b      = bli_obj_imag_stride( b );
	dim_t     pd_b      = bli_obj_panel_dim( b );
	inc_t     ps_b      = bli_obj_panel_stride( b );

	void*     buf_c     = bli_obj_buffer_at_off( c );
	inc_t     rs_c      = bli_obj_row_stride( c );
	inc_t     cs_c      = bli_obj_col_stride( c );

	obj_t     scalar_a;
	obj_t     scalar_b;

	void*     buf_alpha;
	void*     buf_beta;

	FUNCPTR_T f;

	// Detach and multiply the scalars attached to A and B.
	bli_obj_scalar_detach( a, &scalar_a );
	bli_obj_scalar_detach( b, &scalar_b );
	bli_mulsc( &scalar_a, &scalar_b );

	// Grab the addresses of the internal scalar buffers for the scalar
	// merged above and the scalar attached to C.
	buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b );
	buf_beta  = bli_obj_internal_scalar_buffer( c );

	// Index into the type combination array to extract the correct
	// function pointer.
	f = ftypes[dt_exec];

	// Invoke the function.
	f( diagoffc,
	   schema_a,
	   schema_b,
	   m,
	   n,
	   k,
	   buf_alpha,
	   buf_a, cs_a, is_a,
	          pd_a, ps_a,
	   buf_b, rs_b, is_b,
	          pd_b, ps_b,
	   buf_beta,
	   buf_c, rs_c, cs_c,
	   cntx,
	   thread );
}
예제 #29
0
void bli_trsm_rl_ker_var2
     (
       obj_t*  a,
       obj_t*  b,
       obj_t*  c,
       cntx_t* cntx,
       cntl_t* cntl,
       thrinfo_t* thread
     )
{
	num_t     dt_exec   = bli_obj_execution_datatype( *c );

	doff_t    diagoffb  = bli_obj_diag_offset( *b );

	pack_t    schema_a  = bli_obj_pack_schema( *a );
	pack_t    schema_b  = bli_obj_pack_schema( *b );

	dim_t     m         = bli_obj_length( *c );
	dim_t     n         = bli_obj_width( *c );
	dim_t     k         = bli_obj_width( *a );

	void*     buf_a     = bli_obj_buffer_at_off( *a );
	inc_t     cs_a      = bli_obj_col_stride( *a );
	dim_t     pd_a      = bli_obj_panel_dim( *a );
	inc_t     ps_a      = bli_obj_panel_stride( *a );

	void*     buf_b     = bli_obj_buffer_at_off( *b );
	inc_t     rs_b      = bli_obj_row_stride( *b );
	dim_t     pd_b      = bli_obj_panel_dim( *b );
	inc_t     ps_b      = bli_obj_panel_stride( *b );

	void*     buf_c     = bli_obj_buffer_at_off( *c );
	inc_t     rs_c      = bli_obj_row_stride( *c );
	inc_t     cs_c      = bli_obj_col_stride( *c );

	void*     buf_alpha1;
	void*     buf_alpha2;

	FUNCPTR_T f;

	// Grab the address of the internal scalar buffer for the scalar
	// attached to A (the non-triangular matrix). This will be the alpha
	// scalar used in the gemmtrsm subproblems (ie: the scalar that would
	// be applied to the packed copy of A prior to it being updated by
	// the trsm subproblem). This scalar may be unit, if for example it
	// was applied during packing.
	buf_alpha1 = bli_obj_internal_scalar_buffer( *a );

	// Grab the address of the internal scalar buffer for the scalar
	// attached to C. This will be the "beta" scalar used in the gemm-only
	// subproblems that correspond to micro-panels that do not intersect
	// the diagonal. We need this separate scalar because it's possible
	// that the alpha attached to B was reset, if it was applied during
	// packing.
	buf_alpha2 = bli_obj_internal_scalar_buffer( *c );

	// Index into the type combination array to extract the correct
	// function pointer.
	f = ftypes[dt_exec];

	// Invoke the function.
	f( diagoffb,
	   schema_a,
	   schema_b,
	   m,
	   n,
	   k,
	   buf_alpha1,
	   buf_a, cs_a, pd_a, ps_a,
	   buf_b, rs_b, pd_b, ps_b,
	   buf_alpha2,
	   buf_c, rs_c, cs_c,
	   cntx,
	   thread );
}
예제 #30
0
int main( int argc, char** argv )
{
	obj_t a, b, c;
	obj_t c_save;
	obj_t alpha, beta;
	dim_t m, n;
	dim_t p;
	dim_t p_begin, p_end, p_inc;
	int   m_input, n_input;
	num_t dt_a, dt_b, dt_c;
	num_t dt_alpha, dt_beta;
	int   r, n_repeats;
	side_t side;
	uplo_t uplo;

	double dtime;
	double dtime_save;
	double gflops;

	bli_init();

	n_repeats = 3;

#ifndef PRINT
	p_begin = 40;
	p_end   = 2000;
	p_inc   = 40;

	m_input = -1;
	n_input = -1;
#else
	p_begin = 16;
	p_end   = 16;
	p_inc   = 1;

	m_input = 8;
	n_input = 4;
#endif

	dt_a = BLIS_DOUBLE;
	dt_b = BLIS_DOUBLE;
	dt_c = BLIS_DOUBLE;
	dt_alpha = BLIS_DOUBLE;
	dt_beta = BLIS_DOUBLE;

	side = BLIS_LEFT;
	//side = BLIS_RIGHT;

	uplo = BLIS_LOWER;

	for ( p = p_begin; p <= p_end; p += p_inc )
	{

		if ( m_input < 0 ) m = p * ( dim_t )abs(m_input);
		else               m =     ( dim_t )    m_input;
		if ( n_input < 0 ) n = p * ( dim_t )abs(n_input);
		else               n =     ( dim_t )    n_input;


		bli_obj_create( dt_alpha, 1, 1, 0, 0, &alpha );
		bli_obj_create( dt_beta,  1, 1, 0, 0, &beta );

		if ( bli_is_left( side ) )
			bli_obj_create( dt_a, m, m, 0, 0, &a );
		else
			bli_obj_create( dt_a, n, n, 0, 0, &a );
		bli_obj_create( dt_b, m, n, 0, 0, &b );
		bli_obj_create( dt_c, m, n, 0, 0, &c );
		bli_obj_create( dt_c, m, n, 0, 0, &c_save );

		bli_obj_set_struc( BLIS_TRIANGULAR, a );
		bli_obj_set_uplo( uplo, a );

		bli_randm( &a );
		bli_randm( &c );
		bli_randm( &b );


		bli_setsc(  (2.0/1.0), 0.0, &alpha );
		bli_setsc( -(1.0/1.0), 0.0, &beta );


		bli_copym( &c, &c_save );
	
		dtime_save = 1.0e9;

		for ( r = 0; r < n_repeats; ++r )
		{
			bli_copym( &c_save, &c );

			dtime = bli_clock();


#ifdef PRINT
			bli_printm( "a", &a, "%11.8f", "" );
			bli_printm( "c", &c, "%14.11f", "" );
#endif

#ifdef BLIS
			//bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING );

			bli_trmm( side,
			          &alpha,
			          &a,
			          &c );

#else

			f77_char side   = 'L';
			f77_char uplo   = 'L';
			f77_char transa = 'N';
			f77_char diag   = 'N';
			f77_int  mm     = bli_obj_length( c );
			f77_int  nn     = bli_obj_width( c );
			f77_int  lda    = bli_obj_col_stride( a );
			f77_int  ldc    = bli_obj_col_stride( c );
			double*  alphap = bli_obj_buffer( alpha );
			double*  ap     = bli_obj_buffer( a );
			double*  cp     = bli_obj_buffer( c );

			dtrmm_( &side,
			        &uplo,
			        &transa,
			        &diag,
			        &mm,
			        &nn,
			        alphap,
			        ap, &lda,
			        cp, &ldc );
#endif

#ifdef PRINT
			bli_printm( "c after", &c, "%14.11f", "" );
			exit(1);
#endif



			dtime_save = bli_clock_min_diff( dtime_save, dtime );
		}

		if ( bli_is_left( side ) )
			gflops = ( 1.0 * m * m * n ) / ( dtime_save * 1.0e9 );
		else
			gflops = ( 1.0 * m * n * n ) / ( dtime_save * 1.0e9 );

#ifdef BLIS
		printf( "data_trmm_blis" );
#else
		printf( "data_trmm_%s", BLAS );
#endif
		printf( "( %2lu, 1:4 ) = [ %4lu %4lu  %10.3e  %6.3f ];\n",
		        ( unsigned long )(p - p_begin + 1)/p_inc + 1,
		        ( unsigned long )m,
		        ( unsigned long )n, dtime_save, gflops );


		bli_obj_free( &alpha );
		bli_obj_free( &beta );

		bli_obj_free( &a );
		bli_obj_free( &b );
		bli_obj_free( &c );
		bli_obj_free( &c_save );
	}

	bli_finalize();

	return 0;
}