arma_hot
inline
void
op_strans::apply_noalias(Mat<eT>& out, const TA& A)
  {
  arma_extra_debug_sigprint();
  
  const uword A_n_cols = A.n_cols;
  const uword A_n_rows = A.n_rows;
  
  out.set_size(A_n_cols, A_n_rows);
  
  if( (TA::is_row) || (TA::is_col) || (A_n_cols == 1) || (A_n_rows == 1) )
    {
    arrayops::copy( out.memptr(), A.memptr(), A.n_elem );
    }
  else
    {
    if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) )
      {
      op_strans::apply_noalias_tinysq(out, A);
      }
    else
      {
      for(uword k=0; k < A_n_cols; ++k)
        {
        uword i, j;
        
        const eT* colptr = A.colptr(k);
        
        for(i=0, j=1; j < A_n_rows; i+=2, j+=2)
          {
          const eT tmp_i = colptr[i];
          const eT tmp_j = colptr[j];
          
          out.at(k, i) = tmp_i;
          out.at(k, j) = tmp_j;
          }
        
        if(i < A_n_rows)
          {
          out.at(k, i) = colptr[i];
          }
        }
      }
    }
  }
示例#2
0
  arma_hot
  inline
  static
  void
  apply
    (
          Mat<eT>& C,
    const TA&      A,
    const TB&      B,
    const eT       alpha = eT(1),
    const eT       beta  = eT(0)
    )
    {
    arma_extra_debug_sigprint();

    const uword A_n_rows = A.n_rows;
    const uword A_n_cols = A.n_cols;
    
    const uword B_n_rows = B.n_rows;
    const uword B_n_cols = B.n_cols;
    
    if( (do_trans_A == false) && (do_trans_B == false) )
      {
      arma_aligned podarray<eT> tmp(A_n_cols);
      
      eT* A_rowdata = tmp.memptr();
      
      for(uword row_A=0; row_A < A_n_rows; ++row_A)
        {
        //tmp.copy_row(A, row_A);
        const eT acc0 = op_dot::dot_and_copy_row(A_rowdata, A, row_A, B.colptr(0), A_n_cols);
        
             if( (use_alpha == false) && (use_beta == false) )  { C.at(row_A,0) =       acc0;                      }
        else if( (use_alpha == true ) && (use_beta == false) )  { C.at(row_A,0) = alpha*acc0;                      }
        else if( (use_alpha == false) && (use_beta == true ) )  { C.at(row_A,0) =       acc0 + beta*C.at(row_A,0); }
        else if( (use_alpha == true ) && (use_beta == true ) )  { C.at(row_A,0) = alpha*acc0 + beta*C.at(row_A,0); }

        //for(uword col_B=0; col_B < B_n_cols; ++col_B)
        for(uword col_B=1; col_B < B_n_cols; ++col_B)
          {
          const eT acc = op_dot::direct_dot_arma(B_n_rows, A_rowdata, B.colptr(col_B));
          
               if( (use_alpha == false) && (use_beta == false) )  { C.at(row_A,col_B) =       acc;                          }
          else if( (use_alpha == true ) && (use_beta == false) )  { C.at(row_A,col_B) = alpha*acc;                          }
          else if( (use_alpha == false) && (use_beta == true ) )  { C.at(row_A,col_B) =       acc + beta*C.at(row_A,col_B); }
          else if( (use_alpha == true ) && (use_beta == true ) )  { C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); }
          }
        }
      }
    else
    if( (do_trans_A == true) && (do_trans_B == false) )
      {
      for(uword col_A=0; col_A < A_n_cols; ++col_A)
        {
        // col_A is interpreted as row_A when storing the results in matrix C
        
        const eT* A_coldata = A.colptr(col_A);
        
        for(uword col_B=0; col_B < B_n_cols; ++col_B)
          {
          const eT acc = op_dot::direct_dot_arma(B_n_rows, A_coldata, B.colptr(col_B));
          
               if( (use_alpha == false) && (use_beta == false) )  { C.at(col_A,col_B) =       acc;                          }
          else if( (use_alpha == true ) && (use_beta == false) )  { C.at(col_A,col_B) = alpha*acc;                          }
          else if( (use_alpha == false) && (use_beta == true ) )  { C.at(col_A,col_B) =       acc + beta*C.at(col_A,col_B); }
          else if( (use_alpha == true ) && (use_beta == true ) )  { C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); }
          }
        }
      }
    else
    if( (do_trans_A == false) && (do_trans_B == true) )
      {
      Mat<eT> BB;
      op_strans::apply_noalias(BB, B);
      
      gemm_emul_large<false, false, use_alpha, use_beta>::apply(C, A, BB, alpha, beta);
      }
    else
    if( (do_trans_A == true) && (do_trans_B == true) )
      {
      // mat B_tmp = trans(B);
      // dgemm_arma<true, false,  use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
      
      
      // By using the trans(A)*trans(B) = trans(B*A) equivalency,
      // transpose operations are not needed
      
      arma_aligned podarray<eT> tmp(B.n_cols);
      eT* B_rowdata = tmp.memptr();
      
      for(uword row_B=0; row_B < B_n_rows; ++row_B)
        {
        tmp.copy_row(B, row_B);
        
        for(uword col_A=0; col_A < A_n_cols; ++col_A)
          {
          const eT acc = op_dot::direct_dot_arma(A_n_rows, B_rowdata, A.colptr(col_A));
          
               if( (use_alpha == false) && (use_beta == false) )  { C.at(col_A,row_B) =       acc;                          }
          else if( (use_alpha == true ) && (use_beta == false) )  { C.at(col_A,row_B) = alpha*acc;                          }
          else if( (use_alpha == false) && (use_beta == true ) )  { C.at(col_A,row_B) =       acc + beta*C.at(col_A,row_B); }
          else if( (use_alpha == true ) && (use_beta == true ) )  { C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B); }
          }
        }
      }
    }
arma_hot
inline
void
op_strans2::apply_noalias(Mat<eT>& out, const TA& A, const eT val)
  {
  arma_extra_debug_sigprint();
  
  const uword A_n_cols = A.n_cols;
  const uword A_n_rows = A.n_rows;
  
  out.set_size(A_n_cols, A_n_rows);
  
  if( (TA::is_col) || (TA::is_row) || (A_n_cols == 1) || (A_n_rows == 1) )
    {
    const uword N = A.n_elem;
    
    const eT*   A_mem =   A.memptr();
          eT* out_mem = out.memptr();
    
    uword i,j;
    for(i=0, j=1; j < N; i+=2, j+=2)
      {
      const eT tmp_i = A_mem[i];
      const eT tmp_j = A_mem[j];
      
      out_mem[i] = val * tmp_i;
      out_mem[j] = val * tmp_j;
      }
    
    if(i < N)
      {
      out_mem[i] = val * A_mem[i];
      }
    }
  else
    {
    if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) )
      {
      op_strans2::apply_noalias_tinysq(out, A, val);
      }
    else
      {
      for(uword k=0; k < A_n_cols; ++k)
        {
        uword i, j;
        
        const eT* colptr = A.colptr(k);
        
        for(i=0, j=1; j < A_n_rows; i+=2, j+=2)
          {
          const eT tmp_i = colptr[i];
          const eT tmp_j = colptr[j];
          
          out.at(k, i) = val * tmp_i;
          out.at(k, j) = val * tmp_j;
          }
        
        if(i < A_n_rows)
          {
          out.at(k, i) = val * colptr[i];
          }
        }
      }
    }
  }