示例#1
0
int Mjoin(PATL,NCmmJIK_c)
   (const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB,
    const int M, const int N, const int K, const SCALAR alpha,
    const TYPE *A, const int lda, const TYPE *B, const int ldb,
    const SCALAR beta, TYPE *C, const int ldc)
/*
 * JIK loop-ordered matmul with no matrix copy
 */
{
   const int Mb = M / MB, Nb = N / NB, Kb = K / KB;
   const int mr = M - Mb*MB, nr = N - Nb*NB, kr = K - Kb*KB;
   int incAk, incAm, incAn, incBk, incBm, incBn;
   #define incCm MB
   const int incCn = ldc*NB - M + mr;
   int i, j, k;
   const TYPE *a=A, *b=B;
   TYPE *c=C;
   TYPE btmp;
   void *vp;
   TYPE *cp;
   void (*geadd)(const int M, const int N, const SCALAR scalar, const TYPE *A,
                  const int lda, const SCALAR beta, TYPE *C, const int ldc);
   void (*mm_bX)(const int M, const int N, const int K, const SCALAR alpha,
                 const TYPE *A, const int lda, const TYPE *B, const int ldb,
                 const SCALAR beta, TYPE *C, const int ldc);
   void (*mm_b1)(const int M, const int N, const int K, const SCALAR alpha,
                 const TYPE *A, const int lda, const TYPE *B, const int ldb,
                 const SCALAR beta, TYPE *C, const int ldc);
   void (*mmcu) (const int M, const int N, const int K, const SCALAR alpha,
                 const TYPE *A, const int lda, const TYPE *B, const int ldb,
                 const SCALAR beta, TYPE *C, const int ldc);
   void (*mm_fixedKcu)(const int M, const int N, const int K,
                       const SCALAR alpha, const TYPE *A, const int lda,
                       const TYPE *B, const int ldb, const
                       SCALAR beta, TYPE *C, const int ldc);

   if (TA == AtlasNoTrans)
   {
      if (TB == AtlasNoTrans)
      {
         mm_bX = Mjoin(Mjoin(Mjoin(NCmm0,NN),0x0x0),_a1_b0);
         mm_b1 = Mjoin(Mjoin(Mjoin(NCmm0,NN),0x0x0),_a1_b1);
         mm_fixedKcu=Mjoin(Mjoin(Mjoin(NCmm00,Mjoin(0x0x,KB)),NN),0x0x0_aX_bX);
         mmcu = Mjoin(Mjoin(Mjoin(NCmm00,0x0x0),NN),0x0x0_aX_bX);
      }
      else
      {
         mm_bX = Mjoin(Mjoin(Mjoin(NCmm0,NT),0x0x0),_a1_b0);
         mm_b1 = Mjoin(Mjoin(Mjoin(NCmm0,NT),0x0x0),_a1_b1);
         mm_fixedKcu=Mjoin(Mjoin(Mjoin(NCmm00,Mjoin(0x0x,KB)),NT),0x0x0_aX_bX);
         mmcu = Mjoin(Mjoin(Mjoin(NCmm00,0x0x0),NT),0x0x0_aX_bX);
      }
      incAk = lda * KB;
      incAm = MB - Kb * incAk;
      incAn = -Mb * MB;
   }
   else
   {
      if (TB == AtlasNoTrans)
      {
         mm_bX = Mjoin(Mjoin(Mjoin(NCmm0,TN),0x0x0),_a1_b0);
         mm_b1 = Mjoin(Mjoin(Mjoin(NCmm0,TN),0x0x0),_a1_b1);
         mm_fixedKcu=Mjoin(Mjoin(Mjoin(NCmm00,Mjoin(0x0x,KB)),TN),0x0x0_aX_bX);
         mmcu = Mjoin(Mjoin(Mjoin(NCmm00,0x0x0),TN),0x0x0_aX_bX);
      }
      else
      {
         mm_bX = Mjoin(Mjoin(Mjoin(NCmm0,TT),0x0x0),_a1_b0);
         mm_b1 = Mjoin(Mjoin(Mjoin(NCmm0,TT),0x0x0),_a1_b1);
         mm_fixedKcu=Mjoin(Mjoin(Mjoin(NCmm00,Mjoin(0x0x,KB)),TT),0x0x0_aX_bX);
         mmcu = Mjoin(Mjoin(Mjoin(NCmm00,0x0x0),TT),0x0x0_aX_bX);
      }
      incAk = KB;
      incAm = lda*MB - Kb*KB;
      incAn = -lda*MB*Mb;
   }
   if (TB == AtlasNoTrans)
   {
      incBk = KB;
      incBm = -KB*Kb;
      incBn = ldb*NB;
   }
   else
   {
      incBk = KB*ldb;
      incBm = -Kb * incBk;
      incBn = NB;
   }

   if (alpha == ATL_rone)
   {
      if (beta == ATL_rzero) geadd = Mjoin(Mjoin(Mjoin(PATL,geadd),_a1),_b0);
      else if (beta == ATL_rone)
         geadd = Mjoin(Mjoin(Mjoin(PATL,geadd),_a1),_b1);
      else geadd = Mjoin(Mjoin(Mjoin(PATL,geadd),_a1),_bX);
   }
   else if (beta == ATL_rzero)
      geadd = Mjoin(Mjoin(Mjoin(PATL,geadd),_aX),_b0);
   else if (beta == ATL_rone)
      geadd = Mjoin(Mjoin(Mjoin(PATL,geadd),_aX),_b1);
   else geadd = Mjoin(Mjoin(Mjoin(PATL,geadd),_aX),_bX);
   vp = malloc(ATL_Cachelen + ATL_MulBySize(MB * NB));
   ATL_assert(vp);
   cp = ATL_AlignPtr(vp);
   if (mr || nr || kr) for (j=MB*NB, i=0; i != j; i++) cp[i] = ATL_rzero;

   for (j=Nb; j; j--, a += incAn, b += incBn, c += incCn)
   {
      for (i=Mb; i; i--, a += incAm, b += incBm, c += incCm)
      {
         if (Kb)
         {
            mm_bX(MB, NB, KB, ATL_rone, a, lda, b, ldb, ATL_rzero, cp, MB);
            a += incAk;  b += incBk;
            for (k=Kb-1; k; k--, a += incAk, b += incBk)
               mm_b1(MB, NB, KB, ATL_rone, a, lda, b, ldb, ATL_rone, cp, MB);
            if (kr)
               mmcu(MB, NB, kr, ATL_rone, a, lda, b, ldb, ATL_rone, cp, MB);
         }
         else if (kr)
         {
            Mjoin(PATL,zero)(MB*NB, cp, 1); /* kill NaN/INF from last time */
            mmcu(MB, NB, kr, ATL_rone, a, lda, b, ldb, ATL_rzero, cp, MB);
         }
         geadd(MB, NB, alpha, cp, MB, beta, c, ldc);
      }
   }
   if (mr && N != nr)
      ATL_assert(Mjoin(PATL,NCmmIJK)(TA, TB, mr, N-nr, K, alpha,
                                     A+Mb*(incAm+Kb*incAk), lda, B, ldb,
                                     beta, C+Mb*MB, ldc) ==0);
   if (nr)
   {
      for (i=Mb; i; i--, a += incAm, b += incBm, c += incCm)
      {
      Mjoin(PATL,zero)(MB*nr, cp, 1); /* kill NaN and INF from last time */
         if (Kb)
         {
            mm_fixedKcu(MB, nr, KB, ATL_rone, a, lda, b, ldb, ATL_rzero,
                        cp, MB);
            a += incAk;  b += incBk;
            for (k=Kb-1; k; k--, a += incAk, b += incBk)
               mm_fixedKcu(MB, nr, KB, ATL_rone, a, lda, b, ldb, ATL_rone,
                           cp, MB);
            if (kr)
               mmcu(MB, nr, kr, ATL_rone, a, lda, b, ldb, ATL_rone, cp, MB);
         }
         else if (kr)
            mmcu(MB, nr, kr, ATL_rone, a, lda, b, ldb, ATL_rzero, cp, MB);
         geadd(MB, nr, alpha, cp, MB, beta, c, ldc);
      }
      if (mr)  /* cleanup small mr x nr block of C */
      {
         c = C + Mb*MB + ldc*Nb*NB;
         a = A + Mb*(incAm+Kb*incAk);
         b = B + Nb*( incBn+(Mb*(incBm+Kb*incBk)) );
         Mjoin(PATL,zero)(MB*nr, cp, 1); /* kill NaN and INF from last time */
         if (Kb)
         {
            mm_fixedKcu(mr, nr, KB, ATL_rone, a, lda, b, ldb, ATL_rzero,
                        cp, MB);
            a += incAk;  b += incBk;
            for (k=Kb-1; k; k--, a += incAk, b += incBk)
               mm_fixedKcu(mr, nr, KB, ATL_rone, a, lda, b, ldb, ATL_rone,
                           cp, MB);
            if (kr)
               mmcu(mr, nr, kr, ATL_rone, a, lda, b, ldb, ATL_rone, cp, MB);
         }
         else if (kr)
            mmcu(mr, nr, kr, ATL_rone, a, lda, b, ldb, ATL_rzero, cp, MB);
         geadd(mr, nr, alpha, cp, MB, beta, c, ldc);
      }
   }
   free(vp);
   return(0);
}
示例#2
0
int Mjoin(PATL,NCmmIJK)
(const enum ATLAS_TRANS TA, const enum ATLAS_TRANS TB,
 const int M, const int N, const int K, const SCALAR alpha,
 const TYPE *A, const int lda, const TYPE *B, const int ldb,
 const SCALAR beta, TYPE *C, const int ldc)
/*
 * IJK loop-ordered matmul with no matrix copy
 */
{
    const int Mb = M / MB, Nb = N / NB, Kb = K / KB;
    const int mr = M - Mb*MB, nr = N - Nb*NB, kr = K - Kb*KB;
    int incAk, incAm, incAn, incBk, incBm, incBn;
    const int incCn = ldc*NB, incCm = MB - Nb * incCn;
    const int BetaIsZero = (beta == ATL_rzero);
    int i, j, k;
    const TYPE *a=A, *b=B;
    TYPE *c=C;
    TYPE btmp;
    void (*mm_bX)(const int M, const int N, const int K, const SCALAR alpha,
                  const TYPE *A, const int lda, const TYPE *B, const int ldb,
                  const SCALAR beta, TYPE *C, const int ldc);
    void (*mm_b1)(const int M, const int N, const int K, const SCALAR alpha,
                  const TYPE *A, const int lda, const TYPE *B, const int ldb,
                  const SCALAR beta, TYPE *C, const int ldc);
    void (*mmcu) (const int M, const int N, const int K, const SCALAR alpha,
                  const TYPE *A, const int lda, const TYPE *B, const int ldb,
                  const SCALAR beta, TYPE *C, const int ldc);
    void (*mm_fixedKcu)(const int M, const int N, const int K,
                        const SCALAR alpha, const TYPE *A, const int lda,
                        const TYPE *B, const int ldb, const
                        SCALAR beta, TYPE *C, const int ldc);

    if (TA == AtlasNoTrans)
    {
        if (TB == AtlasNoTrans)
        {
            mm_fixedKcu=Mjoin(Mjoin(Mjoin(NCmm00,Mjoin(0x0x,KB)),NN),0x0x0_aX_bX);
            mmcu = Mjoin(Mjoin(Mjoin(NCmm00,0x0x0),NN),0x0x0_aX_bX);
        }
        else
        {
            mm_fixedKcu=Mjoin(Mjoin(Mjoin(NCmm00,Mjoin(0x0x,KB)),NT),0x0x0_aX_bX);
            mmcu = Mjoin(Mjoin(Mjoin(NCmm00,0x0x0),NT),0x0x0_aX_bX);
        }
        incAk = lda * KB;
        incAn = -Kb * incAk;
        incAm = MB;
    }
    else
    {
        if (TB == AtlasNoTrans)
        {
            mm_fixedKcu=Mjoin(Mjoin(Mjoin(NCmm00,Mjoin(0x0x,KB)),TN),0x0x0_aX_bX);
            mmcu = Mjoin(Mjoin(Mjoin(NCmm00,0x0x0),TN),0x0x0_aX_bX);
        }
        else
        {
            mm_fixedKcu=Mjoin(Mjoin(Mjoin(NCmm00,Mjoin(0x0x,KB)),TT),0x0x0_aX_bX);
            mmcu = Mjoin(Mjoin(Mjoin(NCmm00,0x0x0),TT),0x0x0_aX_bX);
        }
        incAk = KB;
        incAn = -Kb * incAk;
        incAm = MB * lda;
    }
    if (TB == AtlasNoTrans)
    {
        incBk = KB;
        incBn = ldb*NB - K + kr;
        incBm = -Nb * ldb * NB;
    }
    else
    {
        incBk = KB*ldb;
        incBn = NB - Kb*incBk;
        incBm = -Nb*NB;
    }

    if (alpha == ATL_rone)
    {
        if (TA == AtlasNoTrans)
        {
            if (TB == AtlasNoTrans)
            {
                mm_b1 = Mjoin(Mjoin(Mjoin(NCmm0,NN),0x0x0),_a1_b1);
                if (beta == ATL_rone) mm_bX = mm_b1;
                else if (beta == ATL_rzero)
                    mm_bX = Mjoin(Mjoin(Mjoin(NCmm0,NN),0x0x0),_a1_b0);
                else mm_bX = Mjoin(Mjoin(Mjoin(NCmm0,NN),0x0x0),_a1_bX);
            }
            else
            {
                mm_b1 = Mjoin(Mjoin(Mjoin(NCmm0,NT),0x0x0),_a1_b1);
                if (beta == ATL_rone) mm_bX = mm_b1;
                else if (beta == ATL_rzero)
                    mm_bX = Mjoin(Mjoin(Mjoin(NCmm0,NT),0x0x0),_a1_b0);
                else mm_bX = Mjoin(Mjoin(Mjoin(NCmm0,NT),0x0x0),_a1_bX);
            }
        }
        else
        {
            if (TB == AtlasNoTrans)
            {
                mm_b1 = Mjoin(Mjoin(Mjoin(NCmm0,TN),0x0x0),_a1_b1);
                if (beta == ATL_rone) mm_bX = mm_b1;
                else if (beta == ATL_rzero)
                    mm_bX = Mjoin(Mjoin(Mjoin(NCmm0,TN),0x0x0),_a1_b0);
                else mm_bX = Mjoin(Mjoin(Mjoin(NCmm0,TN),0x0x0),_a1_bX);
            }
            else
            {
                mm_b1 = Mjoin(Mjoin(Mjoin(NCmm0,TT),0x0x0),_a1_b1);
                if (beta == ATL_rone) mm_bX = mm_b1;
                else if (beta == ATL_rzero)
                    mm_bX = Mjoin(Mjoin(Mjoin(NCmm0,TT),0x0x0),_a1_b0);
                else mm_bX = Mjoin(Mjoin(Mjoin(NCmm0,TT),0x0x0),_a1_bX);
            }
        }
    }
    else  /* non-one alpha */
    {
        btmp = Mabs(beta);
        if (btmp < ATL_rone) btmp = 1.0;
        /*
         *    If needed, call version that uses temp C to handle alpha & beta safely
         */
        if (Kb >= ATL_MaxMMalpha || Mabs(alpha) < btmp)
            return(Mjoin(PATL,NCmmIJK_c)(TA, TB, M, N, K, alpha, A, lda, B, ldb,
                                         beta, C, ldc));
        if (TA == AtlasNoTrans)
        {
            if (TB == AtlasNoTrans)
            {
                mm_bX = mm_b1 = Mjoin(Mjoin(Mjoin(NCmm0,NN),0x0x0),_aX_bX);
                if (beta == ATL_rzero)
                    mm_bX = Mjoin(Mjoin(Mjoin(NCmm0,NN),0x0x0),_aX_b0);
            }
            else
            {
                mm_bX = mm_b1 = Mjoin(Mjoin(Mjoin(NCmm0,NT),0x0x0),_aX_bX);
                if (beta == ATL_rzero)
                    mm_bX = Mjoin(Mjoin(Mjoin(NCmm0,NT),0x0x0),_aX_b0);
            }
        }
        else
        {
            if (TB == AtlasNoTrans)
            {
                mm_bX = mm_b1 = Mjoin(Mjoin(Mjoin(NCmm0,TN),0x0x0),_aX_bX);
                if (beta == ATL_rzero)
                    mm_bX = Mjoin(Mjoin(Mjoin(NCmm0,TN),0x0x0),_aX_b0);
            }
            else
            {
                mm_bX = mm_b1 = Mjoin(Mjoin(Mjoin(NCmm0,TT),0x0x0),_aX_bX);
                if (beta == ATL_rzero)
                    mm_bX = Mjoin(Mjoin(Mjoin(NCmm0,TT),0x0x0),_aX_b0);
            }
        }
    }

    for (i=Mb; i; i--, a += incAm, b += incBm, c += incCm)
    {
        for (j=Nb; j; j--, a += incAn, b += incBn, c += incCn)
        {
            if (Kb)
            {
                mm_bX(MB, NB, KB, alpha, a, lda, b, ldb, beta, c, ldc);
                a += incAk;
                b += incBk;
                for (k=Kb-1; k; k--, a += incAk, b += incBk)
                    mm_b1(MB, NB, KB, alpha, a, lda, b, ldb, ATL_rone, c, ldc);
                if (kr)
                    mmcu(MB, NB, kr, alpha, a, lda, b, ldb, ATL_rone, c, ldc);
            }
            else if (kr)
            {
                if (BetaIsZero) Mjoin(PATL,gezero)(MB, NB, c, ldc);
                mmcu(MB, NB, kr, alpha, a, lda, b, ldb, beta, c, ldc);
            }
        }
    }
    if (mr)  /* M-loop remainder */
    {
        for (j=Nb; j; j--, a += incAn, b += incBn, c += incCn)
        {
            if (BetaIsZero) Mjoin(PATL,gezero)(mr, NB, c, ldc);
            if (Kb)
            {
                mm_fixedKcu(mr, NB, KB, alpha, a, lda, b, ldb, beta,
                            c, ldc);
                a += incAk;
                b += incBk;
                for (k=Kb-1; k; k--, a += incAk, b += incBk)
                    mm_fixedKcu(mr, NB, KB, alpha, a, lda, b, ldb, ATL_rone,
                                c, ldc);
                if (kr)
                    mmcu(mr, NB, kr, alpha, a, lda, b, ldb, ATL_rone, c, ldc);
            }
            else if (kr)
                mmcu(mr, NB, kr, alpha, a, lda, b, ldb, beta, c, ldc);
        }
    }
    if (nr)
        ATL_assert(Mjoin(PATL,NCmmJIK)(TA, TB, M, nr, K, alpha, A, lda,
                                       B+Nb*(incBn+Kb*incBk), ldb,
                                       beta, C+Nb*NB*ldc, ldc) == 0);
    return(0);
}