blasint CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLOAT *sb, BLASLONG mypos) { BLASLONG n, info; BLASLONG bk, i, blocking; int mode; BLASLONG lda, range_N[2]; blas_arg_t newarg; FLOAT *a; FLOAT alpha[2] = { ONE, ZERO}; FLOAT beta [2] = {-ONE, ZERO}; #ifndef COMPLEX #ifdef XDOUBLE mode = BLAS_XDOUBLE | BLAS_REAL; #elif defined(DOUBLE) mode = BLAS_DOUBLE | BLAS_REAL; #else mode = BLAS_SINGLE | BLAS_REAL; #endif #else #ifdef XDOUBLE mode = BLAS_XDOUBLE | BLAS_COMPLEX; #elif defined(DOUBLE) mode = BLAS_DOUBLE | BLAS_COMPLEX; #else mode = BLAS_SINGLE | BLAS_COMPLEX; #endif #endif n = args -> n; a = (FLOAT *)args -> a; lda = args -> lda; if (range_n) n = range_n[1] - range_n[0]; if (n <= DTB_ENTRIES) { info = TRTI2(args, NULL, range_n, sa, sb, 0); return info; } blocking = GEMM_Q; if (n < 4 * GEMM_Q) blocking = (n + 3) / 4; for (i = 0; i < n; i += blocking) { bk = n - i; if (bk > blocking) bk = blocking; range_N[0] = i; range_N[1] = i + bk; newarg.lda = lda; newarg.ldb = lda; newarg.ldc = lda; newarg.alpha = alpha; newarg.m = i; newarg.n = bk; newarg.a = a + (i + i * lda) * COMPSIZE; newarg.b = a + ( i * lda) * COMPSIZE; newarg.beta = beta; newarg.nthreads = args -> nthreads; gemm_thread_m(mode, &newarg, NULL, NULL, TRSM, sa, sb, args -> nthreads); newarg.m = bk; newarg.n = bk; newarg.a = a + (i + i * lda) * COMPSIZE; CNAME (&newarg, NULL, NULL, sa, sb, 0); newarg.m = i; newarg.n = n - i - bk; newarg.k = bk; newarg.a = a + ( i * lda) * COMPSIZE; newarg.b = a + (i + (i + bk) * lda) * COMPSIZE; newarg.c = a + ( (i + bk) * lda) * COMPSIZE; newarg.beta = NULL; gemm_thread_n(mode, &newarg, NULL, NULL, GEMM_NN, sa, sb, args -> nthreads); newarg.a = a + (i + i * lda) * COMPSIZE; newarg.b = a + (i + (i + bk) * lda) * COMPSIZE; newarg.m = bk; newarg.n = n - i - bk; gemm_thread_n(mode, &newarg, NULL, NULL, TRMM, sa, sb, args -> nthreads); } return 0; }
int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLOAT *sb, BLASLONG mypos){ BLASLONG m = args -> m; BLASLONG n = args -> n; BLASLONG nthreads = args -> nthreads; BLASLONG divN, divT; int mode; if (nthreads == 1) { GEMM_LOCAL(args, range_m, range_n, sa, sb, 0); return 0; } if (range_m) { BLASLONG m_from = *(((BLASLONG *)range_m) + 0); BLASLONG m_to = *(((BLASLONG *)range_m) + 1); m = m_to - m_from; } if (range_n) { BLASLONG n_from = *(((BLASLONG *)range_n) + 0); BLASLONG n_to = *(((BLASLONG *)range_n) + 1); n = n_to - n_from; } if ((args -> m < nthreads * SWITCH_RATIO) || (args -> n < nthreads * SWITCH_RATIO)) { GEMM_LOCAL(args, range_m, range_n, sa, sb, 0); return 0; } divT = nthreads; divN = 1; #if 0 while ((GEMM_P * divT > m * SWITCH_RATIO) && (divT > 1)) { do { divT --; divN = 1; while (divT * divN < nthreads) divN ++; } while ((divT * divN != nthreads) && (divT > 1)); } #endif // fprintf(stderr, "divN = %4ld divT = %4ld\n", divN, divT); args -> nthreads = divT; if (divN == 1){ gemm_driver(args, range_m, range_n, sa, sb, 0); } else { #ifndef COMPLEX #ifdef XDOUBLE mode = BLAS_XDOUBLE | BLAS_REAL; #elif defined(DOUBLE) mode = BLAS_DOUBLE | BLAS_REAL; #else mode = BLAS_SINGLE | BLAS_REAL; #endif #else #ifdef XDOUBLE mode = BLAS_XDOUBLE | BLAS_COMPLEX; #elif defined(DOUBLE) mode = BLAS_DOUBLE | BLAS_COMPLEX; #else mode = BLAS_SINGLE | BLAS_COMPLEX; #endif #endif #if defined(TN) || defined(TT) || defined(TR) || defined(TC) || \ defined(CN) || defined(CT) || defined(CR) || defined(CC) mode |= (BLAS_TRANSA_T); #endif #if defined(NT) || defined(TT) || defined(RT) || defined(CT) || \ defined(NC) || defined(TC) || defined(RC) || defined(CC) mode |= (BLAS_TRANSB_T); #endif #ifdef OS_WINDOWS gemm_thread_n(mode, args, range_m, range_n, GEMM_LOCAL, sa, sb, divN); #else gemm_thread_n(mode, args, range_m, range_n, gemm_driver, sa, sb, divN); #endif } return 0; }
blasint CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLOAT *sb, BLASLONG myid) { BLASLONG n, bk, i, blocking, lda; BLASLONG info; int mode; blas_arg_t newarg; FLOAT *a; FLOAT alpha[2] = { -ONE, ZERO}; #ifndef COMPLEX #ifdef XDOUBLE mode = BLAS_XDOUBLE | BLAS_REAL; #elif defined(DOUBLE) mode = BLAS_DOUBLE | BLAS_REAL; #else mode = BLAS_SINGLE | BLAS_REAL; #endif #else #ifdef XDOUBLE mode = BLAS_XDOUBLE | BLAS_COMPLEX; #elif defined(DOUBLE) mode = BLAS_DOUBLE | BLAS_COMPLEX; #else mode = BLAS_SINGLE | BLAS_COMPLEX; #endif #endif if (args -> nthreads == 1) { #ifndef LOWER info = POTRF_U_SINGLE(args, NULL, NULL, sa, sb, 0); #else info = POTRF_L_SINGLE(args, NULL, NULL, sa, sb, 0); #endif return info; } n = args -> n; a = (FLOAT *)args -> a; lda = args -> lda; if (range_n) n = range_n[1] - range_n[0]; if (n <= GEMM_UNROLL_N * 2) { #ifndef LOWER info = POTRF_U_SINGLE(args, NULL, range_n, sa, sb, 0); #else info = POTRF_L_SINGLE(args, NULL, range_n, sa, sb, 0); #endif return info; } newarg.lda = lda; newarg.ldb = lda; newarg.ldc = lda; newarg.alpha = alpha; newarg.beta = NULL; newarg.nthreads = args -> nthreads; blocking = (n / 2 + GEMM_UNROLL_N - 1) & ~(GEMM_UNROLL_N - 1); if (blocking > GEMM_Q) blocking = GEMM_Q; for (i = 0; i < n; i += blocking) { bk = n - i; if (bk > blocking) bk = blocking; newarg.m = bk; newarg.n = bk; newarg.a = a + (i + i * lda) * COMPSIZE; info = CNAME(&newarg, NULL, NULL, sa, sb, 0); if (info) return info + i; if (n - i - bk > 0) { #ifndef USE_SIMPLE_THREADED_LEVEL3 newarg.m = n - i - bk; newarg.k = bk; #ifndef LOWER newarg.b = a + ( i + (i + bk) * lda) * COMPSIZE; #else newarg.b = a + ((i + bk) + i * lda) * COMPSIZE; #endif newarg.c = a + ((i + bk) + (i + bk) * lda) * COMPSIZE; thread_driver(&newarg, sa, sb); #else #ifndef LOWER newarg.m = bk; newarg.n = n - i - bk; newarg.a = a + (i + i * lda) * COMPSIZE; newarg.b = a + (i + (i + bk) * lda) * COMPSIZE; gemm_thread_n(mode | BLAS_TRANSA_T, &newarg, NULL, NULL, (void *)TRSM_LCUN, sa, sb, args -> nthreads); newarg.n = n - i - bk; newarg.k = bk; newarg.a = a + ( i + (i + bk) * lda) * COMPSIZE; newarg.c = a + ((i + bk) + (i + bk) * lda) * COMPSIZE; #if 0 HERK_THREAD_UC(&newarg, NULL, NULL, sa, sb, 0); #else syrk_thread(mode | BLAS_TRANSA_N | BLAS_TRANSB_T, &newarg, NULL, NULL, (void *)HERK_UC, sa, sb, args -> nthreads); #endif #else newarg.m = n - i - bk; newarg.n = bk; newarg.a = a + (i + i * lda) * COMPSIZE; newarg.b = a + (i + bk + i * lda) * COMPSIZE; gemm_thread_m(mode | BLAS_RSIDE | BLAS_TRANSA_T | BLAS_UPLO, &newarg, NULL, NULL, (void *)TRSM_RCLN, sa, sb, args -> nthreads); newarg.n = n - i - bk; newarg.k = bk; newarg.a = a + (i + bk + i * lda) * COMPSIZE; newarg.c = a + (i + bk + (i + bk) * lda) * COMPSIZE; #if 0 HERK_THREAD_LN(&newarg, NULL, NULL, sa, sb, 0); #else syrk_thread(mode | BLAS_TRANSA_N | BLAS_TRANSB_T | BLAS_UPLO, &newarg, NULL, NULL, (void *)HERK_LN, sa, sb, args -> nthreads); #endif #endif #endif } } return 0; }