extern "C" magma_int_t magma_cgetf2_batched( magma_int_t m, magma_int_t n, magmaFloatComplex **dA_array, magma_int_t lda, magmaFloatComplex **dW0_displ, magmaFloatComplex **dW1_displ, magmaFloatComplex **dW2_displ, magma_int_t **ipiv_array, magma_int_t *info_array, magma_int_t gbstep, magma_int_t batchCount, cublasHandle_t myhandle, magma_queue_t queue) { magma_int_t arginfo = 0; if (m < 0) { arginfo = -1; } else if (n < 0 ) { arginfo = -2; } else if (lda < max(1,m)) { arginfo = -4; } if (arginfo != 0) { magma_xerbla( __func__, -(arginfo) ); return arginfo; } // Quick return if possible if (m == 0 || n == 0) { return arginfo; } magmaFloatComplex neg_one = MAGMA_C_NEG_ONE; magmaFloatComplex one = MAGMA_C_ONE; magma_int_t nb = BATF2_NB; //magmaFloatComplex **cpuAarray = (magmaFloatComplex**) malloc(batchCount*sizeof(magmaFloatComplex*)); //magma_getvector( batchCount, sizeof(magmaFloatComplex*), dA_array, 1, cpuAarray, 1); magma_int_t min_mn = min(m, n); magma_int_t gbj, panelj, step, ib; for( panelj=0; panelj < min_mn; panelj+=nb) { ib = min(nb, min_mn-panelj); for(step=0; step < ib; step++){ gbj = panelj+step; //size_t required_shmem_size = zamax*(sizeof(float)+sizeof(int)) + (m-panelj+2)*sizeof(magmaFloatComplex); //if( (m-panelj) > 0) if( (m-panelj) > MAX_NTHREADS) //if( required_shmem_size > (MAX_SHARED_ALLOWED*1024)) { //printf("running non shared version\n"); // find the max of the column gbj arginfo = magma_icamax_batched(m-gbj, dA_array, 1, gbj, lda, ipiv_array, info_array, gbstep, batchCount, queue); if(arginfo != 0 ) return arginfo; // Apply the interchange to columns 1:N. swap the whole row arginfo = magma_cswap_batched(n, dA_array, lda, gbj, ipiv_array, batchCount, queue); if(arginfo != 0 ) return arginfo; // Compute elements J+1:M of J-th column. if (gbj < m) { arginfo = magma_cscal_cgeru_batched(m-gbj, ib-step, gbj, dA_array, lda, info_array, gbstep, batchCount, queue); if(arginfo != 0 ) return arginfo; } } else{ //printf("running --- shared version\n"); arginfo = magma_ccomputecolumn_batched(m-panelj, panelj, step, dA_array, lda, ipiv_array, info_array, gbstep, batchCount, queue); if(arginfo != 0 ) return arginfo; // Apply the interchange to columns 1:N. swap the whole row arginfo = magma_cswap_batched(n, dA_array, lda, gbj, ipiv_array, batchCount, queue); if(arginfo != 0 ) return arginfo; } } if( (n-panelj-ib) > 0){ // continue the update of the selected ib row column panelj+ib:n(TRSM) magma_cgetf2trsm_batched(ib, n-panelj-ib, dA_array, panelj, lda, batchCount, queue); // do the blocked DGER = DGEMM for the remaining panelj+ib:n columns magma_cdisplace_pointers(dW0_displ, dA_array, lda, ib+panelj, panelj, batchCount, queue); magma_cdisplace_pointers(dW1_displ, dA_array, lda, panelj, ib+panelj, batchCount, queue); magma_cdisplace_pointers(dW2_displ, dA_array, lda, ib+panelj, ib+panelj, batchCount, queue); #if 1 magmablas_cgemm_batched( MagmaNoTrans, MagmaNoTrans, m-(panelj+ib), n-(panelj+ib), ib, neg_one, dW0_displ, lda, dW1_displ, lda, one, dW2_displ, lda, batchCount, queue); #else cublasCgemmBatched(myhandle, CUBLAS_OP_N, CUBLAS_OP_N, m-(panelj+ib), n-(panelj+ib), ib, &neg_one, (const magmaFloatComplex**) dW0_displ, lda, (const magmaFloatComplex**) dW1_displ, lda, &one, dW2_displ, lda, batchCount ); #endif } } //free(cpuAarray); return 0; }
/** Purpose ------- This is an internal routine that might have many assumption. Documentation is not fully completed CGETRF_PANEL computes an LU factorization of a general M-by-N matrix A using partial pivoting with row interchanges. The factorization has the form A = P * L * U where P is a permutation matrix, L is lower triangular with unit diagonal elements (lower trapezoidal if m > n), and U is upper triangular (upper trapezoidal if m < n). This is the right-looking Level 3 BLAS version of the algorithm. This is a batched version that factors batchCount M-by-N matrices in parallel. dA, ipiv, and info become arrays with one entry per matrix. Arguments --------- @param[in] m INTEGER The number of rows of each matrix A. M >= 0. @param[in] n INTEGER The number of columns of each matrix A. N >= 0. @param[in] min_recpnb INTEGER. Internal use. The recursive nb @param[in,out] dA_array Array of pointers, dimension (batchCount). Each is a COMPLEX array on the GPU, dimension (LDDA,N). On entry, each pointer is an M-by-N matrix to be factored. On exit, the factors L and U from the factorization A = P*L*U; the unit diagonal elements of L are not stored. @param[in] ldda INTEGER The leading dimension of each array A. LDDA >= max(1,M). @param[out] dipiv_array Array of pointers, dimension (batchCount), for corresponding matrices. Each is an INTEGER array, dimension (min(M,N)) The pivot indices; for 1 <= i <= min(M,N), row i of the matrix was interchanged with row IPIV(i). @param[out] dpivinfo_array Array of pointers, dimension (batchCount), for internal use. @param[in,out] dX_array Array of pointers, dimension (batchCount). Each is a COMPLEX array X of dimension ( lddx, n ). On entry, should be set to 0 On exit, the solution matrix X @param[in] dX_length INTEGER. The size of each workspace matrix dX @param[in,out] dinvA_array Array of pointers, dimension (batchCount). Each is a COMPLEX array dinvA, a workspace on device. If side == MagmaLeft, dinvA must be of size >= ceil(m/TRI_NB)*TRI_NB*TRI_NB, If side == MagmaRight, dinvA must be of size >= ceil(n/TRI_NB)*TRI_NB*TRI_NB, where TRI_NB = 128. @param[in] dinvA_length INTEGER The size of each workspace matrix dinvA @param[in] dW1_displ Workspace array of pointers, for internal use. @param[in] dW2_displ Workspace array of pointers, for internal use. @param[in] dW3_displ Workspace array of pointers, for internal use. @param[in] dW4_displ Workspace array of pointers, for internal use. @param[in] dW5_displ Workspace array of pointers, for internal use. @param[out] info_array Array of INTEGERs, dimension (batchCount), for corresponding matrices. - = 0: successful exit - < 0: if INFO = -i, the i-th argument had an illegal value or another error occured, such as memory allocation failed. - > 0: if INFO = i, U(i,i) is exactly zero. The factorization has been completed, but the factor U is exactly singular, and division by zero will occur if it is used to solve a system of equations. @param[in] gbstep INTEGER internal use. @param[in] batchCount INTEGER The number of matrices to operate on. @param[in] queue magma_queue_t Queue to execute in. @ingroup magma_cgesv_comp ********************************************************************/ extern "C" magma_int_t magma_cgetrf_recpanel_batched( magma_int_t m, magma_int_t n, magma_int_t min_recpnb, magmaFloatComplex** dA_array, magma_int_t ldda, magma_int_t** dipiv_array, magma_int_t** dpivinfo_array, magmaFloatComplex** dX_array, magma_int_t dX_length, magmaFloatComplex** dinvA_array, magma_int_t dinvA_length, magmaFloatComplex** dW1_displ, magmaFloatComplex** dW2_displ, magmaFloatComplex** dW3_displ, magmaFloatComplex** dW4_displ, magmaFloatComplex** dW5_displ, magma_int_t *info_array, magma_int_t gbstep, magma_int_t batchCount, magma_queue_t queue) { //magma_int_t DEBUG = 3; // Quick return if possible if (m == 0 || n == 0) { return 0; } magmaFloatComplex **dA_displ = NULL; magma_malloc((void**)&dA_displ, batchCount * sizeof(*dA_displ)); magma_int_t **dipiv_displ = NULL; magma_malloc((void**)&dipiv_displ, batchCount * sizeof(*dipiv_displ)); magma_int_t panel_nb = n; if (panel_nb <= min_recpnb) { //if (DEBUG > 0)printf("calling bottom panel recursive with m=%d nb=%d\n",m,n); // panel factorization //magma_cdisplace_pointers(dA_displ, dA_array, ldda, 0, 0, batchCount); magma_cgetf2_batched(m, panel_nb, dA_array, ldda, dW1_displ, dW2_displ, dW3_displ, dipiv_array, info_array, gbstep, batchCount, queue); } else { // split A over two [A A2] // panel on A1, update on A2 then panel on A1 magma_int_t n1 = n/2; magma_int_t n2 = n-n1; magma_int_t m1 = m; magma_int_t m2 = m-n1; magma_int_t p1 = 0; magma_int_t p2 = n1; // panel on A1 //if (DEBUG > 0)printf("calling recursive panel on A1 with m=%d nb=%d min_recpnb %d\n",m1,n1,min_recpnb); magma_cdisplace_pointers(dA_displ, dA_array, ldda, p1, p1, batchCount, queue); magma_idisplace_pointers(dipiv_displ, dipiv_array, 1, p1, 0, batchCount, queue); magma_cgetrf_recpanel_batched( m1, n1, min_recpnb, dA_displ, ldda, dipiv_displ, dpivinfo_array, dX_array, dX_length, dinvA_array, dinvA_length, dW1_displ, dW2_displ, dW3_displ, dW4_displ, dW5_displ, info_array, gbstep, batchCount, queue); // update A2 //if (DEBUG > 0)printf("calling TRSM with m=%d n=%d \n",m1,n2); // setup pivinfo setup_pivinfo_batched(dpivinfo_array, dipiv_displ, m1, n1, batchCount, queue); magma_cdisplace_pointers(dW5_displ, dA_array, ldda, p1, p2, batchCount, queue); magma_claswp_rowparallel_batched( n2, dW5_displ, ldda, dX_array, n1, 0, n1, dpivinfo_array, batchCount, queue ); magmablas_ctrsm_outofplace_batched( MagmaLeft, MagmaLower, MagmaNoTrans, MagmaUnit, 1, n1, n2, MAGMA_C_ONE, dA_displ, ldda, // dA dX_array, n1, // dB dW5_displ, ldda, // dX dinvA_array, dinvA_length, dW1_displ, dW2_displ, dW3_displ, dW4_displ, 0, batchCount, queue ); magma_cdisplace_pointers(dW1_displ, dA_array, ldda, p2, 0, batchCount, queue); magma_cdisplace_pointers(dA_displ, dA_array, ldda, p2, p2, batchCount, queue); //if (DEBUG > 0)printf("calling update A2(%d,%d) -= A(%d,%d)*A(%d,%d) with m=%d n=%d k=%d ldda %d\n",p2,p2,p2,0,p1,p2,m2,n2,n1,ldda); magma_cgemm_batched( MagmaNoTrans, MagmaNoTrans, m2, n2, n1, MAGMA_C_NEG_ONE, dW1_displ, ldda, dW5_displ, ldda, MAGMA_C_ONE, dA_displ, ldda, batchCount, queue ); // panel on A2 //if (DEBUG > 0)printf("calling recursive panel on A2 with m=%d nb=%d min_recpnb %d\n",m2,n2,min_recpnb); magma_idisplace_pointers(dipiv_displ, dipiv_array, 1, p2, 0, batchCount, queue); magma_cgetrf_recpanel_batched( m2, n2, min_recpnb, dA_displ, ldda, dipiv_displ, dpivinfo_array, dX_array, dX_length, dinvA_array, dinvA_length, dW1_displ, dW2_displ, dW3_displ, dW4_displ, dW5_displ, info_array, gbstep+p2, batchCount, queue); // setup pivinfo setup_pivinfo_batched(dpivinfo_array, dipiv_displ, m2, n2, batchCount, queue); adjust_ipiv_batched(dipiv_displ, n2, n1, batchCount, queue); magma_cdisplace_pointers(dW1_displ, dA_array, ldda, p2, 0, batchCount, queue); // no need since it is above magma_claswp_rowparallel_batched( n1, dW1_displ, ldda, dW1_displ, ldda, n1, n, dpivinfo_array, batchCount, queue ); } magma_free(dA_displ); magma_free(dipiv_displ); return 0; }
extern "C" magma_int_t magma_cgeqrf_batched( magma_int_t m, magma_int_t n, magmaFloatComplex **dA_array, magma_int_t ldda, magmaFloatComplex **tau_array, magma_int_t *info_array, magma_int_t batchCount, magma_queue_t queue) { #define dA(i, j) (dA + (i) + (j)*ldda) // A(i, j) means at i row, j column magma_int_t min_mn = min(m, n); cudaMemset(info_array, 0, batchCount*sizeof(magma_int_t)); /* Check arguments */ magma_int_t arginfo = 0; if (m < 0) arginfo = -1; else if (n < 0) arginfo = -2; else if (ldda < max(1,m)) arginfo = -4; if (arginfo != 0) { magma_xerbla( __func__, -(arginfo) ); return arginfo; } /* Quick return if possible */ if (m == 0 || n == 0) if(min_mn == 0 ) return arginfo; if( m > 2048 || n > 2048 ){ printf("=========================================================================================\n"); printf(" WARNING batched routines are designed for small sizes it might be better to use the\n Native/Hybrid classical routines if you want performance\n"); printf("=========================================================================================\n"); } magma_int_t nb = 32; magma_int_t nnb = 8; magma_int_t i, k, ib=nb, jb=nnb; magma_int_t ldw, ldt, ldr, offset; cublasHandle_t myhandle; cublasCreate_v2(&myhandle); magmaFloatComplex **dW0_displ = NULL; magmaFloatComplex **dW1_displ = NULL; magmaFloatComplex **dW2_displ = NULL; magmaFloatComplex **dW3_displ = NULL; magmaFloatComplex **dW4_displ = NULL; magmaFloatComplex **dW5_displ = NULL; magmaFloatComplex *dwork = NULL; magmaFloatComplex *dT = NULL; magmaFloatComplex *dR = NULL; magmaFloatComplex **dR_array = NULL; magmaFloatComplex **dT_array = NULL; magmaFloatComplex **cpuAarray = NULL; magmaFloatComplex **cpuTarray = NULL; magma_malloc((void**)&dW0_displ, batchCount * sizeof(*dW0_displ)); magma_malloc((void**)&dW1_displ, batchCount * sizeof(*dW1_displ)); magma_malloc((void**)&dW2_displ, batchCount * sizeof(*dW2_displ)); magma_malloc((void**)&dW3_displ, batchCount * sizeof(*dW3_displ)); magma_malloc((void**)&dW4_displ, batchCount * sizeof(*dW4_displ)); // used in clarfb magma_malloc((void**)&dW5_displ, batchCount * sizeof(*dW5_displ)); magma_malloc((void**)&dR_array, batchCount * sizeof(*dR_array)); magma_malloc((void**)&dT_array, batchCount * sizeof(*dT_array)); ldt = ldr = min(nb, min_mn); magma_cmalloc(&dwork, (2 * nb * n) * batchCount); magma_cmalloc(&dR, ldr * n * batchCount); magma_cmalloc(&dT, ldt * ldt * batchCount); magma_malloc_cpu((void**) &cpuAarray, batchCount*sizeof(magmaFloatComplex*)); magma_malloc_cpu((void**) &cpuTarray, batchCount*sizeof(magmaFloatComplex*)); /* check allocation */ if ( dW0_displ == NULL || dW1_displ == NULL || dW2_displ == NULL || dW3_displ == NULL || dW4_displ == NULL || dW5_displ == NULL || dR_array == NULL || dT_array == NULL || dR == NULL || dT == NULL || dwork == NULL || cpuAarray == NULL || cpuTarray == NULL ) { magma_free(dW0_displ); magma_free(dW1_displ); magma_free(dW2_displ); magma_free(dW3_displ); magma_free(dW4_displ); magma_free(dW5_displ); magma_free(dR_array); magma_free(dT_array); magma_free(dR); magma_free(dT); magma_free(dwork); free(cpuAarray); free(cpuTarray); magma_int_t info = MAGMA_ERR_DEVICE_ALLOC; magma_xerbla( __func__, -(info) ); return info; } magmablas_claset_q(MagmaFull, ldr, n*batchCount , MAGMA_C_ZERO, MAGMA_C_ZERO, dR, ldr, queue); magmablas_claset_q(MagmaFull, ldt, ldt*batchCount, MAGMA_C_ZERO, MAGMA_C_ZERO, dT, ldt, queue); cset_pointer(dR_array, dR, 1, 0, 0, ldr*min(nb, min_mn), batchCount, queue); cset_pointer(dT_array, dT, 1, 0, 0, ldt*min(nb, min_mn), batchCount, queue); magma_queue_t cstream; magmablasGetKernelStream(&cstream); magma_int_t streamid; const magma_int_t nbstreams=32; magma_queue_t stream[nbstreams]; for(i=0; i<nbstreams; i++){ magma_queue_create( &stream[i] ); } magma_getvector( batchCount, sizeof(magmaFloatComplex*), dA_array, 1, cpuAarray, 1); magma_getvector( batchCount, sizeof(magmaFloatComplex*), dT_array, 1, cpuTarray, 1); magmablasSetKernelStream(NULL); for(i=0; i<min_mn;i+=nb) { ib = min(nb, min_mn-i); //=============================================== // panel factorization //=============================================== magma_cdisplace_pointers(dW0_displ, dA_array, ldda, i, i, batchCount, queue); magma_cdisplace_pointers(dW2_displ, tau_array, 1, i, 0, batchCount, queue); //dwork is used in panel factorization and trailing matrix update //dW4_displ, dW5_displ are used as workspace and configured inside magma_cgeqrf_panel_batched(m-i, ib, jb, dW0_displ, ldda, dW2_displ, dT_array, ldt, dR_array, ldr, dW1_displ, dW3_displ, dwork, dW4_displ, dW5_displ, info_array, batchCount, myhandle, queue); //=============================================== // end of panel //=============================================== //direct panel matrix V in dW0_displ, magma_cdisplace_pointers(dW0_displ, dA_array, ldda, i, i, batchCount, queue); // copy the upper part of V into dR cgeqrf_copy_upper_batched(ib, jb, dW0_displ, ldda, dR_array, ldr, batchCount, queue); //=============================================== // update trailing matrix //=============================================== //dwork is used in panel factorization and trailing matrix update //reset dW4_displ ldw = nb; cset_pointer(dW4_displ, dwork, 1, 0, 0, ldw*n, batchCount, queue ); offset = ldw*n*batchCount; cset_pointer(dW5_displ, dwork + offset, 1, 0, 0, ldw*n, batchCount, queue ); if( (n-ib-i) > 0) { // set the diagonal of v as one and the upper triangular part as zero magmablas_claset_batched(MagmaUpper, ib, ib, MAGMA_C_ZERO, MAGMA_C_ONE, dW0_displ, ldda, batchCount, queue); magma_cdisplace_pointers(dW2_displ, tau_array, 1, i, 0, batchCount, queue); // it is faster since it is using BLAS-3 GEMM routines, different from lapack implementation magma_clarft_batched(m-i, ib, 0, dW0_displ, ldda, dW2_displ, dT_array, ldt, dW4_displ, nb*ldt, batchCount, myhandle, queue); // perform C = (I-V T^H V^H) * C, C is the trailing matrix //------------------------------------------- // USE STREAM GEMM //------------------------------------------- if( (m-i) > 100 && (n-i-ib) > 100) { // But since the code use the NULL stream everywhere, // so I don't need it, because the NULL stream do the sync by itself //magma_device_sync(); for(k=0; k<batchCount; k++) { streamid = k%nbstreams; magmablasSetKernelStream(stream[streamid]); // the stream gemm must take cpu pointer magma_clarfb_gpu_gemm(MagmaLeft, MagmaConjTrans, MagmaForward, MagmaColumnwise, m-i, n-i-ib, ib, cpuAarray[k] + i + i * ldda, ldda, cpuTarray[k], ldt, cpuAarray[k] + i + (i+ib) * ldda, ldda, dwork + nb * n * k, -1, dwork + nb * n * batchCount + nb * n * k, -1); } // need to synchronise to be sure that panel does not start before // finishing the update at least of the next panel // BUT no need for it as soon as the other portion of the code // use the NULL stream which do the sync by itself //magma_device_sync(); magmablasSetKernelStream(NULL); } //------------------------------------------- // USE BATCHED GEMM //------------------------------------------- else { //direct trailing matrix in dW1_displ magma_cdisplace_pointers(dW1_displ, dA_array, ldda, i, i+ib, batchCount, queue); magma_clarfb_gemm_batched( MagmaLeft, MagmaConjTrans, MagmaForward, MagmaColumnwise, m-i, n-i-ib, ib, (const magmaFloatComplex**)dW0_displ, ldda, (const magmaFloatComplex**)dT_array, ldt, dW1_displ, ldda, dW4_displ, ldw, dW5_displ, ldw, batchCount, myhandle, queue); } }// update the trailing matrix //=============================================== // copy dR back to V after the trailing matrix update magmablas_clacpy_batched(MagmaUpper, ib, ib, dR_array, ldr, dW0_displ, ldda, batchCount, queue); } for(k=0; k<nbstreams; k++){ magma_queue_destroy( stream[k] ); } magmablasSetKernelStream(cstream); cublasDestroy_v2(myhandle); magma_free(dW0_displ); magma_free(dW1_displ); magma_free(dW2_displ); magma_free(dW3_displ); magma_free(dW4_displ); magma_free(dW5_displ); magma_free(dR_array); magma_free(dT_array); magma_free(dR); magma_free(dT); magma_free(dwork); free(cpuAarray); free(cpuTarray); return arginfo; }
extern "C" magma_int_t magma_cgeqrf_panel_batched( magma_int_t m, magma_int_t n, magma_int_t nb, magmaFloatComplex** dA_array, magma_int_t ldda, magmaFloatComplex** tau_array, magmaFloatComplex** dT_array, magma_int_t ldt, magmaFloatComplex** dR_array, magma_int_t ldr, magmaFloatComplex** dW0_displ, magmaFloatComplex** dW1_displ, magmaFloatComplex *dwork, magmaFloatComplex** dW2_displ, magmaFloatComplex** dW3_displ, magma_int_t *info_array, magma_int_t batchCount, cublasHandle_t myhandle, magma_queue_t queue) { magma_int_t j, jb; magma_int_t ldw = nb; for( j=0; j<n; j+=nb) { jb = min(nb, n-j); magma_cdisplace_pointers(dW0_displ, dA_array, ldda, j, j, batchCount, queue); magma_cdisplace_pointers(dW2_displ, tau_array, 1, j, 0, batchCount, queue); magma_cdisplace_pointers(dW3_displ, dR_array, ldr, j, j, batchCount, queue); // //sub-panel factorization magma_cgeqr2_batched( m-j, jb, dW0_displ, ldda, dW2_displ, info_array, batchCount, queue); //copy upper part of dA to dR magma_cdisplace_pointers(dW0_displ, dA_array, ldda, j, j, batchCount, queue); magma_cdisplace_pointers(dW3_displ, dR_array, ldr, j, j, batchCount, queue); magmablas_clacpy_batched(MagmaUpper, jb, jb, dW0_displ, ldda, dW3_displ, ldr, batchCount, queue); magma_cdisplace_pointers(dW0_displ, dA_array, ldda, j, j, batchCount, queue); magma_cdisplace_pointers(dW3_displ, dR_array, ldr, j, j, batchCount, queue); magmablas_claset_batched(MagmaUpper, jb, jb, MAGMA_C_ZERO, MAGMA_C_ONE, dW0_displ, ldda, batchCount, queue); if( (n-j-jb) > 0) //update the trailing matrix inside the panel { magma_clarft_sm32x32_batched(m-j, jb, dW0_displ, ldda, dW2_displ, dT_array, ldt, batchCount, myhandle, queue); magma_cdisplace_pointers(dW1_displ, dA_array, ldda, j, j + jb, batchCount, queue); cset_pointer(dW2_displ, dwork, 1, 0, 0, ldw*n, batchCount, queue ); cset_pointer(dW3_displ, dwork + ldw*n*batchCount, 1, 0, 0, ldw*n, batchCount, queue ); magma_clarfb_gemm_batched( MagmaLeft, MagmaConjTrans, MagmaForward, MagmaColumnwise, m-j, n-j-jb, jb, (const magmaFloatComplex**)dW0_displ, ldda, (const magmaFloatComplex**)dT_array, ldt, dW1_displ, ldda, dW2_displ, ldw, dW3_displ, ldw, batchCount, myhandle, queue); } } return 0; }
//=================================================================================================================== //=================================================================================================================== //=================================================================================================================== extern "C" magma_int_t magma_clarft_batched(magma_int_t n, magma_int_t k, magma_int_t stair_T, magmaFloatComplex **v_array, magma_int_t ldv, magmaFloatComplex **tau_array, magmaFloatComplex **T_array, magma_int_t ldt, magmaFloatComplex **work_array, magma_int_t lwork, magma_int_t batchCount, magma_queue_t queue) { magmaFloatComplex c_one = MAGMA_C_ONE; magmaFloatComplex c_zero = MAGMA_C_ZERO; if ( k <= 0) return 0; if ( stair_T > 0 && k <= stair_T) return 0; magma_int_t maxnb = max_shared_bsiz; if ( lwork < k*ldt) { magma_xerbla( __func__, -(10) ); return -10; } if ( stair_T > 0 && stair_T > maxnb) { magma_xerbla( __func__, -(3) ); return -3; } magma_int_t DEBUG=0; magma_int_t nb = stair_T == 0 ? min(k,maxnb) : stair_T; magma_int_t i, j, prev_n, mycol, rows; magmaFloatComplex **dW1_displ = NULL; magmaFloatComplex **dW2_displ = NULL; magmaFloatComplex **dW3_displ = NULL; magmaFloatComplex **dTstep_array = NULL; magma_malloc((void**)&dW1_displ, batchCount * sizeof(*dW1_displ)); magma_malloc((void**)&dW2_displ, batchCount * sizeof(*dW2_displ)); magma_malloc((void**)&dW3_displ, batchCount * sizeof(*dW3_displ)); magma_malloc((void**)&dTstep_array, batchCount * sizeof(*dTstep_array)); //magmaFloatComplex *Tstep = k > nb ? work : T; if (k > nb) { magma_cdisplace_pointers(dTstep_array, work_array, lwork, 0, 0, batchCount, queue); } else { magma_cdisplace_pointers(dTstep_array, T_array, ldt, 0, 0, batchCount, queue); } //magma_int_t ldtstep = k > nb ? k : ldt; magma_int_t ldtstep = ldt; //a enlever // stair_T = 0 meaning all T // stair_T > 0 meaning the triangular portion of T has been computed. // the value of stair_T is the nb of these triangulars //GEMV compute the whole triangular upper portion of T (phase 1) // TODO addcublas to check perf magma_cgemm_batched( MagmaConjTrans, MagmaNoTrans, k, k, n, c_one, v_array, ldv, v_array, ldv, c_zero, dTstep_array, ldtstep, batchCount, queue ); magmablas_claset_batched( MagmaLower, k, k, MAGMA_C_ZERO, MAGMA_C_ZERO, dTstep_array, ldtstep, batchCount, queue ); // no need for it as T is expected to be lower zero //if (k > nb) magmablas_claset_batched( MagmaLower, k, k, MAGMA_C_ZERO, MAGMA_C_ZERO, dTstep_array, ldtstep, batchCount, queue ); //TRMV //T(1:i-1,i) := T(1:i-1,1:i-1) * W(1:i-1) i=[1:k] // TRMV is split over block of column of size nb // the update should be done from top to bottom so: // 1- a gemm using the previous computed columns // of T to update rectangular upper protion above // the triangle of my columns // 2- the columns need to be updated by a serial // loop over of gemv over itself. since we limit the // shared memory to nb, this nb column // are split vertically by chunk of nb rows dim3 grid(1, 1, batchCount); for (j=0; j < k; j += nb) { prev_n = j; mycol = min(nb, k-j); // note that myrow = prev_n + mycol; if (prev_n > 0 && mycol > 0) { if (DEBUG == 3) { printf("doing gemm on the rectangular portion of size %d %d of T(%d,%d)\n", (int) prev_n, (int) mycol, 0, (int) j ); } magma_cdisplace_pointers(dW1_displ, dTstep_array, ldtstep, 0, j, batchCount, queue); magma_cdisplace_pointers(dW2_displ, T_array, ldt, 0, j, batchCount, queue); magma_cgemm_batched( MagmaNoTrans, MagmaNoTrans, prev_n, mycol, prev_n, c_one, T_array, ldt, dW1_displ, ldtstep, c_zero, dW2_displ, ldt, batchCount, queue ); // update my rectangular portion (prev_n,mycol) using sequence of gemv magma_cdisplace_pointers(dW1_displ, dTstep_array, ldtstep, j, j, batchCount, queue); magma_cdisplace_pointers(dW3_displ, tau_array, 1, j, 0, batchCount, queue); for (i=0; i < prev_n; i += nb) { rows = min(nb,prev_n-i); if (DEBUG == 3) { printf(" doing recctrmv on the rectangular portion of size %d %d of T(%d,%d)\n", (int) rows, (int) mycol, (int) i, (int) j ); } if (rows > 0 && mycol > 0) { magma_cdisplace_pointers(dW2_displ, T_array, ldt, i, j, batchCount, queue); magmablas_clarft_recctrmv_sm32x32_batched(rows, mycol, dW3_displ, dW2_displ, ldt, dW1_displ, ldtstep, batchCount, queue); } } } // the upper rectangular protion is updated, now if needed update the triangular portion if (stair_T == 0) { if (DEBUG == 3) { printf("doing ctrmv on the triangular portion of size %d %d of T(%d,%d)\n", (int) mycol, (int) mycol, (int) j, (int) j ); } if (mycol > 0) { magma_cdisplace_pointers(dW1_displ, dTstep_array, ldtstep, j, j, batchCount, queue); magma_cdisplace_pointers(dW3_displ, tau_array, 1, j, 0, batchCount, queue); magma_cdisplace_pointers(dW2_displ, T_array, ldt, j, j, batchCount, queue); magmablas_clarft_ctrmv_sm32x32_batched(mycol, mycol, dW3_displ, dW1_displ, ldtstep, dW2_displ, ldt, batchCount, queue); } } }// end of j magma_free(dW1_displ); magma_free(dW2_displ); magma_free(dW3_displ); magma_free(dTstep_array); return 0; }