Exemple #1
0
extern "C" magma_int_t
magma_sgeqrf_expert_batched(
    magma_int_t m, magma_int_t n, 
    float **dA_array, magma_int_t ldda, 
    float **dR_array, magma_int_t lddr,
    float **dT_array, magma_int_t lddt,
    float **dtau_array, magma_int_t provide_RT,
    magma_int_t *info_array, magma_int_t batchCount, magma_queue_t queue)
{
#define dA(i, j)  (dA + (i) + (j)*ldda)   // A(i, j) means at i row, j column
    
    /* Local Parameter */
    magma_int_t nb = magma_get_sgeqrf_batched_nb(m);
    
    magma_int_t nnb = 8;
    magma_int_t min_mn = min(m, n);

    /* Check arguments */
    cudaMemset(info_array, 0, batchCount*sizeof(magma_int_t));
    magma_int_t arginfo = 0;
    if (m < 0)
        arginfo = -1;
    else if (n < 0)
        arginfo = -2;
    else if (ldda < max(1,m))
        arginfo = -4;
    else if (lddr < min_mn && provide_RT == 1)
        arginfo = -6;
    else if (lddr < min(min_mn, nb))
        arginfo = -6;
    else if (lddt < min(min_mn, nb))
        arginfo = -8;

    if (arginfo != 0) {
        magma_xerbla( __func__, -(arginfo) );
        return arginfo;
    }

    /* Quick return if possible */
    if (m == 0 || n == 0)
        if (min_mn == 0 ) return arginfo;

    if ( m >  2048 || n > 2048 ) {
        printf("=========================================================================================\n");
        printf("   WARNING batched routines are designed for small sizes it might be better to use the\n   Native/Hybrid classical routines if you want performance\n");
        printf("=========================================================================================\n");
    }


    magma_int_t i, k, ib=nb, jb=nnb, offset_RT=0, use_stream;
    magma_int_t ldw, offset; 

    float **dW0_displ = NULL;
    float **dW1_displ = NULL;
    float **dW2_displ = NULL;
    float **dW3_displ = NULL;
    float **dW4_displ = NULL;
    float **dW5_displ = NULL;
    float **dR_displ  = NULL;
    float **dT_displ  = NULL;

    float *dwork = NULL;
    float **cpuAarray = NULL;
    float **cpuTarray = NULL;

    magma_malloc((void**)&dW0_displ, batchCount * sizeof(*dW0_displ));
    magma_malloc((void**)&dW1_displ, batchCount * sizeof(*dW1_displ));
    magma_malloc((void**)&dW2_displ, batchCount * sizeof(*dW2_displ));
    magma_malloc((void**)&dW3_displ, batchCount * sizeof(*dW3_displ));
    magma_malloc((void**)&dW4_displ, batchCount * sizeof(*dW4_displ));
    magma_malloc((void**)&dW5_displ, batchCount * sizeof(*dW5_displ));
    magma_malloc((void**)&dR_displ,  batchCount * sizeof(*dR_displ));
    magma_malloc((void**)&dT_displ,  batchCount * sizeof(*dT_displ));

    magma_smalloc(&dwork,  (2 * nb * n) * batchCount);
    magma_malloc_cpu((void**) &cpuAarray, batchCount*sizeof(float*));
    magma_malloc_cpu((void**) &cpuTarray, batchCount*sizeof(float*));

    /* check allocation */
    if ( dW0_displ == NULL || dW1_displ == NULL || dW2_displ == NULL || 
         dW3_displ == NULL || dW4_displ == NULL || dW5_displ == NULL || 
         dR_displ  == NULL || dT_displ  == NULL || dwork     == NULL ||
         cpuAarray == NULL || cpuTarray == NULL ) {
        magma_free(dW0_displ);
        magma_free(dW1_displ);
        magma_free(dW2_displ);
        magma_free(dW3_displ);
        magma_free(dW4_displ);
        magma_free(dW5_displ);
        magma_free(dR_displ);
        magma_free(dT_displ);
        magma_free(dwork);
        magma_free_cpu(cpuAarray);
        magma_free_cpu(cpuTarray);
        magma_int_t info = MAGMA_ERR_DEVICE_ALLOC;
        magma_xerbla( __func__, -(info) );
        return info;
    }

    magma_sdisplace_pointers(dR_displ, dR_array, lddr, 0, 0, batchCount, queue); 
    magma_sdisplace_pointers(dT_displ, dT_array, lddt, 0, 0, batchCount, queue); 
    // set dR and dT to zero. if provide_RT == 0 only a tile of size nbxnb is used and overwritten at each step
    magmablas_slaset_batched( MagmaFull, lddr, (provide_RT > 0 ? n:min(min_mn,nb)), MAGMA_S_ZERO, MAGMA_S_ZERO, dR_displ, lddr, batchCount, queue ); 
    magmablas_slaset_batched( MagmaFull, lddt, (provide_RT > 0 ? n:min(min_mn,nb)), MAGMA_S_ZERO, MAGMA_S_ZERO, dT_displ, lddt, batchCount, queue );
    /*
    if ( provide_RT > 0 )
    {
        magmablas_slaset_q( MagmaFull, lddr, n*batchCount, MAGMA_S_ZERO, MAGMA_S_ZERO, dR, lddr, queue );
        magmablas_slaset_q( MagmaFull, lddt, n*batchCount, MAGMA_S_ZERO, MAGMA_S_ZERO, dT, lddt, queue );
    }
    else
    {
        magmablas_slaset_q( MagmaFull, lddr, nb*batchCount, MAGMA_S_ZERO, MAGMA_S_ZERO, dR, lddr, queue );
        magmablas_slaset_q( MagmaFull, lddt, nb*batchCount, MAGMA_S_ZERO, MAGMA_S_ZERO, dT, lddt, queue );
    }
    */
    magma_int_t streamid;
    const magma_int_t nbstreams=10;
    magma_queue_t queues[nbstreams];
    for (i=0; i < nbstreams; i++) {
        magma_device_t cdev;
        magma_getdevice( &cdev );
        magma_queue_create( cdev, &queues[i] );
    }
    magma_getvector( batchCount, sizeof(float*), dA_array, 1, cpuAarray, 1, queue);
    magma_getvector( batchCount, sizeof(float*), dT_array, 1, cpuTarray, 1, queue);


    for (i=0; i < min_mn; i += nb)
    {
        ib = min(nb, min_mn-i);  
        //===============================================
        // panel factorization
        //===============================================

        magma_sdisplace_pointers(dW0_displ, dA_array, ldda, i, i, batchCount, queue); 
        magma_sdisplace_pointers(dW2_displ, dtau_array, 1, i, 0, batchCount, queue);
        if ( provide_RT > 0 )
        {
            offset_RT = i;
            magma_sdisplace_pointers(dR_displ, dR_array, lddr, (provide_RT == 1 ? offset_RT:0), offset_RT, batchCount, queue); 
            magma_sdisplace_pointers(dT_displ, dT_array, lddt, 0, offset_RT, batchCount, queue); 
        }

        //dwork is used in panel factorization and trailing matrix update
        //dW4_displ, dW5_displ are used as workspace and configured inside
        magma_sgeqrf_panel_batched(m-i, ib, jb, 
                                   dW0_displ, ldda, 
                                   dW2_displ, 
                                   dT_displ, lddt, 
                                   dR_displ, lddr,
                                   dW1_displ,
                                   dW3_displ,
                                   dwork, 
                                   dW4_displ, dW5_displ,
                                   info_array,
                                   batchCount, queue);
           
        //===============================================
        // end of panel
        //===============================================

        //===============================================
        // update trailing matrix
        //===============================================
        if ( (n-ib-i) > 0)
        {
            //dwork is used in panel factorization and trailing matrix update
            //reset dW4_displ
            ldw = nb;
            magma_sset_pointer( dW4_displ, dwork, 1, 0, 0,  ldw*n, batchCount, queue );
            offset = ldw*n*batchCount;
            magma_sset_pointer( dW5_displ, dwork + offset, 1, 0, 0,  ldw*n, batchCount, queue );    

            // set the diagonal of v as one and the upper triangular part as zero already set inside geqrf_panel
            //magmablas_slaset_batched( MagmaUpper, ib, ib, MAGMA_S_ZERO, MAGMA_S_ONE, dW0_displ, ldda, batchCount, queue ); 
            //magma_sdisplace_pointers(dW2_displ, dtau_array, 1, i, 0, batchCount, queue); 

            // it is faster since it is using BLAS-3 GEMM routines, different from lapack implementation 
            magma_slarft_batched(m-i, ib, 0,
                             dW0_displ, ldda,
                             dW2_displ,
                             dT_displ, lddt, 
                             dW4_displ, nb*lddt,
                             batchCount, queue);

            
            // perform C = (I-V T^H V^H) * C, C is the trailing matrix
            //-------------------------------------------
            //          USE STREAM  GEMM
            //-------------------------------------------
            use_stream = magma_srecommend_cublas_gemm_stream(MagmaNoTrans, MagmaNoTrans, m-i-ib, n-i-ib, ib);
            if ( use_stream )   
            { 
                magma_queue_sync(queue); 
                for (k=0; k < batchCount; k++)
                {
                    streamid = k%nbstreams;                                       
                    // the queue gemm must take cpu pointer 
                    magma_slarfb_gpu_gemm( MagmaLeft, MagmaConjTrans, MagmaForward, MagmaColumnwise,
                                m-i, n-i-ib, ib,
                                cpuAarray[k] + i + i * ldda, ldda, 
                                cpuTarray[k] + offset_RT*lddt, lddt,
                                cpuAarray[k] + i + (i+ib) * ldda, ldda,
                                dwork + nb * n * k, -1,
                                dwork + nb * n * batchCount + nb * n * k, -1, queues[streamid] );
                }

                // need to synchronise to be sure that panel does not start before
                // finishing the update at least of the next panel
                // if queue is NULL, no need to sync
                if ( queue != NULL ) {
                    for (magma_int_t s=0; s < nbstreams; s++)
                        magma_queue_sync(queues[s]);
                }
            }
            //-------------------------------------------
            //          USE BATCHED GEMM
            //-------------------------------------------
            else
            {
                //direct trailing matrix in dW1_displ
                magma_sdisplace_pointers(dW1_displ, dA_array, ldda, i, i+ib, batchCount, queue); 

                magma_slarfb_gemm_batched( 
                            MagmaLeft, MagmaConjTrans, MagmaForward, MagmaColumnwise, 
                            m-i, n-i-ib, ib,
                            (const float**)dW0_displ, ldda,
                            (const float**)dT_displ, lddt,
                            dW1_displ,  ldda,
                            dW4_displ,  ldw,
                            dW5_displ, ldw,
                            batchCount, queue );
            }
        }// update the trailing matrix 
        //===============================================

        // copy dR back to V after the trailing matrix update, 
        // only when provide_RT=0 otherwise the nbxnb block of V is set to diag=1/0
        // The upper portion of V could be set totaly to 0 here
        if ( provide_RT == 0 )
        {
            magmablas_slacpy_batched( MagmaUpper, ib, ib, dR_displ, lddr, dW0_displ, ldda, batchCount, queue );
        }
    }

    magma_queue_sync(queue);
    for (k=0; k < nbstreams; k++) {
        magma_queue_destroy( queues[k] );
    }
    
    magma_free(dW0_displ);
    magma_free(dW1_displ);
    magma_free(dW2_displ);
    magma_free(dW3_displ);
    magma_free(dW4_displ);
    magma_free(dW5_displ);
    magma_free(dR_displ);
    magma_free(dT_displ);
    magma_free(dwork);
    magma_free_cpu(cpuAarray);
    magma_free_cpu(cpuTarray);

    return arginfo;
}
extern "C" magma_int_t
magma_sgeqrf_panel_batched(
        magma_int_t m, magma_int_t n, magma_int_t nb,    
        float** dA_array,    magma_int_t ldda,
        float** tau_array, 
        float** dT_array, magma_int_t ldt, 
        float** dR_array, magma_int_t ldr,
        float** dW0_displ, 
        float** dW1_displ,
        float   *dwork,  
        float** dW2_displ, 
        float** dW3_displ,
        magma_int_t *info_array,
        magma_int_t batchCount, magma_queue_t queue)
{
    magma_int_t j, jb;
    magma_int_t ldw = nb; 
    magma_int_t minmn = min(m,n); 

    for( j=0; j < minmn; j += nb)
    {
        jb = min(nb, minmn-j);

        magma_sdisplace_pointers(dW0_displ, dA_array, ldda, j, j, batchCount, queue); 
        magma_sdisplace_pointers(dW2_displ, tau_array, 1, j, 0, batchCount, queue);
        magma_sdisplace_pointers(dW3_displ, dR_array, ldr, j, j, batchCount, queue); // 

        //sub-panel factorization 
        magma_sgeqr2_batched(
                m-j, jb,
                dW0_displ, ldda,      
                dW2_displ, 
                info_array, 
                batchCount,
                queue);

        //copy th whole rectangular n,jb from of dA to dR (it's lower portion (which is V's) will be set to zero if needed at the end)
        magma_sdisplace_pointers(dW0_displ, dA_array, ldda, 0, j, batchCount, queue); 
        magma_sdisplace_pointers(dW3_displ, dR_array, ldr, 0, j, batchCount, queue); 
        magmablas_slacpy_batched( MagmaFull, minmn, jb, dW0_displ, ldda, dW3_displ, ldr, batchCount, queue );

        //set the upper jbxjb portion of V dA(j,j) to 1/0s (note that the rectangular on the top of this triangular of V still non zero but has been copied to dR).
        magma_sdisplace_pointers(dW0_displ, dA_array, ldda, j, j, batchCount, queue); 
        magmablas_slaset_batched( MagmaUpper, jb, jb, MAGMA_S_ZERO, MAGMA_S_ONE, dW0_displ, ldda, batchCount, queue ); 

        if ( (n-j-jb) > 0) //update the trailing matrix inside the panel
        {
            magma_slarft_sm32x32_batched(m-j, jb,
                    dW0_displ, ldda,
                    dW2_displ,
                    dT_array, ldt, 
                    batchCount, queue);

            magma_sdisplace_pointers( dW1_displ, dA_array, ldda, j, j + jb, batchCount, queue );
            magma_sset_pointer( dW2_displ,  dwork, 1, 0, 0,  ldw*n, batchCount, queue );
            magma_sset_pointer( dW3_displ, dwork + ldw*n*batchCount, 1, 0, 0,  ldw*n, batchCount, queue );

            magma_slarfb_gemm_batched( 
                    MagmaLeft, MagmaConjTrans, MagmaForward, MagmaColumnwise,
                    m-j, n-j-jb, jb,
                    (const float**)dW0_displ, ldda,
                    (const float**)dT_array, ldt,
                    dW1_displ,  ldda,
                    dW2_displ,  ldw, 
                    dW3_displ, ldw,
                    batchCount, queue );
        }
    }

    // copy the remaining portion of dR from dA in case m < n
    if ( m < n )
    {
        magma_sdisplace_pointers(dW0_displ, dA_array, ldda, 0, minmn, batchCount, queue); 
        magma_sdisplace_pointers(dW3_displ, dR_array, ldr, 0, minmn, batchCount, queue); 
        magmablas_slacpy_batched( MagmaFull, minmn, n-minmn, dW0_displ, ldda, dW3_displ, ldr, batchCount, queue );
    }
    // to be consistent set the whole upper nbxnb of V to 0/1s, in this case no need to set it inside sgeqrf_batched
    magma_sdisplace_pointers(dW0_displ, dA_array, ldda, 0, 0, batchCount, queue); 
    magmablas_slaset_batched( MagmaUpper, minmn, n, MAGMA_S_ZERO, MAGMA_S_ONE, dW0_displ, ldda, batchCount, queue ); 

    return MAGMA_SUCCESS;
}