extern "C" magma_int_t magma_sgeqrf_expert_batched( magma_int_t m, magma_int_t n, float **dA_array, magma_int_t ldda, float **dR_array, magma_int_t lddr, float **dT_array, magma_int_t lddt, float **dtau_array, magma_int_t provide_RT, magma_int_t *info_array, magma_int_t batchCount, magma_queue_t queue) { #define dA(i, j) (dA + (i) + (j)*ldda) // A(i, j) means at i row, j column /* Local Parameter */ magma_int_t nb = magma_get_sgeqrf_batched_nb(m); magma_int_t nnb = 8; magma_int_t min_mn = min(m, n); /* Check arguments */ cudaMemset(info_array, 0, batchCount*sizeof(magma_int_t)); magma_int_t arginfo = 0; if (m < 0) arginfo = -1; else if (n < 0) arginfo = -2; else if (ldda < max(1,m)) arginfo = -4; else if (lddr < min_mn && provide_RT == 1) arginfo = -6; else if (lddr < min(min_mn, nb)) arginfo = -6; else if (lddt < min(min_mn, nb)) arginfo = -8; if (arginfo != 0) { magma_xerbla( __func__, -(arginfo) ); return arginfo; } /* Quick return if possible */ if (m == 0 || n == 0) if (min_mn == 0 ) return arginfo; if ( m > 2048 || n > 2048 ) { printf("=========================================================================================\n"); printf(" WARNING batched routines are designed for small sizes it might be better to use the\n Native/Hybrid classical routines if you want performance\n"); printf("=========================================================================================\n"); } magma_int_t i, k, ib=nb, jb=nnb, offset_RT=0, use_stream; magma_int_t ldw, offset; float **dW0_displ = NULL; float **dW1_displ = NULL; float **dW2_displ = NULL; float **dW3_displ = NULL; float **dW4_displ = NULL; float **dW5_displ = NULL; float **dR_displ = NULL; float **dT_displ = NULL; float *dwork = NULL; float **cpuAarray = NULL; float **cpuTarray = NULL; magma_malloc((void**)&dW0_displ, batchCount * sizeof(*dW0_displ)); magma_malloc((void**)&dW1_displ, batchCount * sizeof(*dW1_displ)); magma_malloc((void**)&dW2_displ, batchCount * sizeof(*dW2_displ)); magma_malloc((void**)&dW3_displ, batchCount * sizeof(*dW3_displ)); magma_malloc((void**)&dW4_displ, batchCount * sizeof(*dW4_displ)); magma_malloc((void**)&dW5_displ, batchCount * sizeof(*dW5_displ)); magma_malloc((void**)&dR_displ, batchCount * sizeof(*dR_displ)); magma_malloc((void**)&dT_displ, batchCount * sizeof(*dT_displ)); magma_smalloc(&dwork, (2 * nb * n) * batchCount); magma_malloc_cpu((void**) &cpuAarray, batchCount*sizeof(float*)); magma_malloc_cpu((void**) &cpuTarray, batchCount*sizeof(float*)); /* check allocation */ if ( dW0_displ == NULL || dW1_displ == NULL || dW2_displ == NULL || dW3_displ == NULL || dW4_displ == NULL || dW5_displ == NULL || dR_displ == NULL || dT_displ == NULL || dwork == NULL || cpuAarray == NULL || cpuTarray == NULL ) { magma_free(dW0_displ); magma_free(dW1_displ); magma_free(dW2_displ); magma_free(dW3_displ); magma_free(dW4_displ); magma_free(dW5_displ); magma_free(dR_displ); magma_free(dT_displ); magma_free(dwork); magma_free_cpu(cpuAarray); magma_free_cpu(cpuTarray); magma_int_t info = MAGMA_ERR_DEVICE_ALLOC; magma_xerbla( __func__, -(info) ); return info; } magma_sdisplace_pointers(dR_displ, dR_array, lddr, 0, 0, batchCount, queue); magma_sdisplace_pointers(dT_displ, dT_array, lddt, 0, 0, batchCount, queue); // set dR and dT to zero. if provide_RT == 0 only a tile of size nbxnb is used and overwritten at each step magmablas_slaset_batched( MagmaFull, lddr, (provide_RT > 0 ? n:min(min_mn,nb)), MAGMA_S_ZERO, MAGMA_S_ZERO, dR_displ, lddr, batchCount, queue ); magmablas_slaset_batched( MagmaFull, lddt, (provide_RT > 0 ? n:min(min_mn,nb)), MAGMA_S_ZERO, MAGMA_S_ZERO, dT_displ, lddt, batchCount, queue ); /* if ( provide_RT > 0 ) { magmablas_slaset_q( MagmaFull, lddr, n*batchCount, MAGMA_S_ZERO, MAGMA_S_ZERO, dR, lddr, queue ); magmablas_slaset_q( MagmaFull, lddt, n*batchCount, MAGMA_S_ZERO, MAGMA_S_ZERO, dT, lddt, queue ); } else { magmablas_slaset_q( MagmaFull, lddr, nb*batchCount, MAGMA_S_ZERO, MAGMA_S_ZERO, dR, lddr, queue ); magmablas_slaset_q( MagmaFull, lddt, nb*batchCount, MAGMA_S_ZERO, MAGMA_S_ZERO, dT, lddt, queue ); } */ magma_int_t streamid; const magma_int_t nbstreams=10; magma_queue_t queues[nbstreams]; for (i=0; i < nbstreams; i++) { magma_device_t cdev; magma_getdevice( &cdev ); magma_queue_create( cdev, &queues[i] ); } magma_getvector( batchCount, sizeof(float*), dA_array, 1, cpuAarray, 1, queue); magma_getvector( batchCount, sizeof(float*), dT_array, 1, cpuTarray, 1, queue); for (i=0; i < min_mn; i += nb) { ib = min(nb, min_mn-i); //=============================================== // panel factorization //=============================================== magma_sdisplace_pointers(dW0_displ, dA_array, ldda, i, i, batchCount, queue); magma_sdisplace_pointers(dW2_displ, dtau_array, 1, i, 0, batchCount, queue); if ( provide_RT > 0 ) { offset_RT = i; magma_sdisplace_pointers(dR_displ, dR_array, lddr, (provide_RT == 1 ? offset_RT:0), offset_RT, batchCount, queue); magma_sdisplace_pointers(dT_displ, dT_array, lddt, 0, offset_RT, batchCount, queue); } //dwork is used in panel factorization and trailing matrix update //dW4_displ, dW5_displ are used as workspace and configured inside magma_sgeqrf_panel_batched(m-i, ib, jb, dW0_displ, ldda, dW2_displ, dT_displ, lddt, dR_displ, lddr, dW1_displ, dW3_displ, dwork, dW4_displ, dW5_displ, info_array, batchCount, queue); //=============================================== // end of panel //=============================================== //=============================================== // update trailing matrix //=============================================== if ( (n-ib-i) > 0) { //dwork is used in panel factorization and trailing matrix update //reset dW4_displ ldw = nb; magma_sset_pointer( dW4_displ, dwork, 1, 0, 0, ldw*n, batchCount, queue ); offset = ldw*n*batchCount; magma_sset_pointer( dW5_displ, dwork + offset, 1, 0, 0, ldw*n, batchCount, queue ); // set the diagonal of v as one and the upper triangular part as zero already set inside geqrf_panel //magmablas_slaset_batched( MagmaUpper, ib, ib, MAGMA_S_ZERO, MAGMA_S_ONE, dW0_displ, ldda, batchCount, queue ); //magma_sdisplace_pointers(dW2_displ, dtau_array, 1, i, 0, batchCount, queue); // it is faster since it is using BLAS-3 GEMM routines, different from lapack implementation magma_slarft_batched(m-i, ib, 0, dW0_displ, ldda, dW2_displ, dT_displ, lddt, dW4_displ, nb*lddt, batchCount, queue); // perform C = (I-V T^H V^H) * C, C is the trailing matrix //------------------------------------------- // USE STREAM GEMM //------------------------------------------- use_stream = magma_srecommend_cublas_gemm_stream(MagmaNoTrans, MagmaNoTrans, m-i-ib, n-i-ib, ib); if ( use_stream ) { magma_queue_sync(queue); for (k=0; k < batchCount; k++) { streamid = k%nbstreams; // the queue gemm must take cpu pointer magma_slarfb_gpu_gemm( MagmaLeft, MagmaConjTrans, MagmaForward, MagmaColumnwise, m-i, n-i-ib, ib, cpuAarray[k] + i + i * ldda, ldda, cpuTarray[k] + offset_RT*lddt, lddt, cpuAarray[k] + i + (i+ib) * ldda, ldda, dwork + nb * n * k, -1, dwork + nb * n * batchCount + nb * n * k, -1, queues[streamid] ); } // need to synchronise to be sure that panel does not start before // finishing the update at least of the next panel // if queue is NULL, no need to sync if ( queue != NULL ) { for (magma_int_t s=0; s < nbstreams; s++) magma_queue_sync(queues[s]); } } //------------------------------------------- // USE BATCHED GEMM //------------------------------------------- else { //direct trailing matrix in dW1_displ magma_sdisplace_pointers(dW1_displ, dA_array, ldda, i, i+ib, batchCount, queue); magma_slarfb_gemm_batched( MagmaLeft, MagmaConjTrans, MagmaForward, MagmaColumnwise, m-i, n-i-ib, ib, (const float**)dW0_displ, ldda, (const float**)dT_displ, lddt, dW1_displ, ldda, dW4_displ, ldw, dW5_displ, ldw, batchCount, queue ); } }// update the trailing matrix //=============================================== // copy dR back to V after the trailing matrix update, // only when provide_RT=0 otherwise the nbxnb block of V is set to diag=1/0 // The upper portion of V could be set totaly to 0 here if ( provide_RT == 0 ) { magmablas_slacpy_batched( MagmaUpper, ib, ib, dR_displ, lddr, dW0_displ, ldda, batchCount, queue ); } } magma_queue_sync(queue); for (k=0; k < nbstreams; k++) { magma_queue_destroy( queues[k] ); } magma_free(dW0_displ); magma_free(dW1_displ); magma_free(dW2_displ); magma_free(dW3_displ); magma_free(dW4_displ); magma_free(dW5_displ); magma_free(dR_displ); magma_free(dT_displ); magma_free(dwork); magma_free_cpu(cpuAarray); magma_free_cpu(cpuTarray); return arginfo; }
extern "C" magma_int_t magma_sgeqrf_panel_batched( magma_int_t m, magma_int_t n, magma_int_t nb, float** dA_array, magma_int_t ldda, float** tau_array, float** dT_array, magma_int_t ldt, float** dR_array, magma_int_t ldr, float** dW0_displ, float** dW1_displ, float *dwork, float** dW2_displ, float** dW3_displ, magma_int_t *info_array, magma_int_t batchCount, magma_queue_t queue) { magma_int_t j, jb; magma_int_t ldw = nb; magma_int_t minmn = min(m,n); for( j=0; j < minmn; j += nb) { jb = min(nb, minmn-j); magma_sdisplace_pointers(dW0_displ, dA_array, ldda, j, j, batchCount, queue); magma_sdisplace_pointers(dW2_displ, tau_array, 1, j, 0, batchCount, queue); magma_sdisplace_pointers(dW3_displ, dR_array, ldr, j, j, batchCount, queue); // //sub-panel factorization magma_sgeqr2_batched( m-j, jb, dW0_displ, ldda, dW2_displ, info_array, batchCount, queue); //copy th whole rectangular n,jb from of dA to dR (it's lower portion (which is V's) will be set to zero if needed at the end) magma_sdisplace_pointers(dW0_displ, dA_array, ldda, 0, j, batchCount, queue); magma_sdisplace_pointers(dW3_displ, dR_array, ldr, 0, j, batchCount, queue); magmablas_slacpy_batched( MagmaFull, minmn, jb, dW0_displ, ldda, dW3_displ, ldr, batchCount, queue ); //set the upper jbxjb portion of V dA(j,j) to 1/0s (note that the rectangular on the top of this triangular of V still non zero but has been copied to dR). magma_sdisplace_pointers(dW0_displ, dA_array, ldda, j, j, batchCount, queue); magmablas_slaset_batched( MagmaUpper, jb, jb, MAGMA_S_ZERO, MAGMA_S_ONE, dW0_displ, ldda, batchCount, queue ); if ( (n-j-jb) > 0) //update the trailing matrix inside the panel { magma_slarft_sm32x32_batched(m-j, jb, dW0_displ, ldda, dW2_displ, dT_array, ldt, batchCount, queue); magma_sdisplace_pointers( dW1_displ, dA_array, ldda, j, j + jb, batchCount, queue ); magma_sset_pointer( dW2_displ, dwork, 1, 0, 0, ldw*n, batchCount, queue ); magma_sset_pointer( dW3_displ, dwork + ldw*n*batchCount, 1, 0, 0, ldw*n, batchCount, queue ); magma_slarfb_gemm_batched( MagmaLeft, MagmaConjTrans, MagmaForward, MagmaColumnwise, m-j, n-j-jb, jb, (const float**)dW0_displ, ldda, (const float**)dT_array, ldt, dW1_displ, ldda, dW2_displ, ldw, dW3_displ, ldw, batchCount, queue ); } } // copy the remaining portion of dR from dA in case m < n if ( m < n ) { magma_sdisplace_pointers(dW0_displ, dA_array, ldda, 0, minmn, batchCount, queue); magma_sdisplace_pointers(dW3_displ, dR_array, ldr, 0, minmn, batchCount, queue); magmablas_slacpy_batched( MagmaFull, minmn, n-minmn, dW0_displ, ldda, dW3_displ, ldr, batchCount, queue ); } // to be consistent set the whole upper nbxnb of V to 0/1s, in this case no need to set it inside sgeqrf_batched magma_sdisplace_pointers(dW0_displ, dA_array, ldda, 0, 0, batchCount, queue); magmablas_slaset_batched( MagmaUpper, minmn, n, MAGMA_S_ZERO, MAGMA_S_ONE, dW0_displ, ldda, batchCount, queue ); return MAGMA_SUCCESS; }