Esempio n. 1
0
blasint CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLOAT *sb, BLASLONG mypos) {

  BLASLONG n, info;
  BLASLONG bk, i, blocking;
  int mode;
  BLASLONG lda, range_N[2];
  blas_arg_t newarg;
  FLOAT *a;
  FLOAT alpha[2] = { ONE, ZERO};
  FLOAT beta [2] = {-ONE, ZERO};

#ifndef COMPLEX
#ifdef XDOUBLE
  mode  =  BLAS_XDOUBLE | BLAS_REAL;
#elif defined(DOUBLE)
  mode  =  BLAS_DOUBLE  | BLAS_REAL;
#else
  mode  =  BLAS_SINGLE  | BLAS_REAL;
#endif
#else
#ifdef XDOUBLE
  mode  =  BLAS_XDOUBLE | BLAS_COMPLEX;
#elif defined(DOUBLE)
  mode  =  BLAS_DOUBLE  | BLAS_COMPLEX;
#else
  mode  =  BLAS_SINGLE  | BLAS_COMPLEX;
#endif
#endif

  n  = args -> n;
  a  = (FLOAT *)args -> a;
  lda = args -> lda;

  if (range_n) n  = range_n[1] - range_n[0];

  if (n <= DTB_ENTRIES) {
    info = TRTI2(args, NULL, range_n, sa, sb, 0);
    return info;
  }

  blocking = GEMM_Q;
  if (n < 4 * GEMM_Q) blocking = (n + 3) / 4;

  for (i = 0; i < n; i += blocking) {
    bk = n - i;
    if (bk > blocking) bk = blocking;

    range_N[0] = i;
    range_N[1] = i + bk;

    newarg.lda = lda;
    newarg.ldb = lda;
    newarg.ldc = lda;
    newarg.alpha = alpha;

    newarg.m = i;
    newarg.n = bk;
    newarg.a = a + (i + i * lda) * COMPSIZE;
    newarg.b = a + (    i * lda) * COMPSIZE;

    newarg.beta  = beta;
    newarg.nthreads = args -> nthreads;

    gemm_thread_m(mode, &newarg, NULL, NULL, TRSM, sa, sb, args -> nthreads);

    newarg.m = bk;
    newarg.n = bk;

    newarg.a = a + (i + i * lda) * COMPSIZE;

    CNAME  (&newarg, NULL, NULL, sa, sb, 0);

    newarg.m = i;
    newarg.n = n - i - bk;
    newarg.k = bk;

    newarg.a = a + (     i       * lda) * COMPSIZE;
    newarg.b = a + (i + (i + bk) * lda) * COMPSIZE;
    newarg.c = a + (    (i + bk) * lda) * COMPSIZE;

    newarg.beta  = NULL;

    gemm_thread_n(mode, &newarg, NULL, NULL, GEMM_NN, sa, sb, args -> nthreads);

    newarg.a = a + (i +  i       * lda) * COMPSIZE;
    newarg.b = a + (i + (i + bk) * lda) * COMPSIZE;

    newarg.m = bk;
    newarg.n = n - i - bk;

    gemm_thread_n(mode, &newarg, NULL, NULL, TRMM, sa, sb, args -> nthreads);

  }

  return 0;
}
Esempio n. 2
0
int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLOAT *sb, BLASLONG mypos){

  BLASLONG m = args -> m;
  BLASLONG n = args -> n;
  BLASLONG nthreads = args -> nthreads;
  BLASLONG divN, divT;
  int mode;
  
  if (nthreads  == 1) {
    GEMM_LOCAL(args, range_m, range_n, sa, sb, 0); 
    return 0;
  }

  if (range_m) {
    BLASLONG m_from = *(((BLASLONG *)range_m) + 0);
    BLASLONG m_to   = *(((BLASLONG *)range_m) + 1);

    m = m_to - m_from;
  }

  if (range_n) {
    BLASLONG n_from = *(((BLASLONG *)range_n) + 0);
    BLASLONG n_to   = *(((BLASLONG *)range_n) + 1);

    n = n_to - n_from;
  }

  if ((args -> m < nthreads * SWITCH_RATIO) || (args -> n < nthreads * SWITCH_RATIO)) {
    GEMM_LOCAL(args, range_m, range_n, sa, sb, 0);
    return 0;
  }

  divT = nthreads;
  divN = 1;

#if 0
  while ((GEMM_P * divT > m * SWITCH_RATIO) && (divT > 1)) {
    do {
      divT --;
      divN = 1;
      while (divT * divN < nthreads) divN ++;
    } while ((divT * divN != nthreads) && (divT > 1));
  }
#endif

  // fprintf(stderr, "divN = %4ld  divT = %4ld\n", divN, divT);

  args -> nthreads = divT;

  if (divN == 1){

    gemm_driver(args, range_m, range_n, sa, sb, 0);
  } else {
#ifndef COMPLEX
#ifdef XDOUBLE
    mode  =  BLAS_XDOUBLE | BLAS_REAL;
#elif defined(DOUBLE)
    mode  =  BLAS_DOUBLE  | BLAS_REAL;
#else
    mode  =  BLAS_SINGLE  | BLAS_REAL;
#endif  
#else
#ifdef XDOUBLE
    mode  =  BLAS_XDOUBLE | BLAS_COMPLEX;
#elif defined(DOUBLE)
    mode  =  BLAS_DOUBLE  | BLAS_COMPLEX;
#else
    mode  =  BLAS_SINGLE  | BLAS_COMPLEX;
#endif  
#endif
    
#if defined(TN) || defined(TT) || defined(TR) || defined(TC) || \
    defined(CN) || defined(CT) || defined(CR) || defined(CC)
    mode |= (BLAS_TRANSA_T);
#endif
#if defined(NT) || defined(TT) || defined(RT) || defined(CT) || \
    defined(NC) || defined(TC) || defined(RC) || defined(CC)
    mode |= (BLAS_TRANSB_T);
#endif
    
#ifdef OS_WINDOWS
    gemm_thread_n(mode, args, range_m, range_n, GEMM_LOCAL,  sa, sb, divN); 
#else
    gemm_thread_n(mode, args, range_m, range_n, gemm_driver, sa, sb, divN); 
#endif

  }

  return 0;
}
Esempio n. 3
0
blasint CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLOAT *sb, BLASLONG myid) {

  BLASLONG n, bk, i, blocking, lda;
  BLASLONG info;
  int mode;
  blas_arg_t newarg;
  FLOAT *a;
  FLOAT alpha[2] = { -ONE, ZERO};

#ifndef COMPLEX
#ifdef XDOUBLE
  mode  =  BLAS_XDOUBLE | BLAS_REAL;
#elif defined(DOUBLE)
  mode  =  BLAS_DOUBLE  | BLAS_REAL;
#else
  mode  =  BLAS_SINGLE  | BLAS_REAL;
#endif  
#else
#ifdef XDOUBLE
  mode  =  BLAS_XDOUBLE | BLAS_COMPLEX;
#elif defined(DOUBLE)
  mode  =  BLAS_DOUBLE  | BLAS_COMPLEX;
#else
  mode  =  BLAS_SINGLE  | BLAS_COMPLEX;
#endif  
#endif

  if (args -> nthreads  == 1) {
#ifndef LOWER
    info = POTRF_U_SINGLE(args, NULL, NULL, sa, sb, 0); 
#else
    info = POTRF_L_SINGLE(args, NULL, NULL, sa, sb, 0); 
#endif
    return info;
  }

  n  = args -> n;
  a  = (FLOAT *)args -> a;
  lda = args -> lda;

  if (range_n) n  = range_n[1] - range_n[0];

  if (n <= GEMM_UNROLL_N * 2) {
#ifndef LOWER
    info = POTRF_U_SINGLE(args, NULL, range_n, sa, sb, 0);
#else
    info = POTRF_L_SINGLE(args, NULL, range_n, sa, sb, 0);
#endif
    return info;
  }

  newarg.lda = lda;
  newarg.ldb = lda;
  newarg.ldc = lda;
  newarg.alpha = alpha;
  newarg.beta = NULL;
  newarg.nthreads = args -> nthreads;

  blocking = (n / 2 + GEMM_UNROLL_N - 1) & ~(GEMM_UNROLL_N - 1);
  if (blocking > GEMM_Q) blocking = GEMM_Q;
    
  for (i = 0; i < n; i += blocking) {
    bk = n - i;
    if (bk > blocking) bk = blocking;

    newarg.m = bk;
    newarg.n = bk;
    newarg.a = a + (i + i * lda) * COMPSIZE;

    info = CNAME(&newarg, NULL, NULL, sa, sb, 0);
    if (info) return info + i;

    if (n - i - bk > 0) {
#ifndef USE_SIMPLE_THREADED_LEVEL3
      newarg.m = n - i - bk;
      newarg.k = bk;
#ifndef LOWER
      newarg.b = a + ( i       + (i + bk) * lda) * COMPSIZE;
#else
      newarg.b = a + ((i + bk) +  i       * lda) * COMPSIZE;
#endif
      newarg.c = a + ((i + bk) + (i + bk) * lda) * COMPSIZE;

      thread_driver(&newarg, sa, sb);
#else

#ifndef LOWER
    newarg.m = bk;
    newarg.n = n - i - bk;
    newarg.a = a + (i +  i       * lda) * COMPSIZE;
    newarg.b = a + (i + (i + bk) * lda) * COMPSIZE;

    gemm_thread_n(mode | BLAS_TRANSA_T,
		  &newarg, NULL, NULL, (void *)TRSM_LCUN, sa, sb, args -> nthreads);

    newarg.n = n - i - bk;
    newarg.k = bk;
    newarg.a = a + ( i       + (i + bk) * lda) * COMPSIZE;
    newarg.c = a + ((i + bk) + (i + bk) * lda) * COMPSIZE;

#if 0
    HERK_THREAD_UC(&newarg, NULL, NULL, sa, sb, 0);
#else
    syrk_thread(mode | BLAS_TRANSA_N | BLAS_TRANSB_T,
                &newarg, NULL, NULL, (void *)HERK_UC, sa, sb, args -> nthreads);
#endif
#else
    newarg.m = n - i - bk;
    newarg.n = bk;
    newarg.a = a + (i      + i * lda) * COMPSIZE;
    newarg.b = a + (i + bk + i * lda) * COMPSIZE;

    gemm_thread_m(mode | BLAS_RSIDE | BLAS_TRANSA_T | BLAS_UPLO,
		  &newarg, NULL, NULL, (void *)TRSM_RCLN, sa, sb, args -> nthreads);

    newarg.n = n - i - bk;
    newarg.k = bk;
    newarg.a = a + (i + bk +  i       * lda) * COMPSIZE;
    newarg.c = a + (i + bk + (i + bk) * lda) * COMPSIZE;
    
#if 0
    HERK_THREAD_LN(&newarg, NULL, NULL, sa, sb, 0);
#else
    syrk_thread(mode | BLAS_TRANSA_N | BLAS_TRANSB_T | BLAS_UPLO,
                &newarg, NULL, NULL, (void *)HERK_LN, sa, sb, args -> nthreads);
#endif
#endif

#endif
     }
  }
  return 0;
}