void Mjoin(PATL,CNBmm_b1)(const int M, const int N, const int K, const TYPE alpha, const TYPE *A, const int lda, const TYPE *B, const int ldb, const TYPE beta, TYPE *C, const int ldc) { NBmm_b1(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); }
static void ATL_gNBmm(const int M, const int N, const int K, SCALAR alpha, const TYPE *A, const int lda, const TYPE *B, const int ldb, const SCALAR beta, TYPE *C, const int ldc) /* * BETA is known to be 0 or 1 */ { if (M == MB && N == NB && K == KB) { if (beta == ATL_rone) NBmm_b1(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); else NBmm_b0(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); } else if (M != MB) { if (N == NB && K == KB) { if (beta == ATL_rone) Mjoin(PATL,pMBmm_b1)(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); else Mjoin(PATL,pMBmm_b0)(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); } else Mjoin(PATL,pKBmm)(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); } else if (N != NB) /* ib is full */ { if (K == KB) { if (beta == ATL_rone) Mjoin(PATL,pNBmm_b1)(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); else Mjoin(PATL,pNBmm_b0)(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); } else Mjoin(PATL,pKBmm)(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); } else /* ib and jb are full */ { if (beta == ATL_rone) Mjoin(PATL,pKBmm_b1)(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); else Mjoin(PATL,pKBmm_b0)(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); } }
static void ATL_gNBmm_b1 (const int M, const int N, const int K, const TYPE alpha, const TYPE *A, const int lda, const TYPE *B, const int ldb, const TYPE beta, TYPE *C, const int ldc) /* * ALPHA is known to be 1 (handled by copy) * BETA is known to be 1; we handle actual BETA in putblk phase */ { if (M == MB && N == NB && K == KB) { NBmm_bX(M, N, K, ATL_rone, A, lda, B, ldb, ATL_rnone, C+NBNB, ldc); NBmm_b1(M, N, K, ATL_rone, A, lda, B+NBNB, ldb, ATL_rone, C, ldc); NBmm_bX(M, N, K, ATL_rone, A+NBNB, lda, B+NBNB, ldb, ATL_rnone, C+NBNB, ldc); NBmm_b1(M, N, K, ATL_rone, A+NBNB, lda, B, ldb, ATL_rone, C, ldc); } else if (M != MB) { if (N == NB && K == KB) { Mjoin(PATLU,pMBmm_bX)(M, N, K, ATL_rone, A, lda, B, ldb, ATL_rnone, C+M*N, ldc); Mjoin(PATLU,pMBmm_b1)(M, N, K, ATL_rone, A, lda, B+NBNB, ldb, ATL_rone, C, ldc); Mjoin(PATLU,pMBmm_bX)(M, N, K, ATL_rone, A+M*K, lda, B+NBNB, ldb, ATL_rnone, C+M*N, ldc); Mjoin(PATLU,pMBmm_b1)(M, N, K, ATL_rone, A+M*K, lda, B, ldb, ATL_rone, C, ldc); } else { Mjoin(PATLU,pKBmm)(M, N, K, ATL_rone, A, lda, B, ldb, ATL_rnone, C+M*N, ldc); Mjoin(PATLU,pKBmm)(M, N, K, ATL_rone, A, lda, B+N*K, ldb, ATL_rone, C, ldc); Mjoin(PATLU,pKBmm)(M, N, K, ATL_rone, A+M*K, lda, B+N*K, ldb, ATL_rnone, C+M*N, ldc); Mjoin(PATLU,pKBmm)(M, N, K, ATL_rone, A+M*K, lda, B, ldb, ATL_rone, C, ldc); } } else if (N != NB) /* ib is full */ { if (K == KB) { Mjoin(PATLU,pNBmm_bX)(M, N, K, ATL_rone, A, lda, B, ldb, ATL_rnone, C+M*N, ldc); Mjoin(PATLU,pNBmm_b1)(M, N, K, ATL_rone, A, lda, B+N*K, ldb, ATL_rone, C, ldc); Mjoin(PATLU,pNBmm_bX)(M, N, K, ATL_rone, A+M*K, lda, B+N*K, ldb, ATL_rnone, C+M*N, ldc); Mjoin(PATLU,pNBmm_b1)(M, N, K, ATL_rone, A+M*K, lda, B, ldb, ATL_rone, C, ldc); } else { Mjoin(PATLU,pKBmm)(M, N, K, ATL_rone, A, lda, B, ldb, ATL_rnone, C+M*N, ldc); Mjoin(PATLU,pKBmm)(M, N, K, ATL_rone, A, lda, B+N*K, ldb, ATL_rone, C, ldc); Mjoin(PATLU,pKBmm)(M, N, K, ATL_rone, A+M*K, lda, B+N*K, ldb, ATL_rnone, C+M*N, ldc); Mjoin(PATLU,pKBmm)(M, N, K, ATL_rone, A+M*K, lda, B, ldb, ATL_rone, C, ldc); } } else /* ib and jb are full */ { Mjoin(PATLU,pKBmm_bX)(M, N, K, ATL_rone, A, lda, B, ldb, ATL_rnone, C+M*N, ldc); Mjoin(PATLU,pKBmm_b1)(M, N, K, ATL_rone, A, lda, B+N*K, ldb, ATL_rone, C, ldc); Mjoin(PATLU,pKBmm_bX)(M, N, K, ATL_rone, A+M*K, lda, B+N*K, ldb, ATL_rnone, C+M*N, ldc); Mjoin(PATLU,pKBmm_b1)(M, N, K, ATL_rone, A+M*K, lda, B, ldb, ATL_rone, C, ldc); } }
void Mjoin(PATL,mmJIK2) (int K, int nMb, int nNb, int nKb, int ib, int jb, int kb, const SCALAR alpha, const TYPE *pA0, const TYPE *B, int ldb, TYPE *pB0, int incB, MAT2BLK B2blk, const SCALAR beta, TYPE *C, int ldc, MATSCAL gescal, NBMM0 NBmm0) { const int incK = ATL_MulByNB(K)SHIFT, incC = ATL_MulByNB(ldc-nMb) SHIFT; const int ZEROC = ((gescal == NULL) && SCALAR_IS_ZERO(beta)); int i, j = nNb; const TYPE *pA=pA0; const TYPE rbeta = ( (gescal) ? ATL_rone : *beta ); TYPE *pB=pB0, *stB=pB0+(ATL_MulByNBNB(nKb)SHIFT); if (nNb) { do /* Loop over full column panels of B */ { if (B) { B2blk(K, NB, B, ldb, pB, alpha); B += incB; } if (nMb) { i = nMb; do /* loop over full row panels of A */ { if (gescal) gescal(NB, NB, beta, C, ldc); if (nKb) /* loop over full blocks in panels */ { NBmm0(MB, NB, KB, ATL_rone, pA, KB, pB, KB, rbeta, C, ldc); pA += NBNB2; pB += NBNB2; if (nKb != 1) { do { NBmm_b1(MB, NB, KB, ATL_rone, pA, KB, pB, KB, ATL_rone, C, ldc); pA += NBNB2; pB += NBNB2; } while (pB != stB); } if (kb) { KBmm(MB, NB, kb, ATL_rone, pA, kb, pB, kb, ATL_rone, C, ldc); pA += ATL_MulByNB(kb)<<1; } } else if (kb) { if (ZEROC) Mjoin(PATL,gezero)(MB, NB, C, ldc); KBmm(MB, NB, kb, ATL_rone, pA, kb, pB, kb, rbeta, C, ldc); pA += ATL_MulByNB(kb)<<1; } pB = pB0; C += NB2; } while (--i); } if (ib) { if (gescal) gescal(ib, NB, beta, C, ldc); IBNBmm(ib, K, pA, pB, rbeta, C, ldc); } if (!B) { pB0 += incK; pB = pB0; stB += incK; } C += incC; pA = pA0; } while (--j); } if (jb) { if (B) B2blk(K, jb, B, ldb, pB, alpha); for (i=nMb; i; i--) { if (gescal) gescal(NB, jb, beta, C, ldc); NBJBmm(jb, K, pA, pB, rbeta, C, ldc); pA += incK; C += NB2; } if (ib) { if (gescal) gescal(ib, jb, beta, C, ldc); IBJBmm(ib, jb, K, pA, pB, rbeta, C, ldc); } } }
void Mjoin(PATL,mmIJK2) (int K, int nMb, int nNb, int nKb, int ib, int jb, int kb, const SCALAR alpha, const TYPE *A, const int lda, TYPE *pA0, const int incA, MAT2BLK A2blk, TYPE *pB0, const SCALAR beta, TYPE *C, int ldc, MATSCAL gescal, NBMM0 NBmm0) { const int incK = ATL_MulByNB(K)<<1; const int incCn = ATL_MulByNB(ldc)<<1, incCm = (MB<<1) - nNb*incCn; const int ZEROC = ((gescal == NULL) && SCALAR_IS_ZERO(beta)); int i, j, k; const TYPE *pB=pB0; const TYPE rbeta = ( (gescal) ? ATL_rone : *beta ); TYPE *pA=pA0; for (i=nMb; i; i--) { if (A) { A2blk(K, NB, A, lda, pA, alpha); /* get 1 row panel of A */ A += incA; } for (j=nNb; j; j--) { if (gescal) gescal(MB, NB, beta, C, ldc); if (nKb) { NBmm0(MB, NB, KB, ATL_rone, pA, KB, pB, KB, rbeta, C, ldc); pA += NBNB2; pB += NBNB2; if (nKb != 1) { for (k=nKb-1; k; k--, pA += NBNB2, pB += NBNB2) NBmm_b1(MB, NB, KB, ATL_rone, pA, KB, pB, KB, ATL_rone, C, ldc); } if (kb) { KBmm(MB, NB, kb, ATL_rone, pA, kb, pB, kb, ATL_rone, C, ldc); pB += ATL_MulByNB(kb)<<1; } } else { if (ZEROC) Mjoin(PATL,gezero)(MB, NB, C, ldc); if (kb) { KBmm(MB, NB, kb, ATL_rone, pA, kb, pB, kb, rbeta, C, ldc); pB += ATL_MulByNB(kb)<<1; } } pA = pA0; C += incCn; } if (jb) { if (gescal) gescal(MB, jb, beta, C, ldc); MBJBmm(jb, K, pA, pB, rbeta, C, ldc); } pB = pB0; if (!A) { pA0 += incK; pA = pA0; } C += incCm; } if (ib) { if (A) A2blk(K, ib, A, lda, pA, alpha); /* get last row panel of A */ for(j=nNb; j; j--) /* full column panels of B */ { if (gescal) gescal(ib, NB, beta, C, ldc); IBNBmm(ib, K, pA, pB, rbeta, C, ldc); pB += incK; C += incCn; } if (jb) { if (gescal) gescal(ib, jb, beta, C, ldc); IBJBmm(ib, jb, K, pA, pB, rbeta, C, ldc); } } }