Ejemplo n.º 1
0
arma_hot
arma_pure
inline
eT
op_cdot::direct_cdot(const uword n_elem, const eT* const A, const eT* const B)
  {
  arma_extra_debug_sigprint();
  
  if( n_elem <= 32u )
    {
    return op_cdot::direct_cdot_arma(n_elem, A, B);
    }
  else
    {
    #if defined(ARMA_USE_BLAS)
      {
      arma_extra_debug_print("blas::gemv()");
      
      // using gemv() workaround due to compatibility issues with cdotc() and zdotc()
      
      const char trans   = 'C';
      
      const blas_int m   = blas_int(n_elem);
      const blas_int n   = 1;
      //const blas_int lda = (n_elem > 0) ? blas_int(n_elem) : blas_int(1);
      const blas_int inc = 1;
      
      const eT alpha     = eT(1);
      const eT beta      = eT(0);
      
      eT result[2];  // paranoia: using two elements instead of one
      
      //blas::gemv(&trans, &m, &n, &alpha, A, &lda, B, &inc, &beta, &result[0], &inc);
      blas::gemv(&trans, &m, &n, &alpha, A, &m, B, &inc, &beta, &result[0], &inc);
      
      return result[0];
      }
    #elif defined(ARMA_USE_ATLAS)
      {
      // TODO: use dedicated atlas functions cblas_cdotc_sub() and cblas_zdotc_sub() and retune threshold

      return op_cdot::direct_cdot_arma(n_elem, A, B);
      }
    #else
      {
      return op_cdot::direct_cdot_arma(n_elem, A, B);
      }
    #endif
    }
  }
Ejemplo n.º 2
0
 inline
 eT
 dot(const uword n_elem, const eT* x, const eT* y)
   {
   arma_type_check((is_supported_blas_type<eT>::value == false));
   
   if(is_float<eT>::value == true)
     {
     #if defined(ARMA_BLAS_SDOT_BUG)
       {
       if(n_elem == 0)  { return eT(0); }
       
       const char trans   = 'T';
       
       const blas_int m   = blas_int(n_elem);
       const blas_int n   = 1;
       //const blas_int lda = (n_elem > 0) ? blas_int(n_elem) : blas_int(1);
       const blas_int inc = 1;
       
       const eT alpha     = eT(1);
       const eT beta      = eT(0);
       
       eT result[2];  // paranoia: using two elements instead of one
       
       //blas::gemv(&trans, &m, &n, &alpha, x, &lda, y, &inc, &beta, &result[0], &inc);
       blas::gemv(&trans, &m, &n, &alpha, x, &m, y, &inc, &beta, &result[0], &inc);
       
       return result[0];
       }
     #else
       {
       blas_int n   = blas_int(n_elem);
       blas_int inc = 1;
       
       typedef float T;
       return arma_fortran(arma_sdot)(&n, (const T*)x, &inc, (const T*)y, &inc);
       }
     #endif
     }
   else
   if(is_double<eT>::value == true)
     {
     blas_int n   = blas_int(n_elem);
     blas_int inc = 1;
     
     typedef double T;
     return arma_fortran(arma_ddot)(&n, (const T*)x, &inc, (const T*)y, &inc);
     }
   else
   if( (is_supported_complex_float<eT>::value == true) || (is_supported_complex_double<eT>::value == true) )
     {
     if(n_elem == 0)  { return eT(0); }
     
     // using gemv() workaround due to compatibility issues with cdotu() and zdotu()
     
     const char trans   = 'T';
     
     const blas_int m   = blas_int(n_elem);
     const blas_int n   = 1;
     //const blas_int lda = (n_elem > 0) ? blas_int(n_elem) : blas_int(1);
     const blas_int inc = 1;
     
     const eT alpha     = eT(1);
     const eT beta      = eT(0);
     
     eT result[2];  // paranoia: using two elements instead of one
     
     //blas::gemv(&trans, &m, &n, &alpha, x, &lda, y, &inc, &beta, &result[0], &inc);
     blas::gemv(&trans, &m, &n, &alpha, x, &m, y, &inc, &beta, &result[0], &inc);
     
     return result[0];
     }
   else
     {
     return eT(0);
     }
   }
Ejemplo n.º 3
0
 inline
 static
 void
 apply_blas_type( Mat<eT>& C, const TA& A, const TB& B, const eT alpha = eT(1), const eT beta = eT(0) )
   {
   arma_extra_debug_sigprint();
   
   if( (A.n_rows <= 4) && (A.n_rows == A.n_cols) && (A.n_rows == B.n_rows) && (B.n_rows == B.n_cols) && (is_cx<eT>::no) ) 
     {
     if(do_trans_B == false)
       {
       gemm_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(C, A, B, alpha, beta);
       }
     else
       {
       Mat<eT> BB(B.n_rows, B.n_rows);
       
       op_strans::apply_mat_noalias_tinysq(BB, B);
       
       gemm_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(C, A, BB, alpha, beta);
       }
     }
   else
     {
     #if defined(ARMA_USE_ATLAS)
       {
       arma_extra_debug_print("atlas::cblas_gemm()");
       
       arma_debug_assert_atlas_size(A,B);
       
       atlas::cblas_gemm<eT>
         (
         atlas::CblasColMajor,
         (do_trans_A) ? ( is_cx<eT>::yes ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
         (do_trans_B) ? ( is_cx<eT>::yes ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
         C.n_rows,
         C.n_cols,
         (do_trans_A) ? A.n_rows : A.n_cols,
         (use_alpha) ? alpha : eT(1),
         A.mem,
         (do_trans_A) ? A.n_rows : C.n_rows,
         B.mem,
         (do_trans_B) ? C.n_cols : ( (do_trans_A) ? A.n_rows : A.n_cols ),
         (use_beta) ? beta : eT(0),
         C.memptr(),
         C.n_rows
         );
       }
     #elif defined(ARMA_USE_BLAS)
       {
       arma_extra_debug_print("blas::gemm()");
       
       arma_debug_assert_blas_size(A,B);
       
       const char trans_A = (do_trans_A) ? ( is_cx<eT>::yes ? 'C' : 'T' ) : 'N';
       const char trans_B = (do_trans_B) ? ( is_cx<eT>::yes ? 'C' : 'T' ) : 'N';
       
       const blas_int m   = blas_int(C.n_rows);
       const blas_int n   = blas_int(C.n_cols);
       const blas_int k   = (do_trans_A) ? blas_int(A.n_rows) : blas_int(A.n_cols);
       
       const eT local_alpha = (use_alpha) ? alpha : eT(1);
       
       const blas_int lda = (do_trans_A) ? k : m;
       const blas_int ldb = (do_trans_B) ? n : k;
       
       const eT local_beta  = (use_beta) ? beta : eT(0);
       
       arma_extra_debug_print( arma_str::format("blas::gemm(): trans_A = %c") % trans_A );
       arma_extra_debug_print( arma_str::format("blas::gemm(): trans_B = %c") % trans_B );
       
       blas::gemm<eT>
         (
         &trans_A,
         &trans_B,
         &m,
         &n,
         &k,
         &local_alpha,
         A.mem,
         &lda,
         B.mem,
         &ldb,
         &local_beta,
         C.memptr(),
         &m
         );
       }
     #else
       {
       gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
       }
     #endif
     }
   }