Example #1
0
//===================================================================================================================
//===================================================================================================================
//===================================================================================================================
extern "C" void
magma_slarft_sm32x32_batched(magma_int_t n, magma_int_t k, float **v_array, magma_int_t ldv,
                    float **tau_array, float **T_array, magma_int_t ldt, magma_int_t batchCount, cublasHandle_t myhandle, magma_queue_t queue)
{

    if( k <= 0) return;

     //==================================
     //          GEMV
     //==================================
#define USE_GEMV2
#define use_gemm_larft_sm32

#if defined(use_gemm_larft_sm32)
    //magmablas_sgemm_batched( MagmaConjTrans, MagmaNoTrans, k, k, n, MAGMA_S_ONE, v_array, ldv, v_array, ldv, MAGMA_S_ZERO, T_array, ldt, batchCount, queue);
    cublasSgemmBatched(myhandle, CUBLAS_OP_C, CUBLAS_OP_N, k, k, n,
                             &one, (const float**) v_array, ldv,
                                    (const float**) v_array, ldv,
                             &zero,  T_array, ldt, batchCount);

    magmablas_slaset_batched(MagmaLower, k, k, MAGMA_S_ZERO, MAGMA_S_ZERO, T_array, ldt, batchCount, queue);
#else
    #if 1
    for(int i=0; i<k; i++)
    {
        //W(1:i-1) := - tau(i) * V(i:n,1:i-1)' * V(i:n,i)
        //T( i, i ) = tau( i ) 
        //custom implementation.
        #ifdef USE_GEMV2
        magmablas_slarft_gemvrowwise_batched( n-i, i, 
                            tau_array,
                            v_array, ldv, 
                            T_array, ldt,
                            batchCount, queue);
                            
        #else       
        magmablas_slarft_gemvcolwise_batched( n-i, i, v_array, ldv, T_array, ldt, tau_array, batchCount, queue);
        #endif
    }
    #else
        //seems to be very slow when k=32 while the one by one loop above is faster
        slarft_gemv_loop_inside_kernel_batched(n, k, tau_array, v_array, ldv, T_array, ldt, batchCount, queue); 
    #endif
#endif
     //==================================
     //          TRMV
     //==================================
     //T(1:i-1,i) := T(1:i-1,1:i-1) * W(1:i-1) i=[1:k]
     magmablas_slarft_strmv_sm32x32_batched(k, k, tau_array, T_array, ldt, T_array, ldt, batchCount, queue);
}
void caffe_gpu_gemm_batched<float>(const CBLAS_TRANSPOSE TransA,
    const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
    const float alpha, const float** A, const float** B, const float beta,
    float** C,
	int batch_count){
	
	// Note that cublas follows fortran order.
	int lda = (TransA == CblasNoTrans) ? K : M;
	int ldb = (TransB == CblasNoTrans) ? N : K;
	cublasOperation_t cuTransA =
      (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
	cublasOperation_t cuTransB =
      (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
			
	CUBLAS_CHECK(cublasSgemmBatched(Caffe::get_current_cublas_handle(), cuTransB, cuTransA,
      N, M, K, &alpha, B, ldb, A, lda, &beta, C, N,
	  batch_count));
}
Example #3
0
extern "C" magma_int_t
magma_sgetrf_recpanel_batched_q(
    magma_int_t m, magma_int_t n, magma_int_t min_recpnb,    
    float** dA_array,    magma_int_t ldda,
    magma_int_t** dipiv_array, magma_int_t** dpivinfo_array,
    float** dX_array,    magma_int_t dX_length,
    float** dinvA_array, magma_int_t dinvA_length,
    float** dW1_displ, float** dW2_displ,  
    float** dW3_displ, float** dW4_displ,
    float** dW5_displ,
    magma_int_t *info_array, magma_int_t gbstep,  
    magma_int_t batchCount, magma_queue_t stream, cublasHandle_t myhandle)
{

    //magma_int_t DEBUG = 3;
    // Quick return if possible
    if (m ==0 || n == 0) {
        return 0;
    }


    float **dA_displ  = NULL;
    magma_malloc((void**)&dA_displ,   batchCount * sizeof(*dA_displ));
    magma_int_t **dipiv_displ = NULL;
    magma_malloc((void**)&dipiv_displ, batchCount * sizeof(*dipiv_displ));
    
    magma_int_t panel_nb = n;
    if(panel_nb <= min_recpnb){
        //if(DEBUG>0)printf("calling bottom panel recursive with m=%d nb=%d\n",m,n);
        //  panel factorization
        //magma_sdisplace_pointers(dA_displ, dA_array, ldda, 0, 0, batchCount);
        magma_sgetf2_batched(
                           m, panel_nb,
                           dA_array, ldda,
                           dW1_displ, dW2_displ, dW3_displ,
                           dipiv_array, info_array, gbstep, batchCount, myhandle);
    }
    else{
        // split A over two [A A2]
        // panel on A1, update on A2 then panel on A1    
        magma_int_t n1 = n/2;
        magma_int_t n2 = n-n1;
        magma_int_t m1 = m;
        magma_int_t m2 = m-n1;
        magma_int_t p1 = 0;
        magma_int_t p2 = n1;
        // panel on A1
        //if(DEBUG>0)printf("calling recursive panel on A1 with m=%d nb=%d min_recpnb %d\n",m1,n1,min_recpnb);
        magma_sdisplace_pointers(dA_displ, dA_array, ldda, p1, p1, batchCount); 
        magma_idisplace_pointers(dipiv_displ, dipiv_array, 1, p1, 0, batchCount);
        magma_sgetrf_recpanel_batched_q(
                           m1, n1, min_recpnb,
                           dA_displ, ldda,
                           dipiv_displ, dpivinfo_array,
                           dX_array, dX_length,
                           dinvA_array, dinvA_length,
                           dW1_displ, dW2_displ,
                           dW3_displ, dW4_displ, dW5_displ,
                           info_array, gbstep, batchCount, stream, myhandle);

        // update A2
        //if(DEBUG>0)printf("calling TRSM  with             m=%d n=%d \n",m1,n2);
        
        // setup pivinfo 
        setup_pivinfo_batched_q(dpivinfo_array, dipiv_displ, m1, n1, stream, batchCount);
        magma_sdisplace_pointers(dW5_displ, dA_array, ldda, p1, p2, batchCount); 
        magma_slaswp_rowparallel_batched( n2, dW5_displ, ldda,
                           dX_array, n1,
                           0, n1,
                           dpivinfo_array, batchCount);
        magmablas_strsm_outofplace_batched(MagmaLeft, MagmaLower, MagmaNoTrans, MagmaUnit, 1,
                              n1, n2,
                              MAGMA_S_ONE,
                              dA_displ,    ldda, // dA
                              dX_array,  n1, // dB
                              dW5_displ,   ldda, // dX
                              dinvA_array, dinvA_length,
                              dW1_displ,   dW2_displ, 
                              dW3_displ,   dW4_displ,
                              0, batchCount);

        magma_sdisplace_pointers(dW1_displ, dA_array, ldda, p2, 0, batchCount); 
        magma_sdisplace_pointers(dA_displ, dA_array, ldda, p2, p2, batchCount); 

        //if(DEBUG>0)printf("calling update A2(%d,%d) -= A(%d,%d)*A(%d,%d)  with             m=%d n=%d k=%d ldda %d\n",p2,p2,p2,0,p1,p2,m2,n2,n1,ldda);

#if 0
        float neg_one = MAGMA_S_NEG_ONE;
        float one  = MAGMA_S_ONE;
        cublasSgemmBatched(myhandle, CUBLAS_OP_N, CUBLAS_OP_N, m2, n2, n1,
                                     &neg_one, (const float**) dW1_displ, ldda,
                                               (const float**) dW5_displ, ldda,
                                     &one,  dA_displ, ldda, batchCount );


#else

        magmablas_sgemm_batched( MagmaNoTrans, MagmaNoTrans, m2, n2, n1, 
                              MAGMA_S_NEG_ONE, dW1_displ, ldda, 
                              dW5_displ, ldda, 
                              MAGMA_S_ONE,  dA_displ, ldda, 
                              batchCount);
#endif

        // panel on A2
        //if(DEBUG>0)printf("calling recursive panel on A2 with m=%d nb=%d min_recpnb %d\n",m2,n2,min_recpnb);
        magma_idisplace_pointers(dipiv_displ, dipiv_array, 1, p2, 0, batchCount);
        magma_sgetrf_recpanel_batched_q(
                           m2, n2, min_recpnb,
                           dA_displ, ldda,
                           dipiv_displ, dpivinfo_array,
                           dX_array, dX_length,
                           dinvA_array, dinvA_length,
                           dW1_displ, dW2_displ,
                           dW3_displ, dW4_displ, dW5_displ,
                           info_array, gbstep+p2, batchCount, stream, myhandle);

        // setup pivinfo
        setup_pivinfo_batched_q(dpivinfo_array, dipiv_displ, m2, n2, stream, batchCount);
        adjust_ipiv_batched_q(dipiv_displ, n2, n1, magma_stream, batchCount);
        
        magma_sdisplace_pointers(dW1_displ, dA_array, ldda, p2, 0, batchCount); // no need since it is above
        magma_slaswp_rowparallel_batched( n1, dW1_displ, ldda,
                           dW1_displ, ldda,
                           n1, n,
                           dpivinfo_array, batchCount);

        
    }

    magma_free(dA_displ);
    magma_free(dipiv_displ);
    return 0;
}
Example #4
0
extern "C" magma_int_t
magma_slarft_batched(magma_int_t n, magma_int_t k, magma_int_t stair_T, 
                float **v_array, magma_int_t ldv,
                float **tau_array, float **T_array, magma_int_t ldt, 
                float **work_array, magma_int_t lwork, magma_int_t batchCount, cublasHandle_t myhandle, 
                magma_queue_t queue)
{
    if( k <= 0) return 0;
    if( stair_T > 0 && k <= stair_T) return 0;

    magma_int_t maxnb = max_shared_bsiz;

    if( lwork < k*ldt) 
    {
        magma_xerbla( __func__, -(10) );
        return -10;
    }

    if( stair_T > 0 && stair_T > maxnb)
    { 
        magma_xerbla( __func__, -(3) );
        return -3;
    }
    magma_int_t DEBUG=0;
    magma_int_t nb = stair_T == 0 ? min(k,maxnb) : stair_T;

    magma_int_t i, j, prev_n, mycol, rows;

    float **dW1_displ  = NULL;
    float **dW2_displ  = NULL;
    float **dW3_displ  = NULL;
    float **dTstep_array  = NULL;

    magma_malloc((void**)&dW1_displ,  batchCount * sizeof(*dW1_displ));
    magma_malloc((void**)&dW2_displ,  batchCount * sizeof(*dW2_displ));
    magma_malloc((void**)&dW3_displ,  batchCount * sizeof(*dW3_displ));
    magma_malloc((void**)&dTstep_array,  batchCount * sizeof(*dTstep_array));

    //float *Tstep =  k > nb ? work : T;
    if(k > nb)
    {
        magma_sdisplace_pointers(dTstep_array, work_array, lwork, 0, 0, batchCount, queue);
    }
    else
    {
        magma_sdisplace_pointers(dTstep_array, T_array, ldt, 0, 0, batchCount, queue);
    }

    //magma_int_t ldtstep = k > nb ? k : ldt;
    magma_int_t ldtstep = ldt; //a enlever
    // stair_T = 0 meaning all T
    // stair_T > 0 meaning the triangular portion of T has been computed. 
    //                    the value of stair_T is the nb of these triangulars
   

    //GEMV compute the whole triangular upper portion of T (phase 1)
    // TODO addcublas to check perf

#ifdef RFT_MAG_GEM
    magmablas_sgemm_batched( MagmaConjTrans, MagmaNoTrans, 
            k, k, n, 
            one,  v_array, ldv, 
                  v_array, ldv, 
            zero, dTstep_array, ldtstep, 
            batchCount, queue);
#else
    cublasSgemmBatched(myhandle, CUBLAS_OP_C, CUBLAS_OP_N, k, k, n,
                             &one, (const float**) v_array, ldv,
                                    (const float**) v_array, ldv,
                             &zero, dTstep_array, ldtstep, batchCount);
#endif

    magmablas_slaset_batched(MagmaLower, k, k, MAGMA_S_ZERO, MAGMA_S_ZERO, dTstep_array, ldtstep, batchCount, queue);
    // no need for it as T is expected to be lower zero
    //if(k > nb) magmablas_slaset_batched(MagmaLower, k, k, MAGMA_S_ZERO, MAGMA_S_ZERO, dTstep_array, ldtstep, batchCount);
    

    //TRMV
    //T(1:i-1,i) := T(1:i-1,1:i-1) * W(1:i-1) i=[1:k]
    // TRMV is split over block of column of size nb 
    // the update should be done from top to bottom so:
    // 1- a gemm using the previous computed columns
    //    of T to update rectangular upper protion above 
    //    the triangle of my columns 
    // 2- the columns need to be updated by a serial 
    //    loop over of gemv over itself. since we limit the
    //    shared memory to nb, this nb column 
    //    are split vertically by chunk of nb rows

    dim3 grid(1, 1, batchCount);

    for(j=0; j<k; j+=nb)
    {
        prev_n =  j;
        mycol  =  min(nb, k-j);
        // note that myrow = prev_n + mycol;
        if(prev_n>0 && mycol>0){

            if(DEBUG==3) printf("doing gemm on the rectangular portion of size %d %d of T(%d,%d)\n",prev_n,mycol,0,j);

            magma_sdisplace_pointers(dW1_displ, dTstep_array, ldtstep, 0, j, batchCount, queue);
            magma_sdisplace_pointers(dW2_displ, T_array,     ldt, 0, j, batchCount, queue);
#ifdef RFT_MAG_GEM
            magmablas_sgemm_batched( MagmaNoTrans, MagmaNoTrans, 
                    prev_n, mycol, prev_n, 
                    one,  T_array, ldt, 
                          dW1_displ, ldtstep, 
                    zero, dW2_displ, ldt, 
                    batchCount, queue );
#else
            cublasSgemmBatched(myhandle, CUBLAS_OP_N, CUBLAS_OP_N, 
                    prev_n, mycol, prev_n,
                    &one, (const float**) T_array, ldt,
                          (const float**) dW1_displ, ldtstep,
                    &zero, dW2_displ, ldt, batchCount);
#endif

            // update my rectangular portion (prev_n,mycol) using sequence of gemv 
            magma_sdisplace_pointers(dW1_displ, dTstep_array, ldtstep, j, j, batchCount, queue);
            magma_sdisplace_pointers(dW3_displ, tau_array,  1, j, 0, batchCount, queue);

            for(i=0; i<prev_n; i+=nb)
            {
                rows = min(nb,prev_n-i);
                if(DEBUG==3) printf("        doing recstrmv on the rectangular portion of size %d %d of T(%d,%d)\n",rows,mycol,i,j);

                if(rows>0 && mycol>0)
                {
                    magma_sdisplace_pointers(dW2_displ, T_array,     ldt, i, j, batchCount, queue);
                    magmablas_slarft_recstrmv_sm32x32_batched(rows, mycol, dW3_displ, dW2_displ, ldt, dW1_displ, ldtstep, batchCount, queue);
                }
            }
        }

        // the upper rectangular protion is updated, now if needed update the triangular portion
        if(stair_T == 0){
            if(DEBUG==3) printf("doing strmv on the triangular portion of size %d %d of T(%d,%d)\n",mycol,mycol,j,j);

            if(mycol>0)
            {
                magma_sdisplace_pointers(dW1_displ, dTstep_array, ldtstep, j, j, batchCount, queue);
                magma_sdisplace_pointers(dW3_displ, tau_array,  1, j, 0, batchCount, queue);
                magma_sdisplace_pointers(dW2_displ, T_array,     ldt, j, j, batchCount, queue);
                magmablas_slarft_strmv_sm32x32_batched(mycol, mycol, dW3_displ, dW1_displ, ldtstep, dW2_displ, ldt, batchCount, queue);

            }
        }
    }// end of j

    magma_free(dW1_displ);
    magma_free(dW2_displ);
    magma_free(dW3_displ);
    magma_free(dTstep_array);

    return 0;
}
static int sgemmBatch(cb_order order, cb_transpose transA, cb_transpose transB,
                      size_t M, size_t N, size_t K, float alpha,
                      gpudata **A, size_t *offA, size_t lda,
                      gpudata **B, size_t *offB, size_t ldb,
                      float beta, gpudata **C, size_t *offC, size_t ldc,
                      size_t batchCount) {
  cuda_context *ctx;
  size_t *lt, t;
  gpudata **T;
  size_t i;
  cb_transpose transT;
  cublasStatus_t err;

  if (batchCount == 0) return GA_NO_ERROR;

  ASSERT_BUF(A[0]);
  ctx = A[0]->ctx;
  cuda_enter(ctx);

  if (order == cb_c) {
    /* swap A and B */
    t = N;
    N = M;
    M = t;
    T = A;
    A = B;
    B = T;
    t = lda;
    lda = ldb;
    ldb = t;
    transT = transA;
    transA = transB;
    transB = transT;
    lt = offA;
    offA = offB;
    offB = lt;
  }

  // use parallel cublasSgemm calls rather than cublasSgemmBatched for large products
  const size_t threshold = 650;
  const int multiple_dispatch = M * N * K > threshold * threshold * threshold;
  if (multiple_dispatch) {
    for (i = 0; i < batchCount; i++) {
      ASSERT_BUF(A[i]);
      ASSERT_BUF(B[i]);
      ASSERT_BUF(C[i]);
      cuda_wait(A[i], CUDA_WAIT_READ);
      cuda_wait(B[i], CUDA_WAIT_READ);
      cuda_wait(C[i], CUDA_WAIT_READ|CUDA_WAIT_WRITE);

      err = cublasSgemm(((blas_handle *)ctx->blas_handle)->h,
                        convT(transA), convT(transB),
                        M, N, K, &alpha,
                        (float*)A[i]->ptr + offA[i], lda,
                        (float*)B[i]->ptr + offB[i], ldb,
                        &beta,
                        (float*)C[i]->ptr + offC[i], ldc);
      if (err != CUBLAS_STATUS_SUCCESS) {
        cuda_exit(ctx);
        if (err == CUBLAS_STATUS_ARCH_MISMATCH)
          return GA_DEVSUP_ERROR;
        return GA_BLAS_ERROR;
      }

      cuda_record(A[i], CUDA_WAIT_READ);
      cuda_record(B[i], CUDA_WAIT_READ);
      cuda_record(C[i], CUDA_WAIT_READ|CUDA_WAIT_WRITE);
    }
  } else {
    float **T_l = alloca(sizeof(float *) * batchCount * 3);
    const float **A_l = (const float **)T_l;
    const float **B_l = (const float **)T_l + batchCount;
    float **C_l = T_l + (batchCount * 2);
    CUdeviceptr Ta, Aa, Ba, Ca;

    for (i = 0; i < batchCount; i++) {
      ASSERT_BUF(A[i]);
      ASSERT_BUF(B[i]);
      ASSERT_BUF(C[i]);
      cuda_wait(A[i], CUDA_WAIT_READ);
      cuda_wait(B[i], CUDA_WAIT_READ);
      cuda_wait(C[i], CUDA_WAIT_READ|CUDA_WAIT_WRITE);
      A_l[i] = ((float *)A[i]->ptr) + offA[i];
      B_l[i] = ((float *)B[i]->ptr) + offB[i];
      C_l[i] = ((float *)C[i]->ptr) + offC[i];
    }

    cuMemAlloc(&Ta, sizeof(float *) * batchCount * 3);
    Aa = Ta;
    Ba = Ta + (batchCount * sizeof(float *));
    Ca = Ta + (batchCount * sizeof(float *) * 2);

    cuMemcpyHtoD(Ta, T_l, sizeof(float *) * batchCount * 3);

    err = cublasSgemmBatched(((blas_handle *)ctx->blas_handle)->h,
                             convT(transA), convT(transB),
                             M, N, K, &alpha, (const float **)Aa, lda,
                             (const float **)Ba, ldb, &beta,
                             (float **)Ca, ldc, batchCount);
    cuMemFree(Ta);
    if (err != CUBLAS_STATUS_SUCCESS) {
      cuda_exit(ctx);
      if (err == CUBLAS_STATUS_ARCH_MISMATCH)
        return GA_DEVSUP_ERROR;
      return GA_BLAS_ERROR;
    }

    for (i = 0; i < batchCount; i++) {
      cuda_record(A[i], CUDA_WAIT_READ);
      cuda_record(B[i], CUDA_WAIT_READ);
      cuda_record(C[i], CUDA_WAIT_READ|CUDA_WAIT_WRITE);
    }
  }

  cuda_exit(ctx);
  return GA_NO_ERROR;
}
Example #6
0
extern "C" magma_int_t
magma_sgetf2_nopiv_batched(
    magma_int_t m, magma_int_t n,
    float **dA_array, magma_int_t lda,
    float **dW0_displ,
    float **dW1_displ,
    float **dW2_displ,
    magma_int_t *info_array,            
    magma_int_t gbstep, 
    magma_int_t batchCount,
    cublasHandle_t myhandle)

{

    magma_int_t arginfo = 0;
    if (m < 0) {
        arginfo = -1;
    } else if (n < 0 ) {
        arginfo = -2;
    } else if (lda < max(1,m)) {
        arginfo = -4;
    }

    if (arginfo != 0) {
        magma_xerbla( __func__, -(arginfo) );
        return arginfo;
    }

    // Quick return if possible
    if (m == 0 || n == 0) {
        return arginfo;
    }

    float neg_one = MAGMA_S_NEG_ONE;
    float one  = MAGMA_S_ONE;
    magma_int_t nb = 32;//BATF2_NB;

    
    magma_int_t min_mn = min(m, n);
    magma_int_t gbj, panelj, step, ib;

    for( panelj=0; panelj < min_mn; panelj+=nb) 
    {
        ib = min(nb, min_mn-panelj);

        for(step=0; step < ib; step++){
            gbj = panelj+step;
#if 0
            size_t required_shmem_size = ((m-panelj)*ib)*sizeof(float);
            if( required_shmem_size >  (MAX_SHARED_ALLOWED*1024))
#else
            if( (m-panelj) > 0)
#endif
            {
                // Compute elements J+1:M of J-th column.
                if (gbj < m) {
                    arginfo = magma_sscal_sger_batched(m-gbj, ib-step, gbj, dA_array, lda, info_array, gbstep, batchCount);
                    if(arginfo != 0 ) return arginfo;
                }
            }
            else{
                // TODO
            }
        }


        if( (n-panelj-ib) > 0){
            // continue the update of the selected ib row column panelj+ib:n(TRSM)
            magma_sgetf2trsm_batched(ib, n-panelj-ib, dA_array, panelj, lda, batchCount);
            // do the blocked DGER = DGEMM for the remaining panelj+ib:n columns
            magma_sdisplace_pointers(dW0_displ, dA_array, lda, ib+panelj, panelj, batchCount);
            magma_sdisplace_pointers(dW1_displ, dA_array, lda, panelj, ib+panelj, batchCount);            
            magma_sdisplace_pointers(dW2_displ, dA_array, lda, ib+panelj, ib+panelj, batchCount);


#if 1
            magmablas_sgemm_batched( MagmaNoTrans, MagmaNoTrans, m-(panelj+ib), n-(panelj+ib), ib, 
                                      neg_one, dW0_displ, lda, 
                                      dW1_displ, lda, 
                                      one,  dW2_displ, lda, 
                                      batchCount);
#else
            cublasSgemmBatched(myhandle, CUBLAS_OP_N, CUBLAS_OP_N, m-(panelj+ib), n-(panelj+ib), ib,
                                     &neg_one, (const float**) dW0_displ, lda,
                                               (const float**) dW1_displ, lda,
                                     &one,  dW2_displ, lda, batchCount );
#endif
        }
    }

    //free(cpuAarray);

    return 0;

}
Example #7
0
extern "C" magma_int_t
magma_sgetf2_batched(
    magma_int_t m, magma_int_t n,
    float **dA_array, magma_int_t lda,
    float **dW0_displ,
    float **dW1_displ,
    float **dW2_displ,
    magma_int_t **ipiv_array,
    magma_int_t *info_array,
    magma_int_t gbstep,          
    magma_int_t batchCount, 
    cublasHandle_t myhandle, magma_queue_t queue)

{

    magma_int_t arginfo = 0;
    if (m < 0) {
        arginfo = -1;
    } else if (n < 0 ) {
        arginfo = -2;
    } else if (lda < max(1,m)) {
        arginfo = -4;
    }

    if (arginfo != 0) {
        magma_xerbla( __func__, -(arginfo) );
        return arginfo;
    }

    // Quick return if possible
    if (m == 0 || n == 0) {
        return arginfo;
    }

    float neg_one = MAGMA_S_NEG_ONE;
    float one  = MAGMA_S_ONE;
    magma_int_t nb = BATF2_NB;

    

    //float **cpuAarray = (float**) malloc(batchCount*sizeof(float*));
    //magma_getvector( batchCount, sizeof(float*), dA_array, 1, cpuAarray, 1);


    magma_int_t min_mn = min(m, n);
    magma_int_t gbj, panelj, step, ib;

    for( panelj=0; panelj < min_mn; panelj+=nb) 
    {
        ib = min(nb, min_mn-panelj);

        for(step=0; step < ib; step++){
            gbj = panelj+step;
            //size_t required_shmem_size = zamax*(sizeof(float)+sizeof(int)) + (m-panelj+2)*sizeof(float);
            //if( (m-panelj) > 0)
            if( (m-panelj) > MAX_NTHREADS)
            //if( required_shmem_size >  (MAX_SHARED_ALLOWED*1024))
            {
                //printf("running non shared version\n");
                // find the max of the column gbj
                arginfo = magma_isamax_batched(m-gbj, dA_array, 1, gbj, lda, ipiv_array, info_array, gbstep, batchCount, queue);
                if(arginfo != 0 ) return arginfo;
                // Apply the interchange to columns 1:N. swap the whole row
                arginfo = magma_sswap_batched(n, dA_array, lda, gbj, ipiv_array, batchCount, queue);
                if(arginfo != 0 ) return arginfo;
                // Compute elements J+1:M of J-th column.
                if (gbj < m) {
                    arginfo = magma_sscal_sger_batched(m-gbj, ib-step, gbj, dA_array, lda, info_array, gbstep, batchCount, queue);
                    if(arginfo != 0 ) return arginfo;
                }
            }
            else{
                //printf("running --- shared version\n");                
                arginfo = magma_scomputecolumn_batched(m-panelj, panelj, step, dA_array, lda, ipiv_array, info_array, gbstep, batchCount, queue);
                if(arginfo != 0 ) return arginfo;
                // Apply the interchange to columns 1:N. swap the whole row
                arginfo = magma_sswap_batched(n, dA_array, lda, gbj, ipiv_array, batchCount, queue);
                if(arginfo != 0 ) return arginfo;
            }
        }


        if( (n-panelj-ib) > 0){
            // continue the update of the selected ib row column panelj+ib:n(TRSM)
            magma_sgetf2trsm_batched(ib, n-panelj-ib, dA_array, panelj, lda, batchCount, queue);
            // do the blocked DGER = DGEMM for the remaining panelj+ib:n columns
            magma_sdisplace_pointers(dW0_displ, dA_array, lda, ib+panelj, panelj, batchCount, queue);
            magma_sdisplace_pointers(dW1_displ, dA_array, lda, panelj, ib+panelj, batchCount, queue);            
            magma_sdisplace_pointers(dW2_displ, dA_array, lda, ib+panelj, ib+panelj, batchCount, queue);


#if 1
            magmablas_sgemm_batched( MagmaNoTrans, MagmaNoTrans, m-(panelj+ib), n-(panelj+ib), ib, 
                                      neg_one, dW0_displ, lda, 
                                      dW1_displ, lda, 
                                      one,  dW2_displ, lda, 
                                      batchCount, queue);
#else
            cublasSgemmBatched(myhandle, CUBLAS_OP_N, CUBLAS_OP_N, m-(panelj+ib), n-(panelj+ib), ib,
                                     &neg_one, (const float**) dW0_displ, lda,
                                               (const float**) dW1_displ, lda,
                                     &one,  dW2_displ, lda, batchCount );
#endif
        }
    }

    //free(cpuAarray);

    return 0;

}