Exemple #1
0
extern "C" magma_int_t
magma_sgetf2_nopiv_batched(
    magma_int_t m, magma_int_t n,
    float **dA_array, magma_int_t ldda,
    float **dW0_displ,
    float **dW1_displ,
    float **dW2_displ,
    magma_int_t *info_array,            
    magma_int_t gbstep, 
    magma_int_t batchCount, magma_queue_t queue)
{
    magma_int_t arginfo = 0;
    if (m < 0) {
        arginfo = -1;
    } else if (n < 0 ) {
        arginfo = -2;
    } else if (ldda < max(1,m)) {
        arginfo = -4;
    }

    if (arginfo != 0) {
        magma_xerbla( __func__, -(arginfo) );
        return arginfo;
    }

    // Quick return if possible
    if (m == 0 || n == 0) {
        return arginfo;
    }

    float c_neg_one = MAGMA_S_NEG_ONE;
    float c_one     = MAGMA_S_ONE;
    magma_int_t nb = BATF2_NB;

    
    magma_int_t min_mn = min(m, n);
    magma_int_t gbj, panelj, step, ib;

    for( panelj=0; panelj < min_mn; panelj += nb) 
    {
        ib = min(nb, min_mn-panelj);

        for (step=0; step < ib; step++) {
            gbj = panelj+step;
#if 0
            size_t required_shmem_size = ((m-panelj)*ib)*sizeof(float);
            if ( required_shmem_size >  (MAX_SHARED_ALLOWED*1024))
#else
            if ( (m-panelj) > 0)
#endif
            {
                // Compute elements J+1:M of J-th column.
                if (gbj < m) {
                    arginfo = magma_sscal_sger_batched( m-gbj, ib-step, gbj, dA_array, ldda, info_array, gbstep, batchCount, queue );
                    if (arginfo != 0 ) return arginfo;
                }
            }
            else {
                // TODO
            }
        }


        if ( (n-panelj-ib) > 0) {
            // continue the update of the selected ib row column panelj+ib:n(TRSM)
            magma_sgetf2trsm_batched(ib, n-panelj-ib, dA_array, panelj, ldda, batchCount, queue);
            // do the blocked DGER = DGEMM for the remaining panelj+ib:n columns
            magma_sdisplace_pointers(dW0_displ, dA_array, ldda, ib+panelj, panelj, batchCount, queue);
            magma_sdisplace_pointers(dW1_displ, dA_array, ldda, panelj, ib+panelj, batchCount, queue);            
            magma_sdisplace_pointers(dW2_displ, dA_array, ldda, ib+panelj, ib+panelj, batchCount, queue);

            magma_sgemm_batched( MagmaNoTrans, MagmaNoTrans, m-(panelj+ib), n-(panelj+ib), ib, 
                                 c_neg_one, dW0_displ, ldda, 
                                            dW1_displ, ldda, 
                                 c_one,     dW2_displ, ldda, 
                                 batchCount, queue );
        }
    }

    //magma_free_cpu(cpuAarray);

    return 0;
}
Exemple #2
0
extern "C" magma_int_t
magma_sgetf2_batched(
    magma_int_t m, magma_int_t n,
    float **dA_array, magma_int_t lda,
    float **dW0_displ,
    float **dW1_displ,
    float **dW2_displ,
    magma_int_t **ipiv_array,
    magma_int_t *info_array,
    magma_int_t gbstep,          
    magma_int_t batchCount, 
    cublasHandle_t myhandle, magma_queue_t queue)

{

    magma_int_t arginfo = 0;
    if (m < 0) {
        arginfo = -1;
    } else if (n < 0 ) {
        arginfo = -2;
    } else if (lda < max(1,m)) {
        arginfo = -4;
    }

    if (arginfo != 0) {
        magma_xerbla( __func__, -(arginfo) );
        return arginfo;
    }

    // Quick return if possible
    if (m == 0 || n == 0) {
        return arginfo;
    }

    float neg_one = MAGMA_S_NEG_ONE;
    float one  = MAGMA_S_ONE;
    magma_int_t nb = BATF2_NB;

    

    //float **cpuAarray = (float**) malloc(batchCount*sizeof(float*));
    //magma_getvector( batchCount, sizeof(float*), dA_array, 1, cpuAarray, 1);


    magma_int_t min_mn = min(m, n);
    magma_int_t gbj, panelj, step, ib;

    for( panelj=0; panelj < min_mn; panelj+=nb) 
    {
        ib = min(nb, min_mn-panelj);

        for(step=0; step < ib; step++){
            gbj = panelj+step;
            //size_t required_shmem_size = zamax*(sizeof(float)+sizeof(int)) + (m-panelj+2)*sizeof(float);
            //if( (m-panelj) > 0)
            if( (m-panelj) > MAX_NTHREADS)
            //if( required_shmem_size >  (MAX_SHARED_ALLOWED*1024))
            {
                //printf("running non shared version\n");
                // find the max of the column gbj
                arginfo = magma_isamax_batched(m-gbj, dA_array, 1, gbj, lda, ipiv_array, info_array, gbstep, batchCount, queue);
                if(arginfo != 0 ) return arginfo;
                // Apply the interchange to columns 1:N. swap the whole row
                arginfo = magma_sswap_batched(n, dA_array, lda, gbj, ipiv_array, batchCount, queue);
                if(arginfo != 0 ) return arginfo;
                // Compute elements J+1:M of J-th column.
                if (gbj < m) {
                    arginfo = magma_sscal_sger_batched(m-gbj, ib-step, gbj, dA_array, lda, info_array, gbstep, batchCount, queue);
                    if(arginfo != 0 ) return arginfo;
                }
            }
            else{
                //printf("running --- shared version\n");                
                arginfo = magma_scomputecolumn_batched(m-panelj, panelj, step, dA_array, lda, ipiv_array, info_array, gbstep, batchCount, queue);
                if(arginfo != 0 ) return arginfo;
                // Apply the interchange to columns 1:N. swap the whole row
                arginfo = magma_sswap_batched(n, dA_array, lda, gbj, ipiv_array, batchCount, queue);
                if(arginfo != 0 ) return arginfo;
            }
        }


        if( (n-panelj-ib) > 0){
            // continue the update of the selected ib row column panelj+ib:n(TRSM)
            magma_sgetf2trsm_batched(ib, n-panelj-ib, dA_array, panelj, lda, batchCount, queue);
            // do the blocked DGER = DGEMM for the remaining panelj+ib:n columns
            magma_sdisplace_pointers(dW0_displ, dA_array, lda, ib+panelj, panelj, batchCount, queue);
            magma_sdisplace_pointers(dW1_displ, dA_array, lda, panelj, ib+panelj, batchCount, queue);            
            magma_sdisplace_pointers(dW2_displ, dA_array, lda, ib+panelj, ib+panelj, batchCount, queue);


#if 1
            magmablas_sgemm_batched( MagmaNoTrans, MagmaNoTrans, m-(panelj+ib), n-(panelj+ib), ib, 
                                      neg_one, dW0_displ, lda, 
                                      dW1_displ, lda, 
                                      one,  dW2_displ, lda, 
                                      batchCount, queue);
#else
            cublasSgemmBatched(myhandle, CUBLAS_OP_N, CUBLAS_OP_N, m-(panelj+ib), n-(panelj+ib), ib,
                                     &neg_one, (const float**) dW0_displ, lda,
                                               (const float**) dW1_displ, lda,
                                     &one,  dW2_displ, lda, batchCount );
#endif
        }
    }

    //free(cpuAarray);

    return 0;

}