extern "C" magma_int_t magma_sgetrf_recpanel_batched_q( magma_int_t m, magma_int_t n, magma_int_t min_recpnb, float** dA_array, magma_int_t ldda, magma_int_t** dipiv_array, magma_int_t** dpivinfo_array, float** dX_array, magma_int_t dX_length, float** dinvA_array, magma_int_t dinvA_length, float** dW1_displ, float** dW2_displ, float** dW3_displ, float** dW4_displ, float** dW5_displ, magma_int_t *info_array, magma_int_t gbstep, magma_int_t batchCount, magma_queue_t stream, cublasHandle_t myhandle) { //magma_int_t DEBUG = 3; // Quick return if possible if (m ==0 || n == 0) { return 0; } float **dA_displ = NULL; magma_malloc((void**)&dA_displ, batchCount * sizeof(*dA_displ)); magma_int_t **dipiv_displ = NULL; magma_malloc((void**)&dipiv_displ, batchCount * sizeof(*dipiv_displ)); magma_int_t panel_nb = n; if(panel_nb <= min_recpnb){ //if(DEBUG>0)printf("calling bottom panel recursive with m=%d nb=%d\n",m,n); // panel factorization //magma_sdisplace_pointers(dA_displ, dA_array, ldda, 0, 0, batchCount); magma_sgetf2_batched( m, panel_nb, dA_array, ldda, dW1_displ, dW2_displ, dW3_displ, dipiv_array, info_array, gbstep, batchCount, myhandle); } else{ // split A over two [A A2] // panel on A1, update on A2 then panel on A1 magma_int_t n1 = n/2; magma_int_t n2 = n-n1; magma_int_t m1 = m; magma_int_t m2 = m-n1; magma_int_t p1 = 0; magma_int_t p2 = n1; // panel on A1 //if(DEBUG>0)printf("calling recursive panel on A1 with m=%d nb=%d min_recpnb %d\n",m1,n1,min_recpnb); magma_sdisplace_pointers(dA_displ, dA_array, ldda, p1, p1, batchCount); magma_idisplace_pointers(dipiv_displ, dipiv_array, 1, p1, 0, batchCount); magma_sgetrf_recpanel_batched_q( m1, n1, min_recpnb, dA_displ, ldda, dipiv_displ, dpivinfo_array, dX_array, dX_length, dinvA_array, dinvA_length, dW1_displ, dW2_displ, dW3_displ, dW4_displ, dW5_displ, info_array, gbstep, batchCount, stream, myhandle); // update A2 //if(DEBUG>0)printf("calling TRSM with m=%d n=%d \n",m1,n2); // setup pivinfo setup_pivinfo_batched_q(dpivinfo_array, dipiv_displ, m1, n1, stream, batchCount); magma_sdisplace_pointers(dW5_displ, dA_array, ldda, p1, p2, batchCount); magma_slaswp_rowparallel_batched( n2, dW5_displ, ldda, dX_array, n1, 0, n1, dpivinfo_array, batchCount); magmablas_strsm_outofplace_batched(MagmaLeft, MagmaLower, MagmaNoTrans, MagmaUnit, 1, n1, n2, MAGMA_S_ONE, dA_displ, ldda, // dA dX_array, n1, // dB dW5_displ, ldda, // dX dinvA_array, dinvA_length, dW1_displ, dW2_displ, dW3_displ, dW4_displ, 0, batchCount); magma_sdisplace_pointers(dW1_displ, dA_array, ldda, p2, 0, batchCount); magma_sdisplace_pointers(dA_displ, dA_array, ldda, p2, p2, batchCount); //if(DEBUG>0)printf("calling update A2(%d,%d) -= A(%d,%d)*A(%d,%d) with m=%d n=%d k=%d ldda %d\n",p2,p2,p2,0,p1,p2,m2,n2,n1,ldda); #if 0 float neg_one = MAGMA_S_NEG_ONE; float one = MAGMA_S_ONE; cublasSgemmBatched(myhandle, CUBLAS_OP_N, CUBLAS_OP_N, m2, n2, n1, &neg_one, (const float**) dW1_displ, ldda, (const float**) dW5_displ, ldda, &one, dA_displ, ldda, batchCount ); #else magmablas_sgemm_batched( MagmaNoTrans, MagmaNoTrans, m2, n2, n1, MAGMA_S_NEG_ONE, dW1_displ, ldda, dW5_displ, ldda, MAGMA_S_ONE, dA_displ, ldda, batchCount); #endif // panel on A2 //if(DEBUG>0)printf("calling recursive panel on A2 with m=%d nb=%d min_recpnb %d\n",m2,n2,min_recpnb); magma_idisplace_pointers(dipiv_displ, dipiv_array, 1, p2, 0, batchCount); magma_sgetrf_recpanel_batched_q( m2, n2, min_recpnb, dA_displ, ldda, dipiv_displ, dpivinfo_array, dX_array, dX_length, dinvA_array, dinvA_length, dW1_displ, dW2_displ, dW3_displ, dW4_displ, dW5_displ, info_array, gbstep+p2, batchCount, stream, myhandle); // setup pivinfo setup_pivinfo_batched_q(dpivinfo_array, dipiv_displ, m2, n2, stream, batchCount); adjust_ipiv_batched_q(dipiv_displ, n2, n1, magma_stream, batchCount); magma_sdisplace_pointers(dW1_displ, dA_array, ldda, p2, 0, batchCount); // no need since it is above magma_slaswp_rowparallel_batched( n1, dW1_displ, ldda, dW1_displ, ldda, n1, n, dpivinfo_array, batchCount); } magma_free(dA_displ); magma_free(dipiv_displ); return 0; }
extern "C" magma_int_t magma_slarft_batched(magma_int_t n, magma_int_t k, magma_int_t stair_T, float **v_array, magma_int_t ldv, float **tau_array, float **T_array, magma_int_t ldt, float **work_array, magma_int_t lwork, magma_int_t batchCount, cublasHandle_t myhandle, magma_queue_t queue) { if( k <= 0) return 0; if( stair_T > 0 && k <= stair_T) return 0; magma_int_t maxnb = max_shared_bsiz; if( lwork < k*ldt) { magma_xerbla( __func__, -(10) ); return -10; } if( stair_T > 0 && stair_T > maxnb) { magma_xerbla( __func__, -(3) ); return -3; } magma_int_t DEBUG=0; magma_int_t nb = stair_T == 0 ? min(k,maxnb) : stair_T; magma_int_t i, j, prev_n, mycol, rows; float **dW1_displ = NULL; float **dW2_displ = NULL; float **dW3_displ = NULL; float **dTstep_array = NULL; magma_malloc((void**)&dW1_displ, batchCount * sizeof(*dW1_displ)); magma_malloc((void**)&dW2_displ, batchCount * sizeof(*dW2_displ)); magma_malloc((void**)&dW3_displ, batchCount * sizeof(*dW3_displ)); magma_malloc((void**)&dTstep_array, batchCount * sizeof(*dTstep_array)); //float *Tstep = k > nb ? work : T; if(k > nb) { magma_sdisplace_pointers(dTstep_array, work_array, lwork, 0, 0, batchCount, queue); } else { magma_sdisplace_pointers(dTstep_array, T_array, ldt, 0, 0, batchCount, queue); } //magma_int_t ldtstep = k > nb ? k : ldt; magma_int_t ldtstep = ldt; //a enlever // stair_T = 0 meaning all T // stair_T > 0 meaning the triangular portion of T has been computed. // the value of stair_T is the nb of these triangulars //GEMV compute the whole triangular upper portion of T (phase 1) // TODO addcublas to check perf #ifdef RFT_MAG_GEM magmablas_sgemm_batched( MagmaConjTrans, MagmaNoTrans, k, k, n, one, v_array, ldv, v_array, ldv, zero, dTstep_array, ldtstep, batchCount, queue); #else cublasSgemmBatched(myhandle, CUBLAS_OP_C, CUBLAS_OP_N, k, k, n, &one, (const float**) v_array, ldv, (const float**) v_array, ldv, &zero, dTstep_array, ldtstep, batchCount); #endif magmablas_slaset_batched(MagmaLower, k, k, MAGMA_S_ZERO, MAGMA_S_ZERO, dTstep_array, ldtstep, batchCount, queue); // no need for it as T is expected to be lower zero //if(k > nb) magmablas_slaset_batched(MagmaLower, k, k, MAGMA_S_ZERO, MAGMA_S_ZERO, dTstep_array, ldtstep, batchCount); //TRMV //T(1:i-1,i) := T(1:i-1,1:i-1) * W(1:i-1) i=[1:k] // TRMV is split over block of column of size nb // the update should be done from top to bottom so: // 1- a gemm using the previous computed columns // of T to update rectangular upper protion above // the triangle of my columns // 2- the columns need to be updated by a serial // loop over of gemv over itself. since we limit the // shared memory to nb, this nb column // are split vertically by chunk of nb rows dim3 grid(1, 1, batchCount); for(j=0; j<k; j+=nb) { prev_n = j; mycol = min(nb, k-j); // note that myrow = prev_n + mycol; if(prev_n>0 && mycol>0){ if(DEBUG==3) printf("doing gemm on the rectangular portion of size %d %d of T(%d,%d)\n",prev_n,mycol,0,j); magma_sdisplace_pointers(dW1_displ, dTstep_array, ldtstep, 0, j, batchCount, queue); magma_sdisplace_pointers(dW2_displ, T_array, ldt, 0, j, batchCount, queue); #ifdef RFT_MAG_GEM magmablas_sgemm_batched( MagmaNoTrans, MagmaNoTrans, prev_n, mycol, prev_n, one, T_array, ldt, dW1_displ, ldtstep, zero, dW2_displ, ldt, batchCount, queue ); #else cublasSgemmBatched(myhandle, CUBLAS_OP_N, CUBLAS_OP_N, prev_n, mycol, prev_n, &one, (const float**) T_array, ldt, (const float**) dW1_displ, ldtstep, &zero, dW2_displ, ldt, batchCount); #endif // update my rectangular portion (prev_n,mycol) using sequence of gemv magma_sdisplace_pointers(dW1_displ, dTstep_array, ldtstep, j, j, batchCount, queue); magma_sdisplace_pointers(dW3_displ, tau_array, 1, j, 0, batchCount, queue); for(i=0; i<prev_n; i+=nb) { rows = min(nb,prev_n-i); if(DEBUG==3) printf(" doing recstrmv on the rectangular portion of size %d %d of T(%d,%d)\n",rows,mycol,i,j); if(rows>0 && mycol>0) { magma_sdisplace_pointers(dW2_displ, T_array, ldt, i, j, batchCount, queue); magmablas_slarft_recstrmv_sm32x32_batched(rows, mycol, dW3_displ, dW2_displ, ldt, dW1_displ, ldtstep, batchCount, queue); } } } // the upper rectangular protion is updated, now if needed update the triangular portion if(stair_T == 0){ if(DEBUG==3) printf("doing strmv on the triangular portion of size %d %d of T(%d,%d)\n",mycol,mycol,j,j); if(mycol>0) { magma_sdisplace_pointers(dW1_displ, dTstep_array, ldtstep, j, j, batchCount, queue); magma_sdisplace_pointers(dW3_displ, tau_array, 1, j, 0, batchCount, queue); magma_sdisplace_pointers(dW2_displ, T_array, ldt, j, j, batchCount, queue); magmablas_slarft_strmv_sm32x32_batched(mycol, mycol, dW3_displ, dW1_displ, ldtstep, dW2_displ, ldt, batchCount, queue); } } }// end of j magma_free(dW1_displ); magma_free(dW2_displ); magma_free(dW3_displ); magma_free(dTstep_array); return 0; }
extern "C" magma_int_t magma_sgetrf_recpanel_nopiv_batched_q( magma_int_t m, magma_int_t n, magma_int_t min_recpnb, float** dA_array, magma_int_t ldda, float** dX_array, magma_int_t dX_length, float** dinvA_array, magma_int_t dinvA_length, float** dW1_displ, float** dW2_displ, float** dW3_displ, float** dW4_displ, float** dW5_displ, magma_int_t *info_array, magma_int_t gbstep, magma_int_t batchCount, magma_queue_t stream, cublasHandle_t myhandle) { // Quick return if possible if (m == 0 || n == 0) { return 0; } magma_int_t arginfo = 0; float **dA_displ = NULL; magma_malloc((void**)&dA_displ, batchCount * sizeof(*dA_displ)); magma_int_t panel_nb = n; if(panel_nb <= min_recpnb){ // if(DEBUG>0)printf("calling bottom panel recursive with m=%d nb=%d\n",m,n); // panel factorization //magma_sdisplace_pointers(dA_displ, dA_array, ldda, 0, 0, batchCount); arginfo = magma_sgetrf_panel_nopiv_batched_q( m, panel_nb, dA_array, ldda, dX_array, dX_length, dinvA_array, dinvA_length, dW1_displ, dW2_displ, dW3_displ, dW4_displ, dW5_displ, info_array, gbstep, batchCount, stream, myhandle); if (arginfo != 0) return arginfo; } else{ // split A over two [A A2] // panel on A1, update on A2 then panel on A1 magma_int_t n1 = n/2; magma_int_t n2 = n-n1; magma_int_t m1 = m; magma_int_t m2 = m-n1; magma_int_t p1 = 0; magma_int_t p2 = n1; // panel on A1 //printf("calling recursive panel on A1 with m=%d nb=%d min_recpnb %d\n",m1,n1,min_recpnb); magma_sdisplace_pointers(dA_displ, dA_array, ldda, p1, p1, batchCount); arginfo = magma_sgetrf_recpanel_nopiv_batched_q( m1, n1, min_recpnb, dA_displ, ldda, dX_array, dX_length, dinvA_array, dinvA_length, dW1_displ, dW2_displ, dW3_displ, dW4_displ, dW5_displ, info_array, gbstep, batchCount, stream, myhandle); if (arginfo != 0) return arginfo; // update A2 //printf("calling update A2 with m=%d n=%d k=%d\n",m2,n2,n1); magma_sdisplace_pointers(dW5_displ, dA_array, ldda, p1, p2, batchCount); magmablas_strsm_work_batched(MagmaLeft, MagmaLower, MagmaNoTrans, MagmaUnit, 1, n1, n2, MAGMA_S_ONE, dA_displ, ldda, // dA dW5_displ, ldda, // dB dX_array, n1, // dX dinvA_array, dinvA_length, dW1_displ, dW2_displ, dW3_displ, dW4_displ, 1, batchCount); magma_sdisplace_pointers(dW1_displ, dA_array, ldda, p2, 0, batchCount); magma_sdisplace_pointers(dA_displ, dA_array, ldda, p2, p2, batchCount); magmablas_sgemm_batched( MagmaNoTrans, MagmaNoTrans, m2, n2, n1, MAGMA_S_NEG_ONE, dW1_displ, ldda, dW5_displ, ldda, MAGMA_S_ONE, dA_displ, ldda, batchCount); // panel on A2 //printf("calling recursive panel on A2 with m=%d nb=%d min_recpnb %d\n",m2,n2,min_recpnb); arginfo = magma_sgetrf_recpanel_nopiv_batched_q( m2, n2, min_recpnb, dA_displ, ldda, dX_array, dX_length, dinvA_array, dinvA_length, dW1_displ, dW2_displ, dW3_displ, dW4_displ, dW5_displ, info_array, gbstep+p2, batchCount, stream, myhandle); if (arginfo != 0) return arginfo; } magma_free(dA_displ); return 0; }
extern "C" magma_int_t magma_sgetf2_nopiv_batched( magma_int_t m, magma_int_t n, float **dA_array, magma_int_t lda, float **dW0_displ, float **dW1_displ, float **dW2_displ, magma_int_t *info_array, magma_int_t gbstep, magma_int_t batchCount, cublasHandle_t myhandle) { magma_int_t arginfo = 0; if (m < 0) { arginfo = -1; } else if (n < 0 ) { arginfo = -2; } else if (lda < max(1,m)) { arginfo = -4; } if (arginfo != 0) { magma_xerbla( __func__, -(arginfo) ); return arginfo; } // Quick return if possible if (m == 0 || n == 0) { return arginfo; } float neg_one = MAGMA_S_NEG_ONE; float one = MAGMA_S_ONE; magma_int_t nb = 32;//BATF2_NB; magma_int_t min_mn = min(m, n); magma_int_t gbj, panelj, step, ib; for( panelj=0; panelj < min_mn; panelj+=nb) { ib = min(nb, min_mn-panelj); for(step=0; step < ib; step++){ gbj = panelj+step; #if 0 size_t required_shmem_size = ((m-panelj)*ib)*sizeof(float); if( required_shmem_size > (MAX_SHARED_ALLOWED*1024)) #else if( (m-panelj) > 0) #endif { // Compute elements J+1:M of J-th column. if (gbj < m) { arginfo = magma_sscal_sger_batched(m-gbj, ib-step, gbj, dA_array, lda, info_array, gbstep, batchCount); if(arginfo != 0 ) return arginfo; } } else{ // TODO } } if( (n-panelj-ib) > 0){ // continue the update of the selected ib row column panelj+ib:n(TRSM) magma_sgetf2trsm_batched(ib, n-panelj-ib, dA_array, panelj, lda, batchCount); // do the blocked DGER = DGEMM for the remaining panelj+ib:n columns magma_sdisplace_pointers(dW0_displ, dA_array, lda, ib+panelj, panelj, batchCount); magma_sdisplace_pointers(dW1_displ, dA_array, lda, panelj, ib+panelj, batchCount); magma_sdisplace_pointers(dW2_displ, dA_array, lda, ib+panelj, ib+panelj, batchCount); #if 1 magmablas_sgemm_batched( MagmaNoTrans, MagmaNoTrans, m-(panelj+ib), n-(panelj+ib), ib, neg_one, dW0_displ, lda, dW1_displ, lda, one, dW2_displ, lda, batchCount); #else cublasSgemmBatched(myhandle, CUBLAS_OP_N, CUBLAS_OP_N, m-(panelj+ib), n-(panelj+ib), ib, &neg_one, (const float**) dW0_displ, lda, (const float**) dW1_displ, lda, &one, dW2_displ, lda, batchCount ); #endif } } //free(cpuAarray); return 0; }
extern "C" magma_int_t magma_sgetf2_batched( magma_int_t m, magma_int_t n, float **dA_array, magma_int_t lda, float **dW0_displ, float **dW1_displ, float **dW2_displ, magma_int_t **ipiv_array, magma_int_t *info_array, magma_int_t gbstep, magma_int_t batchCount, cublasHandle_t myhandle, magma_queue_t queue) { magma_int_t arginfo = 0; if (m < 0) { arginfo = -1; } else if (n < 0 ) { arginfo = -2; } else if (lda < max(1,m)) { arginfo = -4; } if (arginfo != 0) { magma_xerbla( __func__, -(arginfo) ); return arginfo; } // Quick return if possible if (m == 0 || n == 0) { return arginfo; } float neg_one = MAGMA_S_NEG_ONE; float one = MAGMA_S_ONE; magma_int_t nb = BATF2_NB; //float **cpuAarray = (float**) malloc(batchCount*sizeof(float*)); //magma_getvector( batchCount, sizeof(float*), dA_array, 1, cpuAarray, 1); magma_int_t min_mn = min(m, n); magma_int_t gbj, panelj, step, ib; for( panelj=0; panelj < min_mn; panelj+=nb) { ib = min(nb, min_mn-panelj); for(step=0; step < ib; step++){ gbj = panelj+step; //size_t required_shmem_size = zamax*(sizeof(float)+sizeof(int)) + (m-panelj+2)*sizeof(float); //if( (m-panelj) > 0) if( (m-panelj) > MAX_NTHREADS) //if( required_shmem_size > (MAX_SHARED_ALLOWED*1024)) { //printf("running non shared version\n"); // find the max of the column gbj arginfo = magma_isamax_batched(m-gbj, dA_array, 1, gbj, lda, ipiv_array, info_array, gbstep, batchCount, queue); if(arginfo != 0 ) return arginfo; // Apply the interchange to columns 1:N. swap the whole row arginfo = magma_sswap_batched(n, dA_array, lda, gbj, ipiv_array, batchCount, queue); if(arginfo != 0 ) return arginfo; // Compute elements J+1:M of J-th column. if (gbj < m) { arginfo = magma_sscal_sger_batched(m-gbj, ib-step, gbj, dA_array, lda, info_array, gbstep, batchCount, queue); if(arginfo != 0 ) return arginfo; } } else{ //printf("running --- shared version\n"); arginfo = magma_scomputecolumn_batched(m-panelj, panelj, step, dA_array, lda, ipiv_array, info_array, gbstep, batchCount, queue); if(arginfo != 0 ) return arginfo; // Apply the interchange to columns 1:N. swap the whole row arginfo = magma_sswap_batched(n, dA_array, lda, gbj, ipiv_array, batchCount, queue); if(arginfo != 0 ) return arginfo; } } if( (n-panelj-ib) > 0){ // continue the update of the selected ib row column panelj+ib:n(TRSM) magma_sgetf2trsm_batched(ib, n-panelj-ib, dA_array, panelj, lda, batchCount, queue); // do the blocked DGER = DGEMM for the remaining panelj+ib:n columns magma_sdisplace_pointers(dW0_displ, dA_array, lda, ib+panelj, panelj, batchCount, queue); magma_sdisplace_pointers(dW1_displ, dA_array, lda, panelj, ib+panelj, batchCount, queue); magma_sdisplace_pointers(dW2_displ, dA_array, lda, ib+panelj, ib+panelj, batchCount, queue); #if 1 magmablas_sgemm_batched( MagmaNoTrans, MagmaNoTrans, m-(panelj+ib), n-(panelj+ib), ib, neg_one, dW0_displ, lda, dW1_displ, lda, one, dW2_displ, lda, batchCount, queue); #else cublasSgemmBatched(myhandle, CUBLAS_OP_N, CUBLAS_OP_N, m-(panelj+ib), n-(panelj+ib), ib, &neg_one, (const float**) dW0_displ, lda, (const float**) dW1_displ, lda, &one, dW2_displ, lda, batchCount ); #endif } } //free(cpuAarray); return 0; }