示例#1
0
void Mjoin(PATL,CNBmm_b1)(const int M, const int N, const int K,
                          const TYPE alpha, const TYPE *A, const int lda,
                          const TYPE *B, const int ldb,
                          const TYPE beta, TYPE *C, const int ldc)
{
   NBmm_b1(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
}
示例#2
0
static void ATL_gNBmm(const int M, const int N, const int K, SCALAR alpha,
                      const TYPE *A, const int lda, const TYPE *B,
                      const int ldb, const SCALAR beta, TYPE *C,
                      const int ldc)
/*
 * BETA is known to be 0 or 1
 */
{
   if (M == MB && N == NB && K == KB)
   {
      if (beta == ATL_rone)
         NBmm_b1(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
      else
         NBmm_b0(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
   }
   else if (M != MB)
   {
      if (N == NB && K == KB)
      {
         if (beta == ATL_rone)
            Mjoin(PATL,pMBmm_b1)(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
         else
            Mjoin(PATL,pMBmm_b0)(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
      }
      else Mjoin(PATL,pKBmm)(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
   }
   else if (N != NB)  /* ib is full */
   {
      if (K == KB)
      {
         if (beta == ATL_rone)
            Mjoin(PATL,pNBmm_b1)(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
         else
            Mjoin(PATL,pNBmm_b0)(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
      }
      else Mjoin(PATL,pKBmm)(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
   }
   else  /* ib and jb are full */
   {
      if (beta == ATL_rone)
         Mjoin(PATL,pKBmm_b1)(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
      else
         Mjoin(PATL,pKBmm_b0)(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
   }
}
示例#3
0
static void ATL_gNBmm_b1
   (const int M, const int N, const int K, const TYPE alpha,
    const TYPE *A, const int lda, const TYPE *B, const int ldb,
    const TYPE beta, TYPE *C, const int ldc)
/*
 * ALPHA is known to be 1 (handled by copy)
 * BETA is known to be 1; we handle actual BETA in putblk phase
 */
{
   if (M == MB && N == NB && K == KB)
   {
      NBmm_bX(M, N, K, ATL_rone, A, lda, B, ldb, ATL_rnone, C+NBNB, ldc);
      NBmm_b1(M, N, K, ATL_rone, A, lda, B+NBNB, ldb, ATL_rone, C, ldc);
      NBmm_bX(M, N, K, ATL_rone, A+NBNB, lda, B+NBNB, ldb, ATL_rnone,
              C+NBNB, ldc);
      NBmm_b1(M, N, K, ATL_rone, A+NBNB, lda, B, ldb, ATL_rone, C, ldc);
   }
   else if (M != MB)
   {
      if (N == NB && K == KB)
      {
         Mjoin(PATLU,pMBmm_bX)(M, N, K, ATL_rone, A, lda, B, ldb,
                               ATL_rnone, C+M*N, ldc);
         Mjoin(PATLU,pMBmm_b1)(M, N, K, ATL_rone, A, lda, B+NBNB, ldb,
                               ATL_rone, C, ldc);
         Mjoin(PATLU,pMBmm_bX)(M, N, K, ATL_rone, A+M*K, lda, B+NBNB, ldb,
                               ATL_rnone, C+M*N, ldc);
         Mjoin(PATLU,pMBmm_b1)(M, N, K, ATL_rone, A+M*K, lda, B, ldb,
                               ATL_rone, C, ldc);
      }
      else
      {
         Mjoin(PATLU,pKBmm)(M, N, K, ATL_rone, A, lda, B, ldb,
                            ATL_rnone, C+M*N, ldc);
         Mjoin(PATLU,pKBmm)(M, N, K, ATL_rone, A, lda, B+N*K, ldb,
                            ATL_rone, C, ldc);
         Mjoin(PATLU,pKBmm)(M, N, K, ATL_rone, A+M*K, lda, B+N*K, ldb,
                            ATL_rnone, C+M*N, ldc);
         Mjoin(PATLU,pKBmm)(M, N, K, ATL_rone, A+M*K, lda, B, ldb,
                            ATL_rone, C, ldc);
      }
   }
   else if (N != NB)  /* ib is full */
   {
      if (K == KB)
      {
         Mjoin(PATLU,pNBmm_bX)(M, N, K, ATL_rone, A, lda, B, ldb,
                               ATL_rnone, C+M*N, ldc);
         Mjoin(PATLU,pNBmm_b1)(M, N, K, ATL_rone, A, lda, B+N*K, ldb,
                               ATL_rone, C, ldc);
         Mjoin(PATLU,pNBmm_bX)(M, N, K, ATL_rone, A+M*K, lda, B+N*K, ldb,
                               ATL_rnone, C+M*N, ldc);
         Mjoin(PATLU,pNBmm_b1)(M, N, K, ATL_rone, A+M*K, lda, B, ldb,
                               ATL_rone, C, ldc);
      }
      else
      {
         Mjoin(PATLU,pKBmm)(M, N, K, ATL_rone, A, lda, B, ldb,
                            ATL_rnone, C+M*N, ldc);
         Mjoin(PATLU,pKBmm)(M, N, K, ATL_rone, A, lda, B+N*K, ldb,
                            ATL_rone, C, ldc);
         Mjoin(PATLU,pKBmm)(M, N, K, ATL_rone, A+M*K, lda, B+N*K, ldb,
                            ATL_rnone, C+M*N, ldc);
         Mjoin(PATLU,pKBmm)(M, N, K, ATL_rone, A+M*K, lda, B, ldb,
                            ATL_rone, C, ldc);
      }
   }
   else  /* ib and jb are full */
   {
         Mjoin(PATLU,pKBmm_bX)(M, N, K, ATL_rone, A, lda, B, ldb,
                               ATL_rnone, C+M*N, ldc);
         Mjoin(PATLU,pKBmm_b1)(M, N, K, ATL_rone, A, lda, B+N*K, ldb,
                               ATL_rone, C, ldc);
         Mjoin(PATLU,pKBmm_bX)(M, N, K, ATL_rone, A+M*K, lda, B+N*K, ldb,
                               ATL_rnone, C+M*N, ldc);
         Mjoin(PATLU,pKBmm_b1)(M, N, K, ATL_rone, A+M*K, lda, B, ldb,
                               ATL_rone, C, ldc);
   }
}
示例#4
0
void Mjoin(PATL,mmJIK2)
(int K, int nMb, int nNb, int nKb, int ib, int jb, int kb,
 const SCALAR alpha, const TYPE *pA0, const TYPE *B, int ldb,
 TYPE *pB0, int incB, MAT2BLK B2blk, const SCALAR beta,
 TYPE *C, int ldc, MATSCAL gescal, NBMM0 NBmm0)
{
    const int incK = ATL_MulByNB(K)SHIFT, incC = ATL_MulByNB(ldc-nMb) SHIFT;
    const int ZEROC = ((gescal == NULL) && SCALAR_IS_ZERO(beta));
    int i, j = nNb;
    const TYPE *pA=pA0;
    const TYPE rbeta = ( (gescal) ? ATL_rone : *beta );
    TYPE *pB=pB0, *stB=pB0+(ATL_MulByNBNB(nKb)SHIFT);

    if (nNb)
    {
        do  /* Loop over full column panels of B */
        {
            if (B)
            {
                B2blk(K, NB, B, ldb, pB, alpha);
                B += incB;
            }
            if (nMb)
            {
                i = nMb;
                do /* loop over full row panels of A */
                {
                    if (gescal) gescal(NB, NB, beta, C, ldc);
                    if (nKb) /* loop over full blocks in panels */
                    {
                        NBmm0(MB, NB, KB, ATL_rone, pA, KB, pB, KB, rbeta, C, ldc);
                        pA += NBNB2;
                        pB += NBNB2;
                        if (nKb != 1)
                        {
                            do
                            {
                                NBmm_b1(MB, NB, KB, ATL_rone, pA, KB, pB, KB, ATL_rone,
                                        C, ldc);
                                pA += NBNB2;
                                pB += NBNB2;
                            }
                            while (pB != stB);
                        }
                        if (kb)
                        {
                            KBmm(MB, NB, kb, ATL_rone, pA, kb, pB, kb, ATL_rone,
                                 C, ldc);
                            pA += ATL_MulByNB(kb)<<1;
                        }
                    }
                    else if (kb)
                    {
                        if (ZEROC) Mjoin(PATL,gezero)(MB, NB, C, ldc);
                        KBmm(MB, NB, kb, ATL_rone, pA, kb, pB, kb, rbeta, C, ldc);
                        pA += ATL_MulByNB(kb)<<1;
                    }
                    pB = pB0;
                    C += NB2;
                }
                while (--i);
            }
            if (ib)
            {
                if (gescal) gescal(ib, NB, beta, C, ldc);
                IBNBmm(ib, K, pA, pB, rbeta, C, ldc);
            }
            if (!B)
            {
                pB0 += incK;
                pB = pB0;
                stB += incK;
            }
            C += incC;
            pA = pA0;
        }
        while (--j);
    }
    if (jb)
    {
        if (B) B2blk(K, jb, B, ldb, pB, alpha);
        for (i=nMb; i; i--)
        {
            if (gescal) gescal(NB, jb, beta, C, ldc);
            NBJBmm(jb, K, pA, pB, rbeta, C, ldc);
            pA += incK;
            C += NB2;
        }
        if (ib)
        {
            if (gescal) gescal(ib, jb, beta, C, ldc);
            IBJBmm(ib, jb, K, pA, pB, rbeta, C, ldc);
        }
    }
}
示例#5
0
void Mjoin(PATL,mmIJK2)
   (int K, int nMb, int nNb, int nKb, int ib, int jb, int kb,
    const SCALAR alpha, const TYPE *A, const int lda, TYPE *pA0, const int incA,
    MAT2BLK A2blk, TYPE *pB0, const SCALAR beta, TYPE *C, int ldc,
    MATSCAL gescal, NBMM0 NBmm0)
{
   const int incK = ATL_MulByNB(K)<<1;
   const int incCn = ATL_MulByNB(ldc)<<1, incCm = (MB<<1) - nNb*incCn;
   const int ZEROC = ((gescal == NULL) && SCALAR_IS_ZERO(beta));
   int i, j, k;
   const TYPE *pB=pB0;
   const TYPE rbeta = ( (gescal) ? ATL_rone : *beta );
   TYPE *pA=pA0;

   for (i=nMb; i; i--)
   {
      if (A)
      {
         A2blk(K, NB, A, lda, pA, alpha);  /* get 1 row panel of A */
         A += incA;
      }
      for (j=nNb; j; j--)
      {
         if (gescal) gescal(MB, NB, beta, C, ldc);
         if (nKb)
         {
            NBmm0(MB, NB, KB, ATL_rone, pA, KB, pB, KB, rbeta, C, ldc);
            pA += NBNB2;
            pB += NBNB2;
            if (nKb != 1)
            {
               for (k=nKb-1; k; k--, pA += NBNB2, pB += NBNB2)
                  NBmm_b1(MB, NB, KB, ATL_rone, pA, KB, pB, KB,
                          ATL_rone, C, ldc);
            }
            if (kb)
            {
               KBmm(MB, NB, kb, ATL_rone, pA, kb, pB, kb, ATL_rone, C, ldc);
               pB += ATL_MulByNB(kb)<<1;
            }
         }
         else
         {
            if (ZEROC) Mjoin(PATL,gezero)(MB, NB, C, ldc);
            if (kb)
            {
               KBmm(MB, NB, kb, ATL_rone, pA, kb, pB, kb, rbeta, C, ldc);
               pB += ATL_MulByNB(kb)<<1;
            }
         }
         pA = pA0;
         C += incCn;
      }
      if (jb)
      {
         if (gescal) gescal(MB, jb, beta, C, ldc);
         MBJBmm(jb, K, pA, pB, rbeta, C, ldc);
      }
      pB = pB0;
      if (!A)
      {
         pA0 += incK;
         pA = pA0;
      }
      C += incCm;
   }
   if (ib)
   {
      if (A) A2blk(K, ib, A, lda, pA, alpha);   /* get last row panel of A */
      for(j=nNb; j; j--) /* full column panels of B */
      {
         if (gescal) gescal(ib, NB, beta, C, ldc);
         IBNBmm(ib, K, pA, pB, rbeta, C, ldc);
         pB += incK;
         C += incCn;
      }
      if (jb)
      {
         if (gescal) gescal(ib, jb, beta, C, ldc);
         IBJBmm(ib, jb, K, pA, pB, rbeta, C, ldc);
      }
   }
}