//=================================================================================================================== //=================================================================================================================== //=================================================================================================================== extern "C" void magma_clarft_sm32x32_batched(magma_int_t n, magma_int_t k, magmaFloatComplex **v_array, magma_int_t ldv, magmaFloatComplex **tau_array, magmaFloatComplex **T_array, magma_int_t ldt, magma_int_t batchCount, magma_queue_t queue) { if ( k <= 0) return; //================================== // GEMV //================================== #define USE_GEMV2 #define use_gemm_larft_sm32 #if defined(use_gemm_larft_sm32) magma_cgemm_batched( MagmaConjTrans, MagmaNoTrans, k, k, n, MAGMA_C_ONE, v_array, ldv, v_array, ldv, MAGMA_C_ZERO, T_array, ldt, batchCount, queue ); magmablas_claset_batched( MagmaLower, k, k, MAGMA_C_ZERO, MAGMA_C_ZERO, T_array, ldt, batchCount, queue ); #else #if 1 for (magma_int_t i=0; i < k; i++) { //W(1:i-1) := - tau(i) * V(i:n,1:i-1)' * V(i:n,i) //T( i, i ) = tau( i ) //custom implementation. #ifdef USE_GEMV2 magmablas_clarft_gemvrowwise_batched( n-i, i, tau_array, v_array, ldv, T_array, ldt, batchCount, queue); #else magmablas_clarft_gemvcolwise_batched( n-i, i, v_array, ldv, T_array, ldt, tau_array, batchCount, queue); #endif } #else //seems to be very slow when k=32 while the one by one loop above is faster clarft_gemv_loop_inside_kernel_batched(n, k, tau_array, v_array, ldv, T_array, ldt, batchCount, queue); #endif #endif //================================== // TRMV //================================== //T(1:i-1,i) := T(1:i-1,1:i-1) * W(1:i-1) i=[1:k] magmablas_clarft_ctrmv_sm32x32_batched(k, k, tau_array, T_array, ldt, T_array, ldt, batchCount, queue); }
/** Purpose ------- CLARFB applies a complex block reflector H or its transpose H^H to a COMPLEX m by n matrix C, from the left. __Note that this function assumes__ that the upper part of dV_array is 0 because it is referenced. Same for upper/lower part of dT_array. Arguments --------- @param[in] side magma_side_t - = MagmaLeft: apply H or H^H from the Left - = MagmaRight: apply H or H^H from the Right @param[in] trans magma_trans_t - = MagmaNoTrans: apply H (No transpose) - = Magma_ConjTrans: apply H^H (Conjugate transpose) @param[in] direct magma_direct_t Indicates how H is formed from a product of elementary reflectors - = MagmaForward: H = H(1) H(2) . . . H(k) (Forward) - = MagmaBackward: H = H(k) . . . H(2) H(1) (Backward) @param[in] storev magma_storev_t Indicates how the vectors which define the elementary reflectors are stored: - = MagmaColumnwise: Columnwise - = MagmaRowwise: Rowwise @param[in] m INTEGER The number of rows of the matrix C. @param[in] n INTEGER The number of columns of the matrix C. @param[in] k INTEGER The order of the matrix T (= the number of elementary reflectors whose product defines the block reflector). @param[in] dV_array COMPLEX array on the GPU, dimension (LDDV,K) if STOREV = MagmaColumnwise (LDDV,M) if STOREV = MagmaRowwise and SIDE = MagmaLeft (LDDV,N) if STOREV = MagmaRowwise and SIDE = MagmaRight The matrix V. See further details. @param[in] lddv INTEGER The leading dimension of the array V. If STOREV = MagmaColumnwise and SIDE = MagmaLeft, LDDV >= max(1,M); if STOREV = MagmaColumnwise and SIDE = MagmaRight, LDDV >= max(1,N); if STOREV = MagmaRowwise, LDDV >= K. @param[in] dT_array COMPLEX array on the GPU, dimension (LDDT,K) The triangular k by k matrix T in the representation of the block reflector. @param[in] lddt INTEGER The leading dimension of the array T. LDDT >= K. @param[in,out] dC_array COMPLEX array on the GPU, dimension (LDDC,N) On entry, the m by n matrix C. On exit, C is overwritten by H*C, or H^H*C, or C*H, or C*H^H. @param[in] lddc INTEGER The leading dimension of the array C. LDA >= max(1,M). @param dwork_array (workspace) COMPLEX array, dimension (LDWORK,K) @param[in] ldwork INTEGER The leading dimension of the array WORK. If SIDE = MagmaLeft, LDWORK >= max(1,N); if SIDE = MagmaRight, LDWORK >= max(1,M); @param dworkvt_array (workspace) COMPLEX array, dimension (LDWORKT,K) @param[in] ldworkvt INTEGER The leading dimension of the array WORKVT. LDWORKVT >= max(1,min(M,N)); @param[in] batchCount INTEGER The number of matrices to operate on. @param[in] queue magma_queue_t Queue to execute in. Further Details --------------- The shape of the matrix V and the storage of the vectors which define the H(i) is best illustrated by the following example with n = 5 and k = 3. All elements including 0's and 1's are stored, unlike LAPACK. DIRECT = MagmaForward and DIRECT = MagmaForward and STOREV = MagmaColumnwise: STOREV = MagmaRowwise: V = ( 1 0 0 ) V = ( 1 v1 v1 v1 v1 ) ( v1 1 0 ) ( 0 1 v2 v2 v2 ) ( v1 v2 1 ) ( 0 0 1 v3 v3 ) ( v1 v2 v3 ) ( v1 v2 v3 ) DIRECT = MagmaBackward and DIRECT = MagmaBackward and STOREV = MagmaColumnwise: STOREV = MagmaRowwise: V = ( v1 v2 v3 ) V = ( v1 v1 1 0 0 ) ( v1 v2 v3 ) ( v2 v2 v2 1 0 ) ( 1 v2 v3 ) ( v3 v3 v3 v3 1 ) ( 0 1 v3 ) ( 0 0 1 ) @ingroup magma_caux3 ********************************************************************/ extern "C" magma_int_t magma_clarfb_gemm_batched( magma_side_t side, magma_trans_t trans, magma_direct_t direct, magma_storev_t storev, magma_int_t m, magma_int_t n, magma_int_t k, magmaFloatComplex_const_ptr dV_array[], magma_int_t lddv, magmaFloatComplex_const_ptr dT_array[], magma_int_t lddt, magmaFloatComplex_ptr dC_array[], magma_int_t lddc, magmaFloatComplex_ptr dwork_array[], magma_int_t ldwork, magmaFloatComplex_ptr dworkvt_array[], magma_int_t ldworkvt, magma_int_t batchCount, magma_queue_t queue) { magmaFloatComplex c_zero = MAGMA_C_ZERO; magmaFloatComplex c_one = MAGMA_C_ONE; magmaFloatComplex c_neg_one = MAGMA_C_NEG_ONE; /* Function Body */ magma_int_t info = 0; if (m <= 0 || n <= 0) { return info; } // internal variable magma_int_t ldwvt = (m > n ? k : m); magma_int_t ldw; if ( side == MagmaLeft ) { ldw = k; } else { ldw = m; } // opposite of trans magma_trans_t transt; if (trans == MagmaNoTrans) transt = Magma_ConjTrans; else transt = MagmaNoTrans; MAGMA_UNUSED( transt ); // TODO: is this a bug that it isn't used? // whether V is stored transposed or not magma_trans_t notransV, transV; if (storev == MagmaColumnwise) { notransV = MagmaNoTrans; transV = Magma_ConjTrans; } else { notransV = Magma_ConjTrans; transV = MagmaNoTrans; } if ( side == MagmaLeft ) { // Form H C or H^H C // Comments assume H C. // When forming H^H C, T gets transposed via transt for m >= n or by trans for m < n. // W = V' C magma_cgemm_batched( Magma_ConjTrans,notransV, /*NontransLeft*/ k, n, m, c_one, dV_array, lddv, dC_array, lddc, c_zero, dwork_array, ldw, batchCount, queue ); if (m <= n) { // W2 = V T magma_cgemm_batched( notransV, trans, /* (NoTrans), trans(ConjTrans),*/ m, k, k, c_one, dV_array, lddv, dT_array, lddt, c_zero, dworkvt_array, ldwvt, batchCount, queue ); // C = C - W2 W = C - V T V' C = (I - V T V') C = H C magma_cgemm_batched( MagmaNoTrans, MagmaNoTrans, m, n, k, c_neg_one, dworkvt_array, ldwvt, dwork_array, ldw, c_one, dC_array, lddc, batchCount, queue ); } else { // W2 = T W = T V' C magma_cgemm_batched( trans, MagmaNoTrans, k, n, k, c_one, dT_array, lddt, dwork_array, ldw, c_zero, dworkvt_array, ldwvt, batchCount, queue ); // C = C - V W2 = C - V T V' C = (I - V T V') C = H C magma_cgemm_batched( notransV, MagmaNoTrans, m, n, k, c_neg_one, dV_array, lddv, dworkvt_array, ldwvt, c_one, dC_array, lddc, batchCount, queue ); } } else { // Form C H or C H^H // Comments assume C H. // When forming C H^H, T gets transposed via trans. // W = C V magma_cgemm_batched( MagmaNoTrans, notransV, m, k, n, c_one, dC_array, lddc, dV_array, lddv, c_zero, dwork_array, ldw, batchCount, queue ); if (m <= n) { // W2 = W T = C V T magma_cgemm_batched( MagmaNoTrans, trans, m, k, k, c_one, dwork_array, ldw, dT_array, lddt, c_zero, dworkvt_array, ldwvt, batchCount, queue ); // C = C - W2 V' = C - C V T V' = C (I - V T V') = C H magma_cgemm_batched( MagmaNoTrans, transV, m, n, k, c_neg_one, dworkvt_array, ldwvt, dV_array, lddv, c_one, dC_array, lddc, batchCount, queue ); } else { // W2 = T V' magma_cgemm_batched( trans, transV, k, n, k, c_one, dT_array, lddt, dV_array, lddv, c_zero, dworkvt_array, ldwvt, batchCount, queue ); // C = C - W W2 = C - C V T V' = C (I - V T V') = C H magma_cgemm_batched( MagmaNoTrans, MagmaNoTrans, m, n, k, c_neg_one, dwork_array, ldw, dworkvt_array, ldwvt, c_one, dC_array, lddc, batchCount, queue ); } } return MAGMA_SUCCESS; } /* magma_clarfb */
/** Purpose ------- This is an internal routine that might have many assumption. Documentation is not fully completed CGETRF_PANEL computes an LU factorization of a general M-by-N matrix A using partial pivoting with row interchanges. The factorization has the form A = P * L * U where P is a permutation matrix, L is lower triangular with unit diagonal elements (lower trapezoidal if m > n), and U is upper triangular (upper trapezoidal if m < n). This is the right-looking Level 3 BLAS version of the algorithm. This is a batched version that factors batchCount M-by-N matrices in parallel. dA, ipiv, and info become arrays with one entry per matrix. Arguments --------- @param[in] m INTEGER The number of rows of each matrix A. M >= 0. @param[in] n INTEGER The number of columns of each matrix A. N >= 0. @param[in] min_recpnb INTEGER. Internal use. The recursive nb @param[in,out] dA_array Array of pointers, dimension (batchCount). Each is a COMPLEX array on the GPU, dimension (LDDA,N). On entry, each pointer is an M-by-N matrix to be factored. On exit, the factors L and U from the factorization A = P*L*U; the unit diagonal elements of L are not stored. @param[in] ldda INTEGER The leading dimension of each array A. LDDA >= max(1,M). @param[out] dipiv_array Array of pointers, dimension (batchCount), for corresponding matrices. Each is an INTEGER array, dimension (min(M,N)) The pivot indices; for 1 <= i <= min(M,N), row i of the matrix was interchanged with row IPIV(i). @param[out] dpivinfo_array Array of pointers, dimension (batchCount), for internal use. @param[in,out] dX_array Array of pointers, dimension (batchCount). Each is a COMPLEX array X of dimension ( lddx, n ). On entry, should be set to 0 On exit, the solution matrix X @param[in] dX_length INTEGER. The size of each workspace matrix dX @param[in,out] dinvA_array Array of pointers, dimension (batchCount). Each is a COMPLEX array dinvA, a workspace on device. If side == MagmaLeft, dinvA must be of size >= ceil(m/TRI_NB)*TRI_NB*TRI_NB, If side == MagmaRight, dinvA must be of size >= ceil(n/TRI_NB)*TRI_NB*TRI_NB, where TRI_NB = 128. @param[in] dinvA_length INTEGER The size of each workspace matrix dinvA @param[in] dW1_displ Workspace array of pointers, for internal use. @param[in] dW2_displ Workspace array of pointers, for internal use. @param[in] dW3_displ Workspace array of pointers, for internal use. @param[in] dW4_displ Workspace array of pointers, for internal use. @param[in] dW5_displ Workspace array of pointers, for internal use. @param[out] info_array Array of INTEGERs, dimension (batchCount), for corresponding matrices. - = 0: successful exit - < 0: if INFO = -i, the i-th argument had an illegal value or another error occured, such as memory allocation failed. - > 0: if INFO = i, U(i,i) is exactly zero. The factorization has been completed, but the factor U is exactly singular, and division by zero will occur if it is used to solve a system of equations. @param[in] gbstep INTEGER internal use. @param[in] batchCount INTEGER The number of matrices to operate on. @param[in] queue magma_queue_t Queue to execute in. @ingroup magma_cgesv_comp ********************************************************************/ extern "C" magma_int_t magma_cgetrf_recpanel_batched( magma_int_t m, magma_int_t n, magma_int_t min_recpnb, magmaFloatComplex** dA_array, magma_int_t ldda, magma_int_t** dipiv_array, magma_int_t** dpivinfo_array, magmaFloatComplex** dX_array, magma_int_t dX_length, magmaFloatComplex** dinvA_array, magma_int_t dinvA_length, magmaFloatComplex** dW1_displ, magmaFloatComplex** dW2_displ, magmaFloatComplex** dW3_displ, magmaFloatComplex** dW4_displ, magmaFloatComplex** dW5_displ, magma_int_t *info_array, magma_int_t gbstep, magma_int_t batchCount, magma_queue_t queue) { //magma_int_t DEBUG = 3; // Quick return if possible if (m == 0 || n == 0) { return 0; } magmaFloatComplex **dA_displ = NULL; magma_malloc((void**)&dA_displ, batchCount * sizeof(*dA_displ)); magma_int_t **dipiv_displ = NULL; magma_malloc((void**)&dipiv_displ, batchCount * sizeof(*dipiv_displ)); magma_int_t panel_nb = n; if (panel_nb <= min_recpnb) { //if (DEBUG > 0)printf("calling bottom panel recursive with m=%d nb=%d\n",m,n); // panel factorization //magma_cdisplace_pointers(dA_displ, dA_array, ldda, 0, 0, batchCount); magma_cgetf2_batched(m, panel_nb, dA_array, ldda, dW1_displ, dW2_displ, dW3_displ, dipiv_array, info_array, gbstep, batchCount, queue); } else { // split A over two [A A2] // panel on A1, update on A2 then panel on A1 magma_int_t n1 = n/2; magma_int_t n2 = n-n1; magma_int_t m1 = m; magma_int_t m2 = m-n1; magma_int_t p1 = 0; magma_int_t p2 = n1; // panel on A1 //if (DEBUG > 0)printf("calling recursive panel on A1 with m=%d nb=%d min_recpnb %d\n",m1,n1,min_recpnb); magma_cdisplace_pointers(dA_displ, dA_array, ldda, p1, p1, batchCount, queue); magma_idisplace_pointers(dipiv_displ, dipiv_array, 1, p1, 0, batchCount, queue); magma_cgetrf_recpanel_batched( m1, n1, min_recpnb, dA_displ, ldda, dipiv_displ, dpivinfo_array, dX_array, dX_length, dinvA_array, dinvA_length, dW1_displ, dW2_displ, dW3_displ, dW4_displ, dW5_displ, info_array, gbstep, batchCount, queue); // update A2 //if (DEBUG > 0)printf("calling TRSM with m=%d n=%d \n",m1,n2); // setup pivinfo setup_pivinfo_batched(dpivinfo_array, dipiv_displ, m1, n1, batchCount, queue); magma_cdisplace_pointers(dW5_displ, dA_array, ldda, p1, p2, batchCount, queue); magma_claswp_rowparallel_batched( n2, dW5_displ, ldda, dX_array, n1, 0, n1, dpivinfo_array, batchCount, queue ); magmablas_ctrsm_outofplace_batched( MagmaLeft, MagmaLower, MagmaNoTrans, MagmaUnit, 1, n1, n2, MAGMA_C_ONE, dA_displ, ldda, // dA dX_array, n1, // dB dW5_displ, ldda, // dX dinvA_array, dinvA_length, dW1_displ, dW2_displ, dW3_displ, dW4_displ, 0, batchCount, queue ); magma_cdisplace_pointers(dW1_displ, dA_array, ldda, p2, 0, batchCount, queue); magma_cdisplace_pointers(dA_displ, dA_array, ldda, p2, p2, batchCount, queue); //if (DEBUG > 0)printf("calling update A2(%d,%d) -= A(%d,%d)*A(%d,%d) with m=%d n=%d k=%d ldda %d\n",p2,p2,p2,0,p1,p2,m2,n2,n1,ldda); magma_cgemm_batched( MagmaNoTrans, MagmaNoTrans, m2, n2, n1, MAGMA_C_NEG_ONE, dW1_displ, ldda, dW5_displ, ldda, MAGMA_C_ONE, dA_displ, ldda, batchCount, queue ); // panel on A2 //if (DEBUG > 0)printf("calling recursive panel on A2 with m=%d nb=%d min_recpnb %d\n",m2,n2,min_recpnb); magma_idisplace_pointers(dipiv_displ, dipiv_array, 1, p2, 0, batchCount, queue); magma_cgetrf_recpanel_batched( m2, n2, min_recpnb, dA_displ, ldda, dipiv_displ, dpivinfo_array, dX_array, dX_length, dinvA_array, dinvA_length, dW1_displ, dW2_displ, dW3_displ, dW4_displ, dW5_displ, info_array, gbstep+p2, batchCount, queue); // setup pivinfo setup_pivinfo_batched(dpivinfo_array, dipiv_displ, m2, n2, batchCount, queue); adjust_ipiv_batched(dipiv_displ, n2, n1, batchCount, queue); magma_cdisplace_pointers(dW1_displ, dA_array, ldda, p2, 0, batchCount, queue); // no need since it is above magma_claswp_rowparallel_batched( n1, dW1_displ, ldda, dW1_displ, ldda, n1, n, dpivinfo_array, batchCount, queue ); } magma_free(dA_displ); magma_free(dipiv_displ); return 0; }
//=================================================================================================================== //=================================================================================================================== //=================================================================================================================== extern "C" magma_int_t magma_clarft_batched(magma_int_t n, magma_int_t k, magma_int_t stair_T, magmaFloatComplex **v_array, magma_int_t ldv, magmaFloatComplex **tau_array, magmaFloatComplex **T_array, magma_int_t ldt, magmaFloatComplex **work_array, magma_int_t lwork, magma_int_t batchCount, magma_queue_t queue) { magmaFloatComplex c_one = MAGMA_C_ONE; magmaFloatComplex c_zero = MAGMA_C_ZERO; if ( k <= 0) return 0; if ( stair_T > 0 && k <= stair_T) return 0; magma_int_t maxnb = max_shared_bsiz; if ( lwork < k*ldt) { magma_xerbla( __func__, -(10) ); return -10; } if ( stair_T > 0 && stair_T > maxnb) { magma_xerbla( __func__, -(3) ); return -3; } magma_int_t DEBUG=0; magma_int_t nb = stair_T == 0 ? min(k,maxnb) : stair_T; magma_int_t i, j, prev_n, mycol, rows; magmaFloatComplex **dW1_displ = NULL; magmaFloatComplex **dW2_displ = NULL; magmaFloatComplex **dW3_displ = NULL; magmaFloatComplex **dTstep_array = NULL; magma_malloc((void**)&dW1_displ, batchCount * sizeof(*dW1_displ)); magma_malloc((void**)&dW2_displ, batchCount * sizeof(*dW2_displ)); magma_malloc((void**)&dW3_displ, batchCount * sizeof(*dW3_displ)); magma_malloc((void**)&dTstep_array, batchCount * sizeof(*dTstep_array)); //magmaFloatComplex *Tstep = k > nb ? work : T; if (k > nb) { magma_cdisplace_pointers(dTstep_array, work_array, lwork, 0, 0, batchCount, queue); } else { magma_cdisplace_pointers(dTstep_array, T_array, ldt, 0, 0, batchCount, queue); } //magma_int_t ldtstep = k > nb ? k : ldt; magma_int_t ldtstep = ldt; //a enlever // stair_T = 0 meaning all T // stair_T > 0 meaning the triangular portion of T has been computed. // the value of stair_T is the nb of these triangulars //GEMV compute the whole triangular upper portion of T (phase 1) // TODO addcublas to check perf magma_cgemm_batched( MagmaConjTrans, MagmaNoTrans, k, k, n, c_one, v_array, ldv, v_array, ldv, c_zero, dTstep_array, ldtstep, batchCount, queue ); magmablas_claset_batched( MagmaLower, k, k, MAGMA_C_ZERO, MAGMA_C_ZERO, dTstep_array, ldtstep, batchCount, queue ); // no need for it as T is expected to be lower zero //if (k > nb) magmablas_claset_batched( MagmaLower, k, k, MAGMA_C_ZERO, MAGMA_C_ZERO, dTstep_array, ldtstep, batchCount, queue ); //TRMV //T(1:i-1,i) := T(1:i-1,1:i-1) * W(1:i-1) i=[1:k] // TRMV is split over block of column of size nb // the update should be done from top to bottom so: // 1- a gemm using the previous computed columns // of T to update rectangular upper protion above // the triangle of my columns // 2- the columns need to be updated by a serial // loop over of gemv over itself. since we limit the // shared memory to nb, this nb column // are split vertically by chunk of nb rows dim3 grid(1, 1, batchCount); for (j=0; j < k; j += nb) { prev_n = j; mycol = min(nb, k-j); // note that myrow = prev_n + mycol; if (prev_n > 0 && mycol > 0) { if (DEBUG == 3) { printf("doing gemm on the rectangular portion of size %d %d of T(%d,%d)\n", (int) prev_n, (int) mycol, 0, (int) j ); } magma_cdisplace_pointers(dW1_displ, dTstep_array, ldtstep, 0, j, batchCount, queue); magma_cdisplace_pointers(dW2_displ, T_array, ldt, 0, j, batchCount, queue); magma_cgemm_batched( MagmaNoTrans, MagmaNoTrans, prev_n, mycol, prev_n, c_one, T_array, ldt, dW1_displ, ldtstep, c_zero, dW2_displ, ldt, batchCount, queue ); // update my rectangular portion (prev_n,mycol) using sequence of gemv magma_cdisplace_pointers(dW1_displ, dTstep_array, ldtstep, j, j, batchCount, queue); magma_cdisplace_pointers(dW3_displ, tau_array, 1, j, 0, batchCount, queue); for (i=0; i < prev_n; i += nb) { rows = min(nb,prev_n-i); if (DEBUG == 3) { printf(" doing recctrmv on the rectangular portion of size %d %d of T(%d,%d)\n", (int) rows, (int) mycol, (int) i, (int) j ); } if (rows > 0 && mycol > 0) { magma_cdisplace_pointers(dW2_displ, T_array, ldt, i, j, batchCount, queue); magmablas_clarft_recctrmv_sm32x32_batched(rows, mycol, dW3_displ, dW2_displ, ldt, dW1_displ, ldtstep, batchCount, queue); } } } // the upper rectangular protion is updated, now if needed update the triangular portion if (stair_T == 0) { if (DEBUG == 3) { printf("doing ctrmv on the triangular portion of size %d %d of T(%d,%d)\n", (int) mycol, (int) mycol, (int) j, (int) j ); } if (mycol > 0) { magma_cdisplace_pointers(dW1_displ, dTstep_array, ldtstep, j, j, batchCount, queue); magma_cdisplace_pointers(dW3_displ, tau_array, 1, j, 0, batchCount, queue); magma_cdisplace_pointers(dW2_displ, T_array, ldt, j, j, batchCount, queue); magmablas_clarft_ctrmv_sm32x32_batched(mycol, mycol, dW3_displ, dW1_displ, ldtstep, dW2_displ, ldt, batchCount, queue); } } }// end of j magma_free(dW1_displ); magma_free(dW2_displ); magma_free(dW3_displ); magma_free(dTstep_array); return 0; }