Exemple #1
0
//===================================================================================================================
//===================================================================================================================
//===================================================================================================================
extern "C" void
magma_dlarft_sm32x32_batched(magma_int_t n, magma_int_t k, 
                    double **v_array, magma_int_t ldv,
                    double **tau_array, 
                    double **T_array, magma_int_t ldt, 
                    magma_int_t batchCount, magma_queue_t queue)
{
    if ( k <= 0) return;

     //==================================
     //          GEMV
     //==================================
#define USE_GEMV2
#define use_gemm_larft_sm32

    #if defined(use_gemm_larft_sm32)
    magma_dgemm_batched( MagmaConjTrans, MagmaNoTrans, 
                         k, k, n, 
                         MAGMA_D_ONE, v_array, ldv, 
                         v_array, ldv, 
                         MAGMA_D_ZERO, T_array, ldt, 
                         batchCount, queue );
    magmablas_dlaset_batched( MagmaLower, k, k, 
            MAGMA_D_ZERO, MAGMA_D_ZERO, 
            T_array, ldt, batchCount, queue );
    #else
    #if 1
    for (magma_int_t i=0; i < k; i++)
    {
        //W(1:i-1) := - tau(i) * V(i:n,1:i-1)' * V(i:n,i)
        //T( i, i ) = tau( i ) 
        //custom implementation.
        #ifdef USE_GEMV2
        magmablas_dlarft_gemvrowwise_batched( n-i, i, 
                            tau_array,
                            v_array, ldv, 
                            T_array, ldt,
                            batchCount, queue);
                            
        #else       
        magmablas_dlarft_gemvcolwise_batched( n-i, i, v_array, ldv, T_array, ldt, tau_array, batchCount, queue);
        #endif
    }
    #else
        //seems to be very slow when k=32 while the one by one loop above is faster
        dlarft_gemv_loop_inside_kernel_batched(n, k, tau_array, v_array, ldv, T_array, ldt, batchCount, queue); 
    #endif
    #endif
     //==================================
     //          TRMV
     //==================================
     //T(1:i-1,i) := T(1:i-1,1:i-1) * W(1:i-1) i=[1:k]
     magmablas_dlarft_dtrmv_sm32x32_batched(k, k, tau_array, T_array, ldt, T_array, ldt, batchCount, queue);
}
/***************************************************************************//**
    Purpose
    -------
    DLARFB applies a real block reflector H or its transpose H^H to a
    DOUBLE PRECISION m by n matrix C, from the left.
    
    __Note that this function assumes__ that the upper part of dV_array is 0
    because it is referenced. Same for upper/lower part of dT_array.

    Arguments
    ---------
    @param[in]
    side    magma_side_t
      -     = MagmaLeft:      apply H or H^H from the Left
      -     = MagmaRight:     apply H or H^H from the Right

    @param[in]
    trans   magma_trans_t
      -     = MagmaNoTrans:    apply H   (No transpose)
      -     = MagmaTrans: apply H^H (Conjugate transpose)

    @param[in]
    direct  magma_direct_t
            Indicates how H is formed from a product of elementary
            reflectors
      -     = MagmaForward:  H = H(1) H(2) . . . H(k) (Forward)
      -     = MagmaBackward: H = H(k) . . . H(2) H(1) (Backward)

    @param[in]
    storev  magma_storev_t
            Indicates how the vectors which define the elementary
            reflectors are stored:
      -     = MagmaColumnwise: Columnwise
      -     = MagmaRowwise:    Rowwise

    @param[in]
    m       INTEGER
            The number of rows of the matrix C.

    @param[in]
    n       INTEGER
            The number of columns of the matrix C.

    @param[in]
    k       INTEGER
            The order of the matrix T (= the number of elementary
            reflectors whose product defines the block reflector).

    @param[in]
    dV_array      DOUBLE PRECISION array on the GPU, dimension
                (LDDV,K) if STOREV = MagmaColumnwise
                (LDDV,M) if STOREV = MagmaRowwise and SIDE = MagmaLeft
                (LDDV,N) if STOREV = MagmaRowwise and SIDE = MagmaRight
            The matrix V. See further details.

    @param[in]
    lddv    INTEGER
            The leading dimension of the array V.
            If STOREV = MagmaColumnwise and SIDE = MagmaLeft, LDDV >= max(1,M);
            if STOREV = MagmaColumnwise and SIDE = MagmaRight, LDDV >= max(1,N);
            if STOREV = MagmaRowwise, LDDV >= K.

    @param[in]
    dT_array      DOUBLE PRECISION array on the GPU, dimension (LDDT,K)
            The triangular k by k matrix T in the representation of the
            block reflector.

    @param[in]
    lddt    INTEGER
            The leading dimension of the array T. LDDT >= K.

    @param[in,out]
    dC_array      DOUBLE PRECISION array on the GPU, dimension (LDDC,N)
            On entry, the m by n matrix C.
            On exit, C is overwritten by H*C, or H^H*C, or C*H, or C*H^H.

    @param[in]
    lddc    INTEGER
            The leading dimension of the array C. LDDC >= max(1,M).

    @param
    dwork_array   (workspace) DOUBLE PRECISION array, dimension (LDWORK,K)

    @param[in]
    ldwork  INTEGER
            The leading dimension of the array WORK.
            If SIDE = MagmaLeft,  LDWORK >= max(1,N);
            if SIDE = MagmaRight, LDWORK >= max(1,M);

    @param
    dworkvt_array (workspace) DOUBLE PRECISION array, dimension (LDWORKT,K)

    @param[in]
    ldworkvt INTEGER
            The leading dimension of the array WORKVT.
            LDWORKVT >= max(1,min(M,N));

    @param[in]
    batchCount  INTEGER
                The number of matrices to operate on.

    @param[in]
    queue   magma_queue_t
            Queue to execute in.

    Further Details
    ---------------
    The shape of the matrix V and the storage of the vectors which define
    the H(i) is best illustrated by the following example with n = 5 and
    k = 3.
    All elements including 0's and 1's are stored, unlike LAPACK.

        DIRECT = MagmaForward and         DIRECT = MagmaForward and
        STOREV = MagmaColumnwise:         STOREV = MagmaRowwise:

                 V = (  1  0  0 )                 V = (  1 v1 v1 v1 v1 )
                     ( v1  1  0 )                     (  0  1 v2 v2 v2 )
                     ( v1 v2  1 )                     (  0  0  1 v3 v3 )
                     ( v1 v2 v3 )
                     ( v1 v2 v3 )

        DIRECT = MagmaBackward and        DIRECT = MagmaBackward and
        STOREV = MagmaColumnwise:         STOREV = MagmaRowwise:

                 V = ( v1 v2 v3 )                 V = ( v1 v1  1  0  0 )
                     ( v1 v2 v3 )                     ( v2 v2 v2  1  0 )
                     (  1 v2 v3 )                     ( v3 v3 v3 v3  1 )
                     (  0  1 v3 )
                     (  0  0  1 )

    @ingroup magma_larfb_batched
*******************************************************************************/
extern "C" magma_int_t
magma_dlarfb_gemm_batched(
    magma_side_t side, magma_trans_t trans, magma_direct_t direct, magma_storev_t storev,
    magma_int_t m, magma_int_t n, magma_int_t k,
    magmaDouble_const_ptr dV_array[],    magma_int_t lddv,
    magmaDouble_const_ptr dT_array[],    magma_int_t lddt,
    magmaDouble_ptr dC_array[],          magma_int_t lddc,
    magmaDouble_ptr dwork_array[],       magma_int_t ldwork,
    magmaDouble_ptr dworkvt_array[],     magma_int_t ldworkvt,
    magma_int_t batchCount, magma_queue_t queue)
{
    // Constants
    const double c_zero    = MAGMA_D_ZERO;
    const double c_one     = MAGMA_D_ONE;
    const double c_neg_one = MAGMA_D_NEG_ONE;

    /* Function Body */
    magma_int_t info = 0;
    if (m <= 0 || n <= 0) {
        return info;
    }
    
    // Local variables
    magma_int_t ldwvt = (m > n ?  k : m);
    magma_int_t ldw;
    if ( side == MagmaLeft ) {
        ldw = k;
    } else {
        ldw = m;
    }
    // opposite of trans
    magma_trans_t transt;
    if (trans == MagmaNoTrans)
        transt = MagmaTrans;
    else
        transt = MagmaNoTrans;
    
    MAGMA_UNUSED( transt );  // TODO: is this a bug that it isn't used?
    
    // whether V is stored transposed or not
    magma_trans_t notransV, transV;
    if (storev == MagmaColumnwise) {
        notransV = MagmaNoTrans;
        transV   = MagmaTrans;
    }
    else {
        notransV = MagmaTrans;
        transV   = MagmaNoTrans;
    }

    if ( side == MagmaLeft ) {
        // Form H C or H^H C
        // Comments assume H C.
        // When forming H^H C, T gets transposed via transt for m >= n or by trans for m < n.
        
        // W = V^H C                              
        magma_dgemm_batched( MagmaTrans,notransV, /*NontransLeft*/
                     k, n, m,
                     c_one,  dV_array,    lddv,
                             dC_array,    lddc,
                     c_zero, dwork_array, ldw,
                     batchCount, queue );

        if (m <= n) {
            // W2 = V T
            magma_dgemm_batched( notransV, trans, /* (NoTrans), trans(ConjTrans),*/
                         m, k, k,
                         c_one,  dV_array, lddv,
                                 dT_array, lddt,
                         c_zero, dworkvt_array, ldwvt,
                         batchCount, queue );

            // C = C - W2 W = C - V T V^H C = (I - V T V^H) C = H C
            magma_dgemm_batched( MagmaNoTrans, MagmaNoTrans,
                         m, n, k,
                         c_neg_one, dworkvt_array,  ldwvt,
                                    dwork_array,    ldw,
                         c_one,     dC_array,       lddc,
                         batchCount, queue );
        }
        else {
            // W2 = T W  = T  V^H C
            magma_dgemm_batched( trans, MagmaNoTrans,
                         k, n, k,
                         c_one,  dT_array, lddt,
                                 dwork_array, ldw,
                         c_zero, dworkvt_array, ldwvt,
                         batchCount, queue );

            // C = C - V W2 = C - V T V^H C = (I - V T V^H) C = H C
            magma_dgemm_batched( notransV, MagmaNoTrans,
                         m, n, k,
                         c_neg_one, dV_array,  lddv,
                                    dworkvt_array,  ldwvt,
                         c_one,     dC_array,       lddc,
                         batchCount, queue );
        }
    }
    else {
        // Form C H or C H^H
        // Comments assume C H.
        // When forming C H^H, T gets transposed via trans.
        
        // W = C V
        magma_dgemm_batched( MagmaNoTrans, notransV,
                     m, k, n,
                     c_one,  dC_array,    lddc,
                             dV_array,    lddv,
                     c_zero, dwork_array, ldw,
                     batchCount, queue );

        if (m <= n) {
            // W2 = W T = C V T
            magma_dgemm_batched( MagmaNoTrans, trans,
                         m, k, k,
                         c_one,  dwork_array, ldw,
                                 dT_array, lddt,
                         c_zero, dworkvt_array, ldwvt,
                         batchCount, queue );

            // C = C - W2 V^H = C - C V T V^H = C (I - V T V^H) = C H
            magma_dgemm_batched( MagmaNoTrans, transV,
                         m, n, k,
                         c_neg_one, dworkvt_array, ldwvt,
                                    dV_array,    lddv,
                         c_one,     dC_array,    lddc,
                         batchCount, queue );
        }
        else {
            // W2 = T V^H
            magma_dgemm_batched( trans, transV,
                         k, n, k,
                         c_one,  dT_array, lddt,
                                 dV_array, lddv,
                         c_zero, dworkvt_array, ldwvt,
                         batchCount, queue );

            // C = C - W W2 = C - C V T V^H = C (I - V T V^H) = C H
            magma_dgemm_batched( MagmaNoTrans, MagmaNoTrans,
                         m, n, k,
                         c_neg_one, dwork_array,   ldw,
                                    dworkvt_array, ldwvt,
                         c_one,     dC_array,      lddc,
                         batchCount, queue );
        }
    }

    return info;
} /* magma_dlarfb */
Exemple #3
0
/**
    \n
    This is an internal routine.
    ********************************************************************/
extern "C" magma_int_t
magma_dpotrf_recpanel_batched(
    magma_uplo_t uplo, magma_int_t m, magma_int_t n, 
    magma_int_t min_recpnb,    
    double** dA_array,    magma_int_t ldda,
    double** dX_array,    magma_int_t dX_length,
    double** dinvA_array, magma_int_t dinvA_length,
    double** dW0_displ, double** dW1_displ,  
    double** dW2_displ, double** dW3_displ,
    double** dW4_displ,
    magma_int_t *info_array, magma_int_t gbstep, 
    magma_int_t batchCount, magma_queue_t queue)
{
    magma_int_t arginfo = 0;
    // Quick return if possible
    if (m == 0 || n == 0) {
        return arginfo;
    }
    if (uplo == MagmaUpper) {
        printf("Upper side is unavailable \n");
        arginfo = -1;
        magma_xerbla( __func__, -(arginfo) );
        return arginfo;
    }
    if (m < n) {
        printf("error m < n %d < %d \n", (int) m, (int) n);
        arginfo = -101;
        magma_xerbla( __func__, -(arginfo) );
        return arginfo;
    }

    double **dA_displ  = NULL;
    magma_malloc((void**)&dA_displ,   batchCount * sizeof(*dA_displ));

    double alpha = MAGMA_D_NEG_ONE;
    double beta  = MAGMA_D_ONE;
    magma_int_t panel_nb = n;
    if (panel_nb <= min_recpnb) {
        //printf("calling bottom panel recursive with m=%d nb=%d\n",m,n);
        //  panel factorization
        magma_ddisplace_pointers(dA_displ, dA_array, ldda, 0, 0, batchCount, queue);
        //magma_dpotrf_rectile_batched(uplo, m, panel_nb, 16,
        arginfo = magma_dpotrf_panel_batched(uplo, m, panel_nb,
                           dA_displ, ldda,
                           dX_array, dX_length,
                           dinvA_array, dinvA_length,
                           dW0_displ, dW1_displ, dW2_displ,
                           dW3_displ, dW4_displ,
                           info_array, gbstep,
                           batchCount, queue);
    }
    else {
        // split A over two [A A2]
        // panel on A1, update on A2 then panel on A1    
        magma_int_t n1 = n/2;
        magma_int_t n2 = n-n1;
        magma_int_t m1 = m;
        magma_int_t m2 = m-n1;
        magma_int_t p1 = 0;
        magma_int_t p2 = n1;
        // panel on A1
        //printf("calling recursive panel on A1 with m=%d nb=%d min_recpnb %d\n",m1,n1,min_recpnb);
        magma_ddisplace_pointers(dA_displ, dA_array, ldda, p1, p1, batchCount, queue);        
        arginfo = magma_dpotrf_recpanel_batched(
                           uplo, m1, n1, min_recpnb,
                           dA_displ, ldda,
                           dX_array, dX_length,
                           dinvA_array, dinvA_length,
                           dW0_displ, dW1_displ, dW2_displ,
                           dW3_displ, dW4_displ, 
                           info_array, gbstep,
                           batchCount, queue);
        if (arginfo != 0) {
            magma_free(dA_displ);
            return arginfo;
        }

        // update A2
        //printf("calling update A2 with             m=%d n=%d k=%d\n",m2,n2,n1);
        magma_ddisplace_pointers(dA_displ,  dA_array, ldda, p1+n1, p1, batchCount, queue);        
        magma_ddisplace_pointers(dW0_displ, dA_array, ldda, p1+n1, p2, batchCount, queue);        
        magma_dgemm_batched( MagmaNoTrans, MagmaConjTrans, m2, n2, n1,
                             alpha, dA_displ, ldda, 
                             dA_displ, ldda, 
                             beta,  dW0_displ, ldda, 
                             batchCount, queue );
        // panel on A2
        //printf("calling recursive panel on A2 with m=%d nb=%d min_recpnb %d\n",m2,n2,min_recpnb);
        magma_ddisplace_pointers(dA_displ, dA_array, ldda, p2, p2, batchCount, queue);        
        arginfo = magma_dpotrf_recpanel_batched(
                                      uplo, m2, n2, min_recpnb,
                                      dA_displ, ldda,
                                      dX_array, dX_length,
                                      dinvA_array, dinvA_length,
                                      dW0_displ, dW1_displ, dW2_displ,
                                      dW3_displ, dW4_displ,
                                      info_array, gbstep,
                                      batchCount, queue);
    }

    magma_free(dA_displ);
    return arginfo;
}
Exemple #4
0
/**
    \n
    This is an internal routine.
    ********************************************************************/
extern "C" magma_int_t
magma_dpotrf_rectile_batched(
    magma_uplo_t uplo, magma_int_t m, magma_int_t n, 
    magma_int_t min_recpnb,    
    double** dA_array,    magma_int_t ldda,
    double** dX_array,    magma_int_t dX_length,
    double** dinvA_array, magma_int_t dinvA_length,
    double** dW0_displ, double** dW1_displ,  
    double** dW2_displ, double** dW3_displ,
    double** dW4_displ,
    magma_int_t *info_array, magma_int_t gbstep,
    magma_int_t batchCount, magma_queue_t queue)
{
    //magma_int_t DEBUG=0;

    // Quick return if possible
    if (m == 0 || n == 0) {
        return 1;
    }
    if (uplo == MagmaUpper) {
        printf("Upper side is unavailable \n");
        return -100;
    }
    if (m < n) {
        printf("error m < n %d < %d \n", (int) m, (int) n);
        return -101;
    }

    double **dA_displ  = NULL;
    magma_malloc((void**)&dA_displ,   batchCount * sizeof(*dA_displ));

    double alpha = MAGMA_D_NEG_ONE;
    double beta  = MAGMA_D_ONE;
    magma_int_t panel_nb = n;
    if (panel_nb <= min_recpnb) {
        // if (DEBUG == 1) printf("calling bottom panel recursive with n=%d\n",(int) panel_nb);
        //  panel factorization
        magma_ddisplace_pointers(dA_displ, dA_array, ldda, 0, 0, batchCount, queue);
        magma_dpotrf_panel_batched(
                           uplo, m, panel_nb,
                           dA_displ, ldda,
                           dX_array, dX_length,
                           dinvA_array, dinvA_length,
                           dW0_displ, dW1_displ, dW2_displ,
                           dW3_displ, dW4_displ,
                           info_array, gbstep,
                           batchCount, queue);
    }
    else {
        // split A over two [A11 A12;  A21 A22; A31 A32]
        // panel on tile A11, 
        // trsm on A21, using A11
        // update on A22 then panel on A22.  
        // finally a trsm on [A31 A32] using the whole [A11 A12; A21 A22]     
        magma_int_t n1 = n/2;
        magma_int_t n2 = n-n1;
        magma_int_t p1 = 0;
        magma_int_t p2 = n1;

        // panel on A11
        //if (DEBUG == 1) printf("calling recursive panel on A11=A(%d,%d) with n=%d min_recpnb %d\n",(int) p1, (int) p1, (int) n1, (int) min_recpnb);
        magma_ddisplace_pointers(dA_displ, dA_array, ldda, p1, p1, batchCount, queue);        
        magma_dpotrf_rectile_batched(
                           uplo, n1, n1, min_recpnb,
                           dA_displ, ldda,
                           dX_array, dX_length,
                           dinvA_array, dinvA_length,
                           dW0_displ, dW1_displ, dW2_displ,
                           dW3_displ, dW4_displ, 
                           info_array, gbstep,
                           batchCount, queue);

        // TRSM on A21
        //if (DEBUG == 1) printf("calling trsm on A21=A(%d,%d) using A11 == A(%d,%d) with m=%d k=%d \n",p2,p1,p1,p1,n2,n1);
        magma_ddisplace_pointers(dA_displ,  dA_array, ldda, p1, p1, batchCount, queue);        
        magma_ddisplace_pointers(dW0_displ, dA_array, ldda, p2, p1, batchCount, queue);
        magmablas_dtrsm_work_batched( MagmaRight, MagmaLower, MagmaConjTrans, MagmaNonUnit,
                              1, n2, n1, 
                              MAGMA_D_ONE,
                              dA_displ,    ldda, 
                              dW0_displ,   ldda, 
                              dX_array,    n2, 
                              dinvA_array, dinvA_length, 
                              dW1_displ,   dW2_displ, 
                              dW3_displ,   dW4_displ,
                              0, batchCount, queue );
        // update A22
        //if (DEBUG == 1) printf("calling update A22=A(%d,%d) using A21 == A(%d,%d) with m=%d n=%d k=%d\n",p2,p2,p2,p1,n2,n2,n1);
        magma_ddisplace_pointers(dA_displ,  dA_array, ldda, p2, p1, batchCount, queue);        
        magma_ddisplace_pointers(dW0_displ, dA_array, ldda, p2, p2, batchCount, queue);        // NEED TO BE REPLACED BY HERK
        magma_dgemm_batched( MagmaNoTrans, MagmaConjTrans, n2, n2, n1,
                             alpha, dA_displ, ldda, 
                             dA_displ, ldda, 
                             beta,  dW0_displ, ldda, 
                             batchCount, queue );

        // panel on A22
        //if (DEBUG == 1) printf("calling recursive panel on A22=A(%d,%d) with n=%d min_recpnb %d\n",p2,p2,n2,min_recpnb);
        magma_ddisplace_pointers(dA_displ, dA_array, ldda, p2, p2, batchCount, queue);        
        magma_dpotrf_rectile_batched(
                           uplo, n2, n2, min_recpnb,
                           dA_displ, ldda,
                           dX_array, dX_length,
                           dinvA_array, dinvA_length,
                           dW0_displ, dW1_displ, dW2_displ,
                           dW3_displ, dW4_displ, 
                           info_array, gbstep,
                           batchCount, queue);
    }

    if (m > n) {
        // TRSM on A3:
        //if (DEBUG == 1) printf("calling trsm AT THE END on A3=A(%d,%d): using A1222 == A(%d,%d) with m=%d k=%d \n",n,0,0,0,m-n,n);
        magma_ddisplace_pointers(dA_displ,  dA_array, ldda, 0, 0, batchCount, queue);        
        magma_ddisplace_pointers(dW0_displ, dA_array, ldda, n, 0, batchCount, queue);
        magmablas_dtrsm_work_batched( MagmaRight, MagmaLower, MagmaConjTrans, MagmaNonUnit,
                              1, m-n, n, 
                              MAGMA_D_ONE,
                              dA_displ,    ldda, 
                              dW0_displ,   ldda, 
                              dX_array,    m-n, 
                              dinvA_array, dinvA_length, 
                              dW1_displ,   dW2_displ, 
                              dW3_displ,   dW4_displ,
                              0, batchCount, queue );
    }

    magma_free(dA_displ);
    return 0;
}
Exemple #5
0
extern "C" magma_int_t
magma_dgetf2_batched(
    magma_int_t m, magma_int_t n,
    double **dA_array, magma_int_t ldda,
    double **dW0_displ,
    double **dW1_displ,
    double **dW2_displ,
    magma_int_t **ipiv_array,
    magma_int_t *info_array,
    magma_int_t gbstep,
    magma_int_t batchCount,
    magma_queue_t queue)
{
    magma_int_t arginfo = 0;
    if (m < 0) {
        arginfo = -1;
    } else if (n < 0 ) {
        arginfo = -2;
    } else if (ldda < max(1,m)) {
        arginfo = -4;
    }

    if (arginfo != 0) {
        magma_xerbla( __func__, -(arginfo) );
        return arginfo;
    }

    // Quick return if possible
    if (m == 0 || n == 0) {
        return arginfo;
    }

    double c_neg_one = MAGMA_D_NEG_ONE;
    double c_one     = MAGMA_D_ONE;
    magma_int_t nb = BATF2_NB;

    

    magma_int_t min_mn = min(m, n);
    magma_int_t gbj, panelj, step, ib;

    for( panelj=0; panelj < min_mn; panelj += nb)
    {
        ib = min(nb, min_mn-panelj);

        for (step=0; step < ib; step++) {
            gbj = panelj+step;
            //size_t required_shmem_size = zamax*(sizeof(double)+sizeof(int)) + (m-panelj+2)*sizeof(double);
            //if ( (m-panelj) > 0)
            if ( (m-panelj) > MAX_NTHREADS)
            //if ( required_shmem_size >  (MAX_SHARED_ALLOWED*1024))
            {
                //printf("running non shared version\n");
                // find the max of the column gbj
                arginfo = magma_idamax_batched(m-gbj, dA_array, 1, gbj, ldda, ipiv_array, info_array, gbstep, batchCount, queue);
                if (arginfo != 0 ) return arginfo;
                // Apply the interchange to columns 1:N. swap the whole row
                arginfo = magma_dswap_batched(n, dA_array, ldda, gbj, ipiv_array, batchCount, queue);
                if (arginfo != 0 ) return arginfo;
                // Compute elements J+1:M of J-th column.
                if (gbj < m) {
                    arginfo = magma_dscal_dger_batched( m-gbj, ib-step, gbj, dA_array, ldda, info_array, gbstep, batchCount, queue );
                    if (arginfo != 0 ) return arginfo;
                }
            }
            else {
                //printf("running --- shared version\n");
                arginfo = magma_dcomputecolumn_batched(m-panelj, panelj, step, dA_array, ldda, ipiv_array, info_array, gbstep, batchCount, queue);
                if (arginfo != 0 ) return arginfo;
                // Apply the interchange to columns 1:N. swap the whole row
                arginfo = magma_dswap_batched(n, dA_array, ldda, gbj, ipiv_array, batchCount, queue);
                if (arginfo != 0 ) return arginfo;
            }
        }


        if ( (n-panelj-ib) > 0) {
            // continue the update of the selected ib row column panelj+ib:n(TRSM)
            magma_dgetf2trsm_batched(ib, n-panelj-ib, dA_array, panelj, ldda, batchCount, queue);
            // do the blocked DGER = DGEMM for the remaining panelj+ib:n columns
            magma_ddisplace_pointers(dW0_displ, dA_array, ldda, ib+panelj, panelj, batchCount, queue);
            magma_ddisplace_pointers(dW1_displ, dA_array, ldda, panelj, ib+panelj, batchCount, queue);
            magma_ddisplace_pointers(dW2_displ, dA_array, ldda, ib+panelj, ib+panelj, batchCount, queue);

            magma_dgemm_batched( MagmaNoTrans, MagmaNoTrans, m-(panelj+ib), n-(panelj+ib), ib,
                                 c_neg_one, dW0_displ, ldda,
                                            dW1_displ, ldda,
                                 c_one,     dW2_displ, ldda,
                                 batchCount, queue );
        }
    }

    //magma_free_cpu(cpuAarray);

    return 0;
}
Exemple #6
0
//===================================================================================================================
//===================================================================================================================
//===================================================================================================================
extern "C" magma_int_t
magma_dlarft_batched(magma_int_t n, magma_int_t k, magma_int_t stair_T, 
                double **v_array, magma_int_t ldv,
                double **tau_array, double **T_array, magma_int_t ldt, 
                double **work_array, magma_int_t lwork, 
                magma_int_t batchCount, magma_queue_t queue)
{
    double c_one  = MAGMA_D_ONE;
    double c_zero = MAGMA_D_ZERO;

    if ( k <= 0) return 0;
    if ( stair_T > 0 && k <= stair_T) return 0;

    magma_int_t maxnb = max_shared_bsiz;

    if ( lwork < k*ldt) 
    {
        magma_xerbla( __func__, -(10) );
        return -10;
    }

    if ( stair_T > 0 && stair_T > maxnb)
    { 
        magma_xerbla( __func__, -(3) );
        return -3;
    }
    magma_int_t DEBUG=0;
    magma_int_t nb = stair_T == 0 ? min(k,maxnb) : stair_T;

    magma_int_t i, j, prev_n, mycol, rows;

    double **dW1_displ  = NULL;
    double **dW2_displ  = NULL;
    double **dW3_displ  = NULL;
    double **dTstep_array  = NULL;

    magma_malloc((void**)&dW1_displ,  batchCount * sizeof(*dW1_displ));
    magma_malloc((void**)&dW2_displ,  batchCount * sizeof(*dW2_displ));
    magma_malloc((void**)&dW3_displ,  batchCount * sizeof(*dW3_displ));
    magma_malloc((void**)&dTstep_array,  batchCount * sizeof(*dTstep_array));

    //double *Tstep =  k > nb ? work : T;
    if (k > nb)
    {
        magma_ddisplace_pointers(dTstep_array, work_array, lwork, 0, 0, batchCount, queue);
    }
    else
    {
        magma_ddisplace_pointers(dTstep_array, T_array, ldt, 0, 0, batchCount, queue);
    }

    //magma_int_t ldtstep = k > nb ? k : ldt;
    magma_int_t ldtstep = ldt; //a enlever
    // stair_T = 0 meaning all T
    // stair_T > 0 meaning the triangular portion of T has been computed. 
    //                    the value of stair_T is the nb of these triangulars
   

    //GEMV compute the whole triangular upper portion of T (phase 1)
    // TODO addcublas to check perf

    magma_dgemm_batched( MagmaConjTrans, MagmaNoTrans, 
                         k, k, n, 
                         c_one,  v_array, ldv, 
                                 v_array, ldv, 
                         c_zero, dTstep_array, ldtstep, 
                         batchCount, queue );

    magmablas_dlaset_batched( MagmaLower, k, k, MAGMA_D_ZERO, MAGMA_D_ZERO, dTstep_array, ldtstep, batchCount, queue );
    // no need for it as T is expected to be lower zero
    //if (k > nb) magmablas_dlaset_batched( MagmaLower, k, k, MAGMA_D_ZERO, MAGMA_D_ZERO, dTstep_array, ldtstep, batchCount, queue );
    

    //TRMV
    //T(1:i-1,i) := T(1:i-1,1:i-1) * W(1:i-1) i=[1:k]
    // TRMV is split over block of column of size nb 
    // the update should be done from top to bottom so:
    // 1- a gemm using the previous computed columns
    //    of T to update rectangular upper protion above 
    //    the triangle of my columns 
    // 2- the columns need to be updated by a serial 
    //    loop over of gemv over itself. since we limit the
    //    shared memory to nb, this nb column 
    //    are split vertically by chunk of nb rows

    dim3 grid(1, 1, batchCount);

    for (j=0; j < k; j += nb)
    {
        prev_n =  j;
        mycol  =  min(nb, k-j);
        // note that myrow = prev_n + mycol;
        if (prev_n > 0 && mycol > 0) {
            if (DEBUG == 3) {
                printf("doing gemm on the rectangular portion of size %d %d of T(%d,%d)\n",
                        (int) prev_n, (int) mycol, 0, (int) j );
            }

            magma_ddisplace_pointers(dW1_displ, dTstep_array, ldtstep, 0, j, batchCount, queue);
            magma_ddisplace_pointers(dW2_displ, T_array,     ldt, 0, j, batchCount, queue);
            magma_dgemm_batched( MagmaNoTrans, MagmaNoTrans, 
                                 prev_n, mycol, prev_n, 
                                 c_one,  T_array, ldt, 
                                         dW1_displ, ldtstep, 
                                 c_zero, dW2_displ, ldt, 
                                 batchCount, queue );

            // update my rectangular portion (prev_n,mycol) using sequence of gemv 
            magma_ddisplace_pointers(dW1_displ, dTstep_array, ldtstep, j, j, batchCount, queue);
            magma_ddisplace_pointers(dW3_displ, tau_array,  1, j, 0, batchCount, queue);

            for (i=0; i < prev_n; i += nb)
            {
                rows = min(nb,prev_n-i);
                if (DEBUG == 3) {
                    printf("        doing recdtrmv on the rectangular portion of size %d %d of T(%d,%d)\n",
                            (int) rows, (int) mycol, (int) i, (int) j );
                }

                if (rows > 0 && mycol > 0)
                {
                    magma_ddisplace_pointers(dW2_displ, T_array,     ldt, i, j, batchCount, queue);
                    magmablas_dlarft_recdtrmv_sm32x32_batched(rows, mycol, dW3_displ, dW2_displ, ldt, dW1_displ, ldtstep, batchCount, queue);
                }
            }
        }

        // the upper rectangular protion is updated, now if needed update the triangular portion
        if (stair_T == 0) {
            if (DEBUG == 3) {
                printf("doing dtrmv on the triangular portion of size %d %d of T(%d,%d)\n",
                        (int) mycol, (int) mycol, (int) j, (int) j );
            }

            if (mycol > 0)
            {
                magma_ddisplace_pointers(dW1_displ, dTstep_array, ldtstep, j, j, batchCount, queue);
                magma_ddisplace_pointers(dW3_displ, tau_array,  1, j, 0, batchCount, queue);
                magma_ddisplace_pointers(dW2_displ, T_array,     ldt, j, j, batchCount, queue);
                magmablas_dlarft_dtrmv_sm32x32_batched(mycol, mycol, dW3_displ, dW1_displ, ldtstep, dW2_displ, ldt, batchCount, queue);
            }
        }
    }// end of j

    magma_free(dW1_displ);
    magma_free(dW2_displ);
    magma_free(dW3_displ);
    magma_free(dTstep_array);

    return 0;
}
Exemple #7
0
extern "C" magma_int_t
magma_dpotf2_batched(
    magma_uplo_t uplo, magma_int_t m, magma_int_t n,
    double **dA_array, magma_int_t lda,
    double **dA_displ, 
    double **dW_displ,
    double **dB_displ, 
    double **dC_displ, 
    magma_int_t *info_array, magma_int_t gbstep, 
    magma_int_t batchCount, magma_queue_t queue)
{
    magma_int_t arginfo=0;

    // Quick return if possible
    if (n == 0) {
        return 1;
    }

    double alpha = MAGMA_D_NEG_ONE;
    double beta  = MAGMA_D_ONE;


    magma_int_t nb = POTF2_NB;
    magma_int_t j, ib, rows;
    magma_int_t crossover = magma_get_dpotrf_batched_crossover();

    if (uplo == MagmaUpper) {
        printf("Upper side is unavailable\n");
    }
    else {
        if ( n <= crossover )
        {
            arginfo = magma_dpotrf_lpout_batched(uplo, n, dA_array, lda, gbstep, info_array, batchCount, queue);
        } else {
            for (j = 0; j < n; j += nb) {
                ib   = min(nb, n-j);
                rows = m-j;
                if ( (rows <= POTF2_TILE_SIZE) && (ib <= POTF2_TILE_SIZE) ) {
                    magma_ddisplace_pointers(dA_displ, dA_array, lda, j, j, batchCount, queue);
                    arginfo = magma_dpotf2_tile_batched(
                                   uplo, rows, ib,
                                   dA_displ, lda,
                                   info_array, gbstep, batchCount, queue);
                }
                else {
                    magma_ddisplace_pointers(dA_displ, dA_array, lda, j, j, batchCount, queue); 
                    magma_dpotf2_dtrsm_batched(
                              uplo, rows, ib,
                              dA_displ, lda,
                              dW_displ, dB_displ, dC_displ, 
                              info_array, gbstep, batchCount, queue);
                }
                #if 1
                //#define RIGHT_LOOKING
                if ( (n-j-ib) > 0) {
                    #ifdef RIGHT_LOOKING
                    magma_ddisplace_pointers(dA_displ, dA_array, lda, j+ib, j, batchCount, queue);
                    magma_ddisplace_pointers(dC_displ, dA_array, lda, j+ib, j+ib, batchCount, queue);
                    magma_dgemm_batched( MagmaNoTrans, MagmaConjTrans,
                                 m-j-ib, n-j-ib, ib,
                                 alpha, dA_displ, lda,
                                        dA_displ, lda,
                                 beta,  dC_displ, lda, batchCount, queue );
                #else
                    // update next subpanel
                    magma_ddisplace_pointers(dA_displ, dA_array, lda, j+ib, 0, batchCount, queue);
                    magma_ddisplace_pointers(dC_displ, dA_array, lda, j+ib, j+ib, batchCount, queue);
                    magma_dgemm_batched( MagmaNoTrans, MagmaConjTrans,
                                 m-j-ib, min((n-j-ib),ib), j+ib,
                                 alpha, dA_displ, lda,
                                        dA_displ, lda,
                                 beta,  dC_displ, lda, batchCount, queue );
                #endif
                } // end of if ( (n-j-ib) > 0)
                #endif
            }
        }
    }

    return arginfo;
}
extern "C" magma_int_t
magma_dgetf2_nopiv_batched(
    magma_int_t m, magma_int_t n,
    double **dA_array, magma_int_t ldda,
    double **dW0_displ,
    double **dW1_displ,
    double **dW2_displ,
    magma_int_t *info_array,            
    magma_int_t gbstep, 
    magma_int_t batchCount, magma_queue_t queue)
{
    magma_int_t arginfo = 0;
    if (m < 0) {
        arginfo = -1;
    } else if (n < 0 ) {
        arginfo = -2;
    } else if (ldda < max(1,m)) {
        arginfo = -4;
    }

    if (arginfo != 0) {
        magma_xerbla( __func__, -(arginfo) );
        return arginfo;
    }

    // Quick return if possible
    if (m == 0 || n == 0) {
        return arginfo;
    }

    double c_neg_one = MAGMA_D_NEG_ONE;
    double c_one     = MAGMA_D_ONE;
    magma_int_t nb = BATF2_NB;

    
    magma_int_t min_mn = min(m, n);
    magma_int_t gbj, panelj, step, ib;

    for( panelj=0; panelj < min_mn; panelj += nb) 
    {
        ib = min(nb, min_mn-panelj);

        for (step=0; step < ib; step++) {
            gbj = panelj+step;
#if 0
            size_t required_shmem_size = ((m-panelj)*ib)*sizeof(double);
            if ( required_shmem_size >  (MAX_SHARED_ALLOWED*1024))
#else
            if ( (m-panelj) > 0)
#endif
            {
                // Compute elements J+1:M of J-th column.
                if (gbj < m) {
                    arginfo = magma_dscal_dger_batched( m-gbj, ib-step, gbj, dA_array, ldda, info_array, gbstep, batchCount, queue );
                    if (arginfo != 0 ) return arginfo;
                }
            }
            else {
                // TODO
            }
        }


        if ( (n-panelj-ib) > 0) {
            // continue the update of the selected ib row column panelj+ib:n(TRSM)
            magma_dgetf2trsm_batched(ib, n-panelj-ib, dA_array, panelj, ldda, batchCount, queue);
            // do the blocked DGER = DGEMM for the remaining panelj+ib:n columns
            magma_ddisplace_pointers(dW0_displ, dA_array, ldda, ib+panelj, panelj, batchCount, queue);
            magma_ddisplace_pointers(dW1_displ, dA_array, ldda, panelj, ib+panelj, batchCount, queue);            
            magma_ddisplace_pointers(dW2_displ, dA_array, ldda, ib+panelj, ib+panelj, batchCount, queue);

            magma_dgemm_batched( MagmaNoTrans, MagmaNoTrans, m-(panelj+ib), n-(panelj+ib), ib, 
                                 c_neg_one, dW0_displ, ldda, 
                                            dW1_displ, ldda, 
                                 c_one,     dW2_displ, ldda, 
                                 batchCount, queue );
        }
    }

    //magma_free_cpu(cpuAarray);

    return 0;
}