Esempio n. 1
0
extern "C" magma_int_t
magma_sgetrf_recpanel_batched_q(
    magma_int_t m, magma_int_t n, magma_int_t min_recpnb,    
    float** dA_array,    magma_int_t ldda,
    magma_int_t** dipiv_array, magma_int_t** dpivinfo_array,
    float** dX_array,    magma_int_t dX_length,
    float** dinvA_array, magma_int_t dinvA_length,
    float** dW1_displ, float** dW2_displ,  
    float** dW3_displ, float** dW4_displ,
    float** dW5_displ,
    magma_int_t *info_array, magma_int_t gbstep,  
    magma_int_t batchCount, magma_queue_t stream, cublasHandle_t myhandle)
{

    //magma_int_t DEBUG = 3;
    // Quick return if possible
    if (m ==0 || n == 0) {
        return 0;
    }


    float **dA_displ  = NULL;
    magma_malloc((void**)&dA_displ,   batchCount * sizeof(*dA_displ));
    magma_int_t **dipiv_displ = NULL;
    magma_malloc((void**)&dipiv_displ, batchCount * sizeof(*dipiv_displ));
    
    magma_int_t panel_nb = n;
    if(panel_nb <= min_recpnb){
        //if(DEBUG>0)printf("calling bottom panel recursive with m=%d nb=%d\n",m,n);
        //  panel factorization
        //magma_sdisplace_pointers(dA_displ, dA_array, ldda, 0, 0, batchCount);
        magma_sgetf2_batched(
                           m, panel_nb,
                           dA_array, ldda,
                           dW1_displ, dW2_displ, dW3_displ,
                           dipiv_array, info_array, gbstep, batchCount, myhandle);
    }
    else{
        // split A over two [A A2]
        // panel on A1, update on A2 then panel on A1    
        magma_int_t n1 = n/2;
        magma_int_t n2 = n-n1;
        magma_int_t m1 = m;
        magma_int_t m2 = m-n1;
        magma_int_t p1 = 0;
        magma_int_t p2 = n1;
        // panel on A1
        //if(DEBUG>0)printf("calling recursive panel on A1 with m=%d nb=%d min_recpnb %d\n",m1,n1,min_recpnb);
        magma_sdisplace_pointers(dA_displ, dA_array, ldda, p1, p1, batchCount); 
        magma_idisplace_pointers(dipiv_displ, dipiv_array, 1, p1, 0, batchCount);
        magma_sgetrf_recpanel_batched_q(
                           m1, n1, min_recpnb,
                           dA_displ, ldda,
                           dipiv_displ, dpivinfo_array,
                           dX_array, dX_length,
                           dinvA_array, dinvA_length,
                           dW1_displ, dW2_displ,
                           dW3_displ, dW4_displ, dW5_displ,
                           info_array, gbstep, batchCount, stream, myhandle);

        // update A2
        //if(DEBUG>0)printf("calling TRSM  with             m=%d n=%d \n",m1,n2);
        
        // setup pivinfo 
        setup_pivinfo_batched_q(dpivinfo_array, dipiv_displ, m1, n1, stream, batchCount);
        magma_sdisplace_pointers(dW5_displ, dA_array, ldda, p1, p2, batchCount); 
        magma_slaswp_rowparallel_batched( n2, dW5_displ, ldda,
                           dX_array, n1,
                           0, n1,
                           dpivinfo_array, batchCount);
        magmablas_strsm_outofplace_batched(MagmaLeft, MagmaLower, MagmaNoTrans, MagmaUnit, 1,
                              n1, n2,
                              MAGMA_S_ONE,
                              dA_displ,    ldda, // dA
                              dX_array,  n1, // dB
                              dW5_displ,   ldda, // dX
                              dinvA_array, dinvA_length,
                              dW1_displ,   dW2_displ, 
                              dW3_displ,   dW4_displ,
                              0, batchCount);

        magma_sdisplace_pointers(dW1_displ, dA_array, ldda, p2, 0, batchCount); 
        magma_sdisplace_pointers(dA_displ, dA_array, ldda, p2, p2, batchCount); 

        //if(DEBUG>0)printf("calling update A2(%d,%d) -= A(%d,%d)*A(%d,%d)  with             m=%d n=%d k=%d ldda %d\n",p2,p2,p2,0,p1,p2,m2,n2,n1,ldda);

#if 0
        float neg_one = MAGMA_S_NEG_ONE;
        float one  = MAGMA_S_ONE;
        cublasSgemmBatched(myhandle, CUBLAS_OP_N, CUBLAS_OP_N, m2, n2, n1,
                                     &neg_one, (const float**) dW1_displ, ldda,
                                               (const float**) dW5_displ, ldda,
                                     &one,  dA_displ, ldda, batchCount );


#else

        magmablas_sgemm_batched( MagmaNoTrans, MagmaNoTrans, m2, n2, n1, 
                              MAGMA_S_NEG_ONE, dW1_displ, ldda, 
                              dW5_displ, ldda, 
                              MAGMA_S_ONE,  dA_displ, ldda, 
                              batchCount);
#endif

        // panel on A2
        //if(DEBUG>0)printf("calling recursive panel on A2 with m=%d nb=%d min_recpnb %d\n",m2,n2,min_recpnb);
        magma_idisplace_pointers(dipiv_displ, dipiv_array, 1, p2, 0, batchCount);
        magma_sgetrf_recpanel_batched_q(
                           m2, n2, min_recpnb,
                           dA_displ, ldda,
                           dipiv_displ, dpivinfo_array,
                           dX_array, dX_length,
                           dinvA_array, dinvA_length,
                           dW1_displ, dW2_displ,
                           dW3_displ, dW4_displ, dW5_displ,
                           info_array, gbstep+p2, batchCount, stream, myhandle);

        // setup pivinfo
        setup_pivinfo_batched_q(dpivinfo_array, dipiv_displ, m2, n2, stream, batchCount);
        adjust_ipiv_batched_q(dipiv_displ, n2, n1, magma_stream, batchCount);
        
        magma_sdisplace_pointers(dW1_displ, dA_array, ldda, p2, 0, batchCount); // no need since it is above
        magma_slaswp_rowparallel_batched( n1, dW1_displ, ldda,
                           dW1_displ, ldda,
                           n1, n,
                           dpivinfo_array, batchCount);

        
    }

    magma_free(dA_displ);
    magma_free(dipiv_displ);
    return 0;
}
Esempio n. 2
0
extern "C" magma_int_t
magma_slarft_batched(magma_int_t n, magma_int_t k, magma_int_t stair_T, 
                float **v_array, magma_int_t ldv,
                float **tau_array, float **T_array, magma_int_t ldt, 
                float **work_array, magma_int_t lwork, magma_int_t batchCount, cublasHandle_t myhandle, 
                magma_queue_t queue)
{
    if( k <= 0) return 0;
    if( stair_T > 0 && k <= stair_T) return 0;

    magma_int_t maxnb = max_shared_bsiz;

    if( lwork < k*ldt) 
    {
        magma_xerbla( __func__, -(10) );
        return -10;
    }

    if( stair_T > 0 && stair_T > maxnb)
    { 
        magma_xerbla( __func__, -(3) );
        return -3;
    }
    magma_int_t DEBUG=0;
    magma_int_t nb = stair_T == 0 ? min(k,maxnb) : stair_T;

    magma_int_t i, j, prev_n, mycol, rows;

    float **dW1_displ  = NULL;
    float **dW2_displ  = NULL;
    float **dW3_displ  = NULL;
    float **dTstep_array  = NULL;

    magma_malloc((void**)&dW1_displ,  batchCount * sizeof(*dW1_displ));
    magma_malloc((void**)&dW2_displ,  batchCount * sizeof(*dW2_displ));
    magma_malloc((void**)&dW3_displ,  batchCount * sizeof(*dW3_displ));
    magma_malloc((void**)&dTstep_array,  batchCount * sizeof(*dTstep_array));

    //float *Tstep =  k > nb ? work : T;
    if(k > nb)
    {
        magma_sdisplace_pointers(dTstep_array, work_array, lwork, 0, 0, batchCount, queue);
    }
    else
    {
        magma_sdisplace_pointers(dTstep_array, T_array, ldt, 0, 0, batchCount, queue);
    }

    //magma_int_t ldtstep = k > nb ? k : ldt;
    magma_int_t ldtstep = ldt; //a enlever
    // stair_T = 0 meaning all T
    // stair_T > 0 meaning the triangular portion of T has been computed. 
    //                    the value of stair_T is the nb of these triangulars
   

    //GEMV compute the whole triangular upper portion of T (phase 1)
    // TODO addcublas to check perf

#ifdef RFT_MAG_GEM
    magmablas_sgemm_batched( MagmaConjTrans, MagmaNoTrans, 
            k, k, n, 
            one,  v_array, ldv, 
                  v_array, ldv, 
            zero, dTstep_array, ldtstep, 
            batchCount, queue);
#else
    cublasSgemmBatched(myhandle, CUBLAS_OP_C, CUBLAS_OP_N, k, k, n,
                             &one, (const float**) v_array, ldv,
                                    (const float**) v_array, ldv,
                             &zero, dTstep_array, ldtstep, batchCount);
#endif

    magmablas_slaset_batched(MagmaLower, k, k, MAGMA_S_ZERO, MAGMA_S_ZERO, dTstep_array, ldtstep, batchCount, queue);
    // no need for it as T is expected to be lower zero
    //if(k > nb) magmablas_slaset_batched(MagmaLower, k, k, MAGMA_S_ZERO, MAGMA_S_ZERO, dTstep_array, ldtstep, batchCount);
    

    //TRMV
    //T(1:i-1,i) := T(1:i-1,1:i-1) * W(1:i-1) i=[1:k]
    // TRMV is split over block of column of size nb 
    // the update should be done from top to bottom so:
    // 1- a gemm using the previous computed columns
    //    of T to update rectangular upper protion above 
    //    the triangle of my columns 
    // 2- the columns need to be updated by a serial 
    //    loop over of gemv over itself. since we limit the
    //    shared memory to nb, this nb column 
    //    are split vertically by chunk of nb rows

    dim3 grid(1, 1, batchCount);

    for(j=0; j<k; j+=nb)
    {
        prev_n =  j;
        mycol  =  min(nb, k-j);
        // note that myrow = prev_n + mycol;
        if(prev_n>0 && mycol>0){

            if(DEBUG==3) printf("doing gemm on the rectangular portion of size %d %d of T(%d,%d)\n",prev_n,mycol,0,j);

            magma_sdisplace_pointers(dW1_displ, dTstep_array, ldtstep, 0, j, batchCount, queue);
            magma_sdisplace_pointers(dW2_displ, T_array,     ldt, 0, j, batchCount, queue);
#ifdef RFT_MAG_GEM
            magmablas_sgemm_batched( MagmaNoTrans, MagmaNoTrans, 
                    prev_n, mycol, prev_n, 
                    one,  T_array, ldt, 
                          dW1_displ, ldtstep, 
                    zero, dW2_displ, ldt, 
                    batchCount, queue );
#else
            cublasSgemmBatched(myhandle, CUBLAS_OP_N, CUBLAS_OP_N, 
                    prev_n, mycol, prev_n,
                    &one, (const float**) T_array, ldt,
                          (const float**) dW1_displ, ldtstep,
                    &zero, dW2_displ, ldt, batchCount);
#endif

            // update my rectangular portion (prev_n,mycol) using sequence of gemv 
            magma_sdisplace_pointers(dW1_displ, dTstep_array, ldtstep, j, j, batchCount, queue);
            magma_sdisplace_pointers(dW3_displ, tau_array,  1, j, 0, batchCount, queue);

            for(i=0; i<prev_n; i+=nb)
            {
                rows = min(nb,prev_n-i);
                if(DEBUG==3) printf("        doing recstrmv on the rectangular portion of size %d %d of T(%d,%d)\n",rows,mycol,i,j);

                if(rows>0 && mycol>0)
                {
                    magma_sdisplace_pointers(dW2_displ, T_array,     ldt, i, j, batchCount, queue);
                    magmablas_slarft_recstrmv_sm32x32_batched(rows, mycol, dW3_displ, dW2_displ, ldt, dW1_displ, ldtstep, batchCount, queue);
                }
            }
        }

        // the upper rectangular protion is updated, now if needed update the triangular portion
        if(stair_T == 0){
            if(DEBUG==3) printf("doing strmv on the triangular portion of size %d %d of T(%d,%d)\n",mycol,mycol,j,j);

            if(mycol>0)
            {
                magma_sdisplace_pointers(dW1_displ, dTstep_array, ldtstep, j, j, batchCount, queue);
                magma_sdisplace_pointers(dW3_displ, tau_array,  1, j, 0, batchCount, queue);
                magma_sdisplace_pointers(dW2_displ, T_array,     ldt, j, j, batchCount, queue);
                magmablas_slarft_strmv_sm32x32_batched(mycol, mycol, dW3_displ, dW1_displ, ldtstep, dW2_displ, ldt, batchCount, queue);

            }
        }
    }// end of j

    magma_free(dW1_displ);
    magma_free(dW2_displ);
    magma_free(dW3_displ);
    magma_free(dTstep_array);

    return 0;
}
Esempio n. 3
0
extern "C" magma_int_t
magma_sgetrf_recpanel_nopiv_batched_q(
    magma_int_t m, magma_int_t n, magma_int_t min_recpnb,    
    float** dA_array,    magma_int_t ldda,
    float** dX_array,    magma_int_t dX_length,
    float** dinvA_array, magma_int_t dinvA_length,
    float** dW1_displ, float** dW2_displ,  
    float** dW3_displ, float** dW4_displ,
    float** dW5_displ, 
    magma_int_t *info_array, magma_int_t gbstep,
    magma_int_t batchCount, magma_queue_t stream, cublasHandle_t myhandle)
{
    // Quick return if possible
    if (m == 0 || n == 0) {
        return 0;
    }
    magma_int_t arginfo = 0;


    float **dA_displ  = NULL;
    magma_malloc((void**)&dA_displ,   batchCount * sizeof(*dA_displ));
    
    magma_int_t panel_nb = n;
    if(panel_nb <= min_recpnb){
        // if(DEBUG>0)printf("calling bottom panel recursive with m=%d nb=%d\n",m,n);
        //  panel factorization
        //magma_sdisplace_pointers(dA_displ, dA_array, ldda, 0, 0, batchCount);
        arginfo = magma_sgetrf_panel_nopiv_batched_q(
                           m, panel_nb, 
                           dA_array, ldda,
                           dX_array, dX_length,
                           dinvA_array, dinvA_length,
                           dW1_displ, dW2_displ,
                           dW3_displ, dW4_displ, dW5_displ,
                           info_array, gbstep, batchCount, stream, myhandle);
        if (arginfo != 0) return arginfo;
    }
    else{
        // split A over two [A A2]
        // panel on A1, update on A2 then panel on A1    
        magma_int_t n1 = n/2;
        magma_int_t n2 = n-n1;
        magma_int_t m1 = m;
        magma_int_t m2 = m-n1;
        magma_int_t p1 = 0;
        magma_int_t p2 = n1;
        // panel on A1
        //printf("calling recursive panel on A1 with m=%d nb=%d min_recpnb %d\n",m1,n1,min_recpnb);
        magma_sdisplace_pointers(dA_displ, dA_array, ldda, p1, p1, batchCount); 
        arginfo = magma_sgetrf_recpanel_nopiv_batched_q(
                           m1, n1, min_recpnb,
                           dA_displ, ldda,
                           dX_array, dX_length,
                           dinvA_array, dinvA_length,
                           dW1_displ, dW2_displ,
                           dW3_displ, dW4_displ, dW5_displ,
                           info_array, gbstep, batchCount, stream, myhandle);
        if (arginfo != 0) return arginfo;

        // update A2
        //printf("calling update A2 with             m=%d n=%d k=%d\n",m2,n2,n1);
        
        magma_sdisplace_pointers(dW5_displ, dA_array, ldda, p1, p2, batchCount); 
        magmablas_strsm_work_batched(MagmaLeft, MagmaLower, MagmaNoTrans, MagmaUnit, 1,
                              n1, n2,
                              MAGMA_S_ONE,
                              dA_displ,    ldda, // dA
                              dW5_displ,   ldda, // dB
                              dX_array,  n1, // dX
                              dinvA_array, dinvA_length,
                              dW1_displ,   dW2_displ, 
                              dW3_displ,   dW4_displ,
                              1, batchCount);

        magma_sdisplace_pointers(dW1_displ, dA_array, ldda, p2, 0, batchCount); 
        magma_sdisplace_pointers(dA_displ, dA_array, ldda, p2, p2, batchCount); 
        magmablas_sgemm_batched( MagmaNoTrans, MagmaNoTrans, m2, n2, n1, 
                              MAGMA_S_NEG_ONE, dW1_displ, ldda, 
                              dW5_displ, ldda, 
                              MAGMA_S_ONE,  dA_displ, ldda, 
                              batchCount);
        // panel on A2
        //printf("calling recursive panel on A2 with m=%d nb=%d min_recpnb %d\n",m2,n2,min_recpnb);
        arginfo = magma_sgetrf_recpanel_nopiv_batched_q(
                           m2, n2, min_recpnb,
                           dA_displ, ldda,
                           dX_array, dX_length,
                           dinvA_array, dinvA_length,
                           dW1_displ, dW2_displ,
                           dW3_displ, dW4_displ, dW5_displ,
                           info_array, gbstep+p2, batchCount, stream, myhandle);
        if (arginfo != 0) return arginfo;
    }

    magma_free(dA_displ);
    return 0;
}
Esempio n. 4
0
extern "C" magma_int_t
magma_sgetf2_nopiv_batched(
    magma_int_t m, magma_int_t n,
    float **dA_array, magma_int_t lda,
    float **dW0_displ,
    float **dW1_displ,
    float **dW2_displ,
    magma_int_t *info_array,            
    magma_int_t gbstep, 
    magma_int_t batchCount,
    cublasHandle_t myhandle)

{

    magma_int_t arginfo = 0;
    if (m < 0) {
        arginfo = -1;
    } else if (n < 0 ) {
        arginfo = -2;
    } else if (lda < max(1,m)) {
        arginfo = -4;
    }

    if (arginfo != 0) {
        magma_xerbla( __func__, -(arginfo) );
        return arginfo;
    }

    // Quick return if possible
    if (m == 0 || n == 0) {
        return arginfo;
    }

    float neg_one = MAGMA_S_NEG_ONE;
    float one  = MAGMA_S_ONE;
    magma_int_t nb = 32;//BATF2_NB;

    
    magma_int_t min_mn = min(m, n);
    magma_int_t gbj, panelj, step, ib;

    for( panelj=0; panelj < min_mn; panelj+=nb) 
    {
        ib = min(nb, min_mn-panelj);

        for(step=0; step < ib; step++){
            gbj = panelj+step;
#if 0
            size_t required_shmem_size = ((m-panelj)*ib)*sizeof(float);
            if( required_shmem_size >  (MAX_SHARED_ALLOWED*1024))
#else
            if( (m-panelj) > 0)
#endif
            {
                // Compute elements J+1:M of J-th column.
                if (gbj < m) {
                    arginfo = magma_sscal_sger_batched(m-gbj, ib-step, gbj, dA_array, lda, info_array, gbstep, batchCount);
                    if(arginfo != 0 ) return arginfo;
                }
            }
            else{
                // TODO
            }
        }


        if( (n-panelj-ib) > 0){
            // continue the update of the selected ib row column panelj+ib:n(TRSM)
            magma_sgetf2trsm_batched(ib, n-panelj-ib, dA_array, panelj, lda, batchCount);
            // do the blocked DGER = DGEMM for the remaining panelj+ib:n columns
            magma_sdisplace_pointers(dW0_displ, dA_array, lda, ib+panelj, panelj, batchCount);
            magma_sdisplace_pointers(dW1_displ, dA_array, lda, panelj, ib+panelj, batchCount);            
            magma_sdisplace_pointers(dW2_displ, dA_array, lda, ib+panelj, ib+panelj, batchCount);


#if 1
            magmablas_sgemm_batched( MagmaNoTrans, MagmaNoTrans, m-(panelj+ib), n-(panelj+ib), ib, 
                                      neg_one, dW0_displ, lda, 
                                      dW1_displ, lda, 
                                      one,  dW2_displ, lda, 
                                      batchCount);
#else
            cublasSgemmBatched(myhandle, CUBLAS_OP_N, CUBLAS_OP_N, m-(panelj+ib), n-(panelj+ib), ib,
                                     &neg_one, (const float**) dW0_displ, lda,
                                               (const float**) dW1_displ, lda,
                                     &one,  dW2_displ, lda, batchCount );
#endif
        }
    }

    //free(cpuAarray);

    return 0;

}
Esempio n. 5
0
extern "C" magma_int_t
magma_sgetf2_batched(
    magma_int_t m, magma_int_t n,
    float **dA_array, magma_int_t lda,
    float **dW0_displ,
    float **dW1_displ,
    float **dW2_displ,
    magma_int_t **ipiv_array,
    magma_int_t *info_array,
    magma_int_t gbstep,          
    magma_int_t batchCount, 
    cublasHandle_t myhandle, magma_queue_t queue)

{

    magma_int_t arginfo = 0;
    if (m < 0) {
        arginfo = -1;
    } else if (n < 0 ) {
        arginfo = -2;
    } else if (lda < max(1,m)) {
        arginfo = -4;
    }

    if (arginfo != 0) {
        magma_xerbla( __func__, -(arginfo) );
        return arginfo;
    }

    // Quick return if possible
    if (m == 0 || n == 0) {
        return arginfo;
    }

    float neg_one = MAGMA_S_NEG_ONE;
    float one  = MAGMA_S_ONE;
    magma_int_t nb = BATF2_NB;

    

    //float **cpuAarray = (float**) malloc(batchCount*sizeof(float*));
    //magma_getvector( batchCount, sizeof(float*), dA_array, 1, cpuAarray, 1);


    magma_int_t min_mn = min(m, n);
    magma_int_t gbj, panelj, step, ib;

    for( panelj=0; panelj < min_mn; panelj+=nb) 
    {
        ib = min(nb, min_mn-panelj);

        for(step=0; step < ib; step++){
            gbj = panelj+step;
            //size_t required_shmem_size = zamax*(sizeof(float)+sizeof(int)) + (m-panelj+2)*sizeof(float);
            //if( (m-panelj) > 0)
            if( (m-panelj) > MAX_NTHREADS)
            //if( required_shmem_size >  (MAX_SHARED_ALLOWED*1024))
            {
                //printf("running non shared version\n");
                // find the max of the column gbj
                arginfo = magma_isamax_batched(m-gbj, dA_array, 1, gbj, lda, ipiv_array, info_array, gbstep, batchCount, queue);
                if(arginfo != 0 ) return arginfo;
                // Apply the interchange to columns 1:N. swap the whole row
                arginfo = magma_sswap_batched(n, dA_array, lda, gbj, ipiv_array, batchCount, queue);
                if(arginfo != 0 ) return arginfo;
                // Compute elements J+1:M of J-th column.
                if (gbj < m) {
                    arginfo = magma_sscal_sger_batched(m-gbj, ib-step, gbj, dA_array, lda, info_array, gbstep, batchCount, queue);
                    if(arginfo != 0 ) return arginfo;
                }
            }
            else{
                //printf("running --- shared version\n");                
                arginfo = magma_scomputecolumn_batched(m-panelj, panelj, step, dA_array, lda, ipiv_array, info_array, gbstep, batchCount, queue);
                if(arginfo != 0 ) return arginfo;
                // Apply the interchange to columns 1:N. swap the whole row
                arginfo = magma_sswap_batched(n, dA_array, lda, gbj, ipiv_array, batchCount, queue);
                if(arginfo != 0 ) return arginfo;
            }
        }


        if( (n-panelj-ib) > 0){
            // continue the update of the selected ib row column panelj+ib:n(TRSM)
            magma_sgetf2trsm_batched(ib, n-panelj-ib, dA_array, panelj, lda, batchCount, queue);
            // do the blocked DGER = DGEMM for the remaining panelj+ib:n columns
            magma_sdisplace_pointers(dW0_displ, dA_array, lda, ib+panelj, panelj, batchCount, queue);
            magma_sdisplace_pointers(dW1_displ, dA_array, lda, panelj, ib+panelj, batchCount, queue);            
            magma_sdisplace_pointers(dW2_displ, dA_array, lda, ib+panelj, ib+panelj, batchCount, queue);


#if 1
            magmablas_sgemm_batched( MagmaNoTrans, MagmaNoTrans, m-(panelj+ib), n-(panelj+ib), ib, 
                                      neg_one, dW0_displ, lda, 
                                      dW1_displ, lda, 
                                      one,  dW2_displ, lda, 
                                      batchCount, queue);
#else
            cublasSgemmBatched(myhandle, CUBLAS_OP_N, CUBLAS_OP_N, m-(panelj+ib), n-(panelj+ib), ib,
                                     &neg_one, (const float**) dW0_displ, lda,
                                               (const float**) dW1_displ, lda,
                                     &one,  dW2_displ, lda, batchCount );
#endif
        }
    }

    //free(cpuAarray);

    return 0;

}