/**
 * \brief NCCL implementation of \ref gpucomm_broadcast.
 */
static int broadcast(gpudata *array, size_t offset, size_t count, int typecode,
                     int root, gpucomm *comm) {
  // need dummy init so that compiler shuts up
  ncclDataType_t datatype = ncclNumTypes;
  int rank = 0;
  cuda_context *ctx;

  ASSERT_BUF(array);
  ASSERT_COMM(comm);
  GA_CHECK(check_restrictions(array, offset, NULL, 0, count, typecode, 0, comm,
                              &datatype, NULL));
  GA_CHECK(get_rank(comm, &rank));

  ctx = comm->ctx;
  cuda_enter(ctx);

  // sync: wait till a write has finished (out of concurrent kernels)
  if (rank == root)
    GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(array, CUDA_WAIT_READ));
  else
    GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(array, CUDA_WAIT_WRITE));

  // change stream of nccl ops to enable concurrency
  NCCL_EXIT_ON_ERROR(ctx, ncclBcast((void *)(array->ptr + offset), count,
                                    datatype, root, comm->c, ctx->s));

  if (rank == root)
    GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(array, CUDA_WAIT_READ));
  else
    GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(array, CUDA_WAIT_WRITE));

  cuda_exit(ctx);

  return GA_NO_ERROR;
}
/**
 * \brief NCCL implementation of \ref gpucomm_all_reduce.
 */
static int all_reduce(gpudata *src, size_t offsrc, gpudata *dest,
                      size_t offdest, size_t count, int typecode, int opcode,
                      gpucomm *comm) {
  // need dummy init so that compiler shuts up
  ncclRedOp_t op = ncclNumOps;
  ncclDataType_t datatype = ncclNumTypes;
  cuda_context *ctx;

  ASSERT_BUF(src);
  ASSERT_COMM(comm);
  ASSERT_BUF(dest);
  GA_CHECK(check_restrictions(src, offsrc, dest, offdest, count, typecode,
                              opcode, comm, &datatype, &op));

  ctx = comm->ctx;
  cuda_enter(ctx);

  // sync: wait till a write has finished (out of concurrent kernels)
  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(src, CUDA_WAIT_READ));
  // sync: wait till a read/write has finished (out of concurrent kernels)
  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(dest, CUDA_WAIT_WRITE));

  // change stream of nccl ops to enable concurrency
  NCCL_EXIT_ON_ERROR(ctx, ncclAllReduce((void *)(src->ptr + offsrc),
                                        (void *)(dest->ptr + offdest), count,
                                        datatype, op, comm->c, ctx->s));

  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(src, CUDA_WAIT_READ));
  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(dest, CUDA_WAIT_WRITE));

  cuda_exit(ctx);

  return GA_NO_ERROR;
}
/**
 * \brief NCCL implementation of \ref gpucomm_reduce_scatter.
 */
static int reduce_scatter(gpudata *src, size_t offsrc, gpudata *dest,
                          size_t offdest, size_t count, int typecode,
                          int opcode, gpucomm *comm) {
  // need dummy init so that compiler shuts up
  ncclRedOp_t op = ncclNumOps;
  ncclDataType_t datatype = ncclNumTypes;
  int ndev = 0;
  size_t resc_size;
  cuda_context *ctx;

  ASSERT_BUF(src);
  ASSERT_COMM(comm);
  ASSERT_BUF(dest);
  GA_CHECK(get_count(comm, &ndev));
  GA_CHECK(check_restrictions(src, offsrc, NULL, 0, count * ndev, typecode,
                              opcode, comm, &datatype, &op));
  if (dest->ctx != comm->ctx)
    return error_set(comm->ctx->err, GA_VALUE_ERROR, "destination and comm context differ");
  resc_size = count * gpuarray_get_elsize(typecode);
  if ((dest->sz - offdest) < resc_size)
    return error_set(comm->ctx->err, GA_VALUE_ERROR, "destination too small for operation");
  assert(!(offdest > dest->sz));

  ctx = comm->ctx;
  cuda_enter(ctx);

  // sync: wait till a write has finished (out of concurrent kernels)
  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(src, CUDA_WAIT_READ));
  // sync: wait till a read/write has finished (out of concurrent kernels)
  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(dest, CUDA_WAIT_WRITE));

  // change stream of nccl ops to enable concurrency
  NCCL_EXIT_ON_ERROR(ctx, ncclReduceScatter((void *)(src->ptr + offsrc),
                                            (void *)(dest->ptr + offdest), count,
                                            datatype, op, comm->c, ctx->s));

  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(src, CUDA_WAIT_READ));
  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(dest, CUDA_WAIT_WRITE));

  cuda_exit(ctx);

  return GA_NO_ERROR;
}
static int dger(cb_order order, size_t M, size_t N, double alpha, gpudata *X,
                size_t offX, int incX, gpudata *Y, size_t offY, int incY,
                gpudata *A, size_t offA, size_t lda) {
  cuda_context *ctx = X->ctx;
  blas_handle *h = (blas_handle *)ctx->blas_handle;
  gpudata *td;
  size_t t;

  ASSERT_BUF(X);
  ASSERT_BUF(Y);
  ASSERT_BUF(A);

  if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(M * N) ||
      LARGE_VAL(lda) || LARGE_VAL(incX) || LARGE_VAL(incY))
    return GA_XLARGE_ERROR;

  if (order == cb_c) {
    t = M;
    M = N;
    N = t;
    t = offX;
    offX = offY;
    offY = t;
    t = incX;
    incX = incY;
    incY = t;
    td = X;
    X = Y;
    Y = td;
  }

  cuda_enter(ctx);

  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(X, CUDA_WAIT_READ));
  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(Y, CUDA_WAIT_READ));
  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(A, CUDA_WAIT_ALL));

  h->err = cublasDger(h->h, M, N, &alpha,
                                ((double *)X->ptr) + offX, incX,
                                ((double *)Y->ptr) + offY, incY,
                                ((double *)A->ptr) + offA, lda);
  if (h->err != CUBLAS_STATUS_SUCCESS) {
    cuda_exit(ctx);
    if (h->err == CUBLAS_STATUS_ARCH_MISMATCH)
      return GA_DEVSUP_ERROR;
    return GA_BLAS_ERROR;
  }

  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(X, CUDA_WAIT_READ));
  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(Y, CUDA_WAIT_READ));
  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(A, CUDA_WAIT_ALL));

  cuda_exit(ctx);

  return GA_NO_ERROR;
}
static int dgemv(cb_order order, cb_transpose transA, size_t M, size_t N,
                 double alpha, gpudata *A, size_t offA, size_t lda,
                 gpudata *X, size_t offX, int incX,
                 double beta, gpudata *Y, size_t offY, int incY) {
  cuda_context *ctx = A->ctx;
  blas_handle *h = (blas_handle *)ctx->blas_handle;
  size_t t;

  ASSERT_BUF(A);
  ASSERT_BUF(X);
  ASSERT_BUF(Y);

  if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(M * N) ||
      LARGE_VAL(lda) || LARGE_VAL(incX) || LARGE_VAL(incY))
    return GA_XLARGE_ERROR;

  if (order == cb_c) {
    t = N;
    N = M;
    M = t;

    if (transA == cb_no_trans) {
      transA = cb_trans;
    } else {
      transA = cb_no_trans;
    }
  }

  cuda_enter(ctx);

  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(A, CUDA_WAIT_READ));
  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(X, CUDA_WAIT_READ));
  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(Y, CUDA_WAIT_ALL));

  h->err = cublasDgemv(h->h,
                       convT(transA), M, N, &alpha,
                       ((double *)A->ptr) + offA, lda,
                       ((double *)X->ptr) + offX, incX,
                       &beta, ((double *)Y->ptr) + offY, incY);
  if (h->err != CUBLAS_STATUS_SUCCESS) {
    cuda_exit(ctx);
    if (h->err == CUBLAS_STATUS_ARCH_MISMATCH)
      return GA_DEVSUP_ERROR;
    return GA_BLAS_ERROR;
  }

  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(A, CUDA_WAIT_READ));
  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(X, CUDA_WAIT_READ));
  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(Y, CUDA_WAIT_ALL));

  cuda_exit(ctx);

  return GA_NO_ERROR;
}
static int sgemvBatch(cb_order order, cb_transpose transA,
                      size_t M, size_t N, float alpha,
                      gpudata **A, size_t *offA, size_t lda,
                      gpudata **x, size_t *offX, size_t incX,
                      float beta, gpudata **y, size_t *offY, size_t incY,
                      size_t batchCount, int flags) {
  /* Flags is there for possible future implementations where we might
     not use atomics or have some alternate implemntation. */
  cuda_context *ctx;
  size_t t, i;
  size_t ls[2], gs[2];
  void *args[9];
  gpudata *Aa, *xa, *ya;
  int err;

  if (flags != 0) return GA_INVALID_ERROR;
  if (batchCount == 0) return GA_NO_ERROR;

  if (alpha != 1.0 || beta != 1.0) return GA_UNSUPPORTED_ERROR;

  if (M < 512) {
    ls[0] = 32;
    if (batchCount > 16)
      ls[1] = 16;
    else
      ls[1] = batchCount;
  } else {
    ls[0] = 512;
    ls[1] = 1;
  }
  gs[0] = (M + ls[0] - 1) / ls[0];
  gs[1] = (batchCount + ls[1] - 1) / ls[1];
  if (gs[0] * gs[1] / 65535) {
    gs[1] = (65535 / gs[0]);
  }

  if (order == cb_c) {
    t = N;
    N = M;
    M = t;
    if (transA == cb_no_trans) {
      transA = cb_trans;
    } else {
      transA = cb_no_trans;
    }
  }

  ASSERT_BUF(A[0]);

  ctx = A[0]->ctx;

  cuda_enter(ctx);

  {
    float **T_l = alloca(sizeof(float *) * batchCount * 3);
    const float **A_l = (const float **)T_l;
    const float **x_l = (const float **)T_l + batchCount;
    float **y_l = T_l + (batchCount * 2);

    for (i = 0; i < batchCount; i++) {
      ASSERT_BUF(A[i]);
      ASSERT_BUF(x[i]);
      ASSERT_BUF(y[i]);
      GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(A[i], CUDA_WAIT_READ));
      GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(x[i], CUDA_WAIT_READ));
      GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(y[i], CUDA_WAIT_ALL));
      A_l[i] = (float *)(A[i]->ptr + offA[i]);
      x_l[i] = (float *)(x[i]->ptr + offX[i]);
      y_l[i] = (float *)(y[i]->ptr + offY[i]);
    }

    Aa = cuda_ops.buffer_alloc((gpucontext *)ctx, sizeof(float *) * batchCount, A_l,
                               GA_BUFFER_INIT, &err);
    if (Aa == NULL)
      return err;
    xa = cuda_ops.buffer_alloc((gpucontext *)ctx, sizeof(float *) * batchCount, x_l,
                               GA_BUFFER_INIT, &err);
    if (xa == NULL) {
      cuda_ops.buffer_release(Aa);
      return err;
    }
    ya = cuda_ops.buffer_alloc((gpucontext *)ctx, sizeof(float *) * batchCount, y_l,
                               GA_BUFFER_INIT, &err);
    if (ya == NULL) {
      cuda_ops.buffer_release(Aa);
      cuda_ops.buffer_release(xa);
      return err;
    }
  }

  args[0] = Aa;
  args[1] = &lda;
  args[2] = xa;
  args[3] = &incX;
  args[4] = ya;
  args[5] = &incY;
  args[6] = &batchCount;
  args[7] = &M;
  args[8] = &N;

  if (transA == cb_no_trans) {
    err = GpuKernel_call(&((blas_handle *)ctx->blas_handle)->sgemvBH_N_a1_b1_small, 2, ls, gs, 0, args);
  } else {
    err = GpuKernel_call(&((blas_handle *)ctx->blas_handle)->sgemvBH_T_a1_b1_small, 2, ls, gs, 0, args);
  }

  cuda_ops.buffer_release(Aa);
  cuda_ops.buffer_release(xa);
  cuda_ops.buffer_release(ya);

  if (err != GA_NO_ERROR) {
    cuda_exit(ctx);
    return err;
  }


  for (i = 0; i < batchCount; i++) {
    GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(A[i], CUDA_WAIT_READ));
    GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(x[i], CUDA_WAIT_READ));
    GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(y[i], CUDA_WAIT_ALL));
  }

  cuda_exit(ctx);
  return GA_NO_ERROR;
}
static int dgemmBatch(cb_order order, cb_transpose transA, cb_transpose transB,
                      size_t M, size_t N, size_t K, double alpha,
                      gpudata **A, size_t *offA, size_t lda,
                      gpudata **B, size_t *offB, size_t ldb,
                      double beta, gpudata **C, size_t *offC, size_t ldc,
                      size_t batchCount) {
  cuda_context *ctx;
  blas_handle *h;
  size_t *lt, t;
  gpudata **T;
  size_t i;
  const size_t threshold = 650;
  cb_transpose transT;

  if (batchCount == 0) return GA_NO_ERROR;

  if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) ||
      LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc) ||
      LARGE_VAL(M * N) || LARGE_VAL(M * K) || LARGE_VAL(K * N))
    return GA_XLARGE_ERROR;

  ASSERT_BUF(A[0]);
  ctx = A[0]->ctx;
  h = (blas_handle *)ctx->blas_handle;
  cuda_enter(ctx);

  if (order == cb_c) {
    /* swap A and B */
    t = N;
    N = M;
    M = t;
    T = A;
    A = B;
    B = T;
    t = lda;
    lda = ldb;
    ldb = t;
    transT = transA;
    transA = transB;
    transB = transT;
    lt = offA;
    offA = offB;
    offB = lt;
  }

  /* use parallel cublasSgemm calls rather than cublasSgemmBatched for
   * large products */
  if (M * N * K > threshold * threshold * threshold) {
    for (i = 0; i < batchCount; i++) {
      ASSERT_BUF(A[i]);
      ASSERT_BUF(B[i]);
      ASSERT_BUF(C[i]);
      GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(A[i], CUDA_WAIT_READ));
      GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(B[i], CUDA_WAIT_READ));
      GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(C[i], CUDA_WAIT_ALL));

      h->err = cublasDgemm(h->h,
                           convT(transA), convT(transB),
                           M, N, K, &alpha,
                           (double*)A[i]->ptr + offA[i], lda,
                           (double*)B[i]->ptr + offB[i], ldb,
                           &beta,
                           (double*)C[i]->ptr + offC[i], ldc);
      if (h->err != CUBLAS_STATUS_SUCCESS) {
        cuda_exit(ctx);
        if (h->err == CUBLAS_STATUS_ARCH_MISMATCH)
          return GA_DEVSUP_ERROR;
        return GA_BLAS_ERROR;
      }

      GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(A[i], CUDA_WAIT_READ));
      GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(B[i], CUDA_WAIT_READ));
      GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(C[i], CUDA_WAIT_ALL));
    }
  } else {
    double **T_l = alloca(sizeof(double *) * batchCount * 3);
    const double **A_l = (const double **)T_l;
    const double **B_l = (const double **)T_l + batchCount;
    double **C_l = T_l + (batchCount * 2);
    CUdeviceptr Ta, Aa, Ba, Ca;

    for (i = 0; i < batchCount; i++) {
      ASSERT_BUF(A[i]);
      ASSERT_BUF(B[i]);
      ASSERT_BUF(C[i]);
      GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(A[i], CUDA_WAIT_READ));
      GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(B[i], CUDA_WAIT_READ));
      GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(C[i], CUDA_WAIT_ALL));
      A_l[i] = ((double *)A[i]->ptr) + offA[i];
      B_l[i] = ((double *)B[i]->ptr) + offB[i];
      C_l[i] = ((double *)C[i]->ptr) + offC[i];
    }

    cuMemAlloc(&Ta, sizeof(double *) * batchCount * 3);
    Aa = Ta;
    Ba = Ta + (batchCount * sizeof(double *));
    Ca = Ta + (batchCount * sizeof(double *) * 2);

    cuMemcpyHtoD(Ta, T_l, sizeof(double *) * batchCount * 3);

    h->err = cublasDgemmBatched(h->h,
                                convT(transA), convT(transB),
                                M, N, K, &alpha,
                                (const double **)Aa, lda,
                                (const double **)Ba, ldb, &beta,
                                (double **)Ca, ldc, batchCount);
    cuMemFree(Ta);
    if (h->err != CUBLAS_STATUS_SUCCESS) {
      cuda_exit(ctx);
      if (h->err == CUBLAS_STATUS_ARCH_MISMATCH)
        return GA_DEVSUP_ERROR;
      return GA_BLAS_ERROR;
    }

    for (i = 0; i < batchCount; i++) {
      GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(A[i], CUDA_WAIT_READ));
      GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(B[i], CUDA_WAIT_READ));
      GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(C[i], CUDA_WAIT_ALL));
    }
  }

  cuda_exit(ctx);
  return GA_NO_ERROR;
}
static int hgemm(cb_order order, cb_transpose transA, cb_transpose transB,
                 size_t M, size_t N, size_t K, float alpha,
                 gpudata *A, size_t offA, size_t lda,
                 gpudata *B, size_t offB, size_t ldb,
                 float beta, gpudata *C, size_t offC, size_t ldc) {
#ifdef HAVE_CUBLAS_SGEMMEX
  /* This will use float32 for computation as it's the best we can
   * have right now. In the future when native float16 support will be
   * there we will switch to that. */
  cuda_context *ctx = A->ctx;
  blas_handle *h = (blas_handle *)ctx->blas_handle;
  gpudata *T;
  size_t t;
  cb_transpose transT;

  ASSERT_BUF(A);
  ASSERT_BUF(B);
  ASSERT_BUF(C);

  if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) ||
      LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc) ||
      LARGE_VAL(M * N) || LARGE_VAL(M * K) || LARGE_VAL(K * N))
    return GA_XLARGE_ERROR;

  if (order == cb_c) {
    /* swap A and B */
    t = N;
    N = M;
    M = t;
    T = A;
    A = B;
    B = T;
    t = lda;
    lda = ldb;
    ldb = t;
    transT = transA;
    transA = transB;
    transB = transT;
    t = offA;
    offA = offB;
    offB = t;
  }

  cuda_enter(ctx);

  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(A, CUDA_WAIT_READ));
  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(B, CUDA_WAIT_READ));
  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(C, CUDA_WAIT_ALL));

  h->err = cublasSgemmEx(h->h,
                         convT(transA), convT(transB), M, N, K,
                         &alpha, ((uint16_t *)A->ptr) + offA,
#if CUDA_VERSION >= 8000
                         CUDA_R_16F,
#else
                         CUBLAS_DATA_HALF,
#endif
                         lda, ((uint16_t *)B->ptr) + offB,
#if CUDA_VERSION >= 8000
                         CUDA_R_16F,
#else
                         CUBLAS_DATA_HALF,
#endif
                         ldb, &beta, ((uint16_t *)C->ptr) + offC,
#if CUDA_VERSION >= 8000
                         CUDA_R_16F,
#else
                         CUBLAS_DATA_HALF,
#endif
                         ldc);
  if (h->err != CUBLAS_STATUS_SUCCESS) {
    cuda_exit(ctx);
    if (h->err == CUBLAS_STATUS_ARCH_MISMATCH)
      return GA_DEVSUP_ERROR;
    return GA_BLAS_ERROR;
  }

  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(A, CUDA_WAIT_READ));
  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(B, CUDA_WAIT_READ));
  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(C, CUDA_WAIT_ALL));

  cuda_exit(ctx);
  return GA_NO_ERROR;
#else
  return GA_DEVSUP_ERROR;
#endif
}
static int dgemm(cb_order order, cb_transpose transA, cb_transpose transB,
                 size_t M, size_t N, size_t K, double alpha,
                 gpudata *A, size_t offA, size_t lda,
                 gpudata *B, size_t offB, size_t ldb,
                 double beta, gpudata *C, size_t offC, size_t ldc) {
  cuda_context *ctx = A->ctx;
  blas_handle *h = (blas_handle *)ctx->blas_handle;
  gpudata *T;
  size_t t;
  cb_transpose transT;

  ASSERT_BUF(A);
  ASSERT_BUF(B);
  ASSERT_BUF(C);

  if (LARGE_VAL(M) || LARGE_VAL(N) || LARGE_VAL(K) ||
      LARGE_VAL(lda) || LARGE_VAL(ldb) || LARGE_VAL(ldc) ||
      LARGE_VAL(M * N) || LARGE_VAL(M * K) || LARGE_VAL(K * N))
    return GA_XLARGE_ERROR;

  if (order == cb_c) {
    /* swap A and B */
    t = N;
    N = M;
    M = t;
    T = A;
    A = B;
    B = T;
    t = lda;
    lda = ldb;
    ldb = t;
    transT = transA;
    transA = transB;
    transB = transT;
    t = offA;
    offA = offB;
    offB = t;
  }

  cuda_enter(ctx);

  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(A, CUDA_WAIT_READ));
  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(B, CUDA_WAIT_READ));
  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(C, CUDA_WAIT_ALL));

  h->err = cublasDgemm(h->h,
                       convT(transA), convT(transB), M, N, K,
                       &alpha, ((double *)A->ptr) + offA, lda,
                       ((double *)B->ptr) + offB, ldb, &beta,
                       ((double *)C->ptr) + offC, ldc);
  if (h->err != CUBLAS_STATUS_SUCCESS) {
    cuda_exit(ctx);
    if (h->err == CUBLAS_STATUS_ARCH_MISMATCH)
      return GA_DEVSUP_ERROR;
    return GA_BLAS_ERROR;
  }

  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(A, CUDA_WAIT_READ));
  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(B, CUDA_WAIT_READ));
  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(C, CUDA_WAIT_ALL));

  cuda_exit(ctx);
  return GA_NO_ERROR;
}
static int dgerBatch(cb_order order, size_t M, size_t N, double alpha,
                     gpudata **x, size_t *offX, size_t incX,
                     gpudata **y, size_t *offY, size_t incY,
                     gpudata **A, size_t *offA, size_t lda,
                     size_t batchCount, int flags) {
  cuda_context *ctx;
  size_t t, *tp, i;
  size_t ls[3] = {M, N, 1}, gs[3] = {1, 1, batchCount};
  void *args[10];
  gpudata **T;
  gpudata *Aa, *xa, *ya;
  int err;

  if (flags != 0) return GA_INVALID_ERROR;
  if (batchCount == 0) return GA_NO_ERROR;

  if (incX == 1) {
    if (ls[0] > 32) {
      gs[0] = (ls[0] + 31) / 32;
      ls[0] = 32;
    }
    if (ls[0] * ls[1] > 512) {
      gs[1] = (ls[1] + 15) / 16;
      ls[1] = 16;
    }
  } else {
    if (ls[1] > 32) {
      gs[1] = (ls[1] + 31) / 32;
      ls[1] = 32;
    }
    if (ls[0] * ls[1] > 512) {
      gs[0] = (ls[0] + 15) / 16;
      ls[0] = 16;
    }
  }
  if (gs[0] * gs[1] * gs[2] > 65535) {
    if (gs[0] * gs[1] > 65535)
      return GA_VALUE_ERROR;
    gs[2] = (65535 / (gs[0] * gs[1]));
  }

  if (order == cb_c) {
    t = M;
    M = N;
    N = t;
    tp = offX;
    offX = offY;
    offY = tp;
    t = incX;
    incX = incY;
    incY = t;
    T = x;
    x = y;
    y = T;
  }

  ASSERT_BUF(x[0]);

  ctx = x[0]->ctx;

  cuda_enter(ctx);

  {
    double **T_l = alloca(sizeof(double *) * batchCount * 3);
    const double **A_l = (const double **)T_l;
    const double **x_l = (const double **)T_l + batchCount;
    double **y_l = T_l + (batchCount * 2);

    for (i = 0; i < batchCount; i++) {
      ASSERT_BUF(A[i]);
      ASSERT_BUF(x[i]);
      ASSERT_BUF(y[i]);
      GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(A[i], CUDA_WAIT_ALL));
      GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(x[i], CUDA_WAIT_READ));
      GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(y[i], CUDA_WAIT_READ));
      A_l[i] = (double *)(A[i]->ptr + offA[i]);
      x_l[i] = (double *)(x[i]->ptr + offX[i]);
      y_l[i] = (double *)(y[i]->ptr + offY[i]);
    }

    Aa = cuda_ops.buffer_alloc((gpucontext *)ctx, sizeof(double *) * batchCount, A_l,
                               GA_BUFFER_INIT, &err);
    if (Aa == NULL)
      return err;
    xa = cuda_ops.buffer_alloc((gpucontext *)ctx, sizeof(double *) * batchCount, x_l,
                               GA_BUFFER_INIT, &err);
    if (xa == NULL) {
      cuda_ops.buffer_release(Aa);
      return err;
    }
    ya = cuda_ops.buffer_alloc((gpucontext *)ctx, sizeof(double *) * batchCount, y_l,
                               GA_BUFFER_INIT, &err);
    if (ya == NULL) {
      cuda_ops.buffer_release(Aa);
      cuda_ops.buffer_release(xa);
      return err;
    }
  }

  args[0] = xa;
  args[1] = &incX;
  args[2] = ya;
  args[3] = &incY;
  args[4] = &alpha;
  args[5] = Aa;
  args[6] = &lda;
  args[7] = &batchCount;
  args[8] = &M;
  args[9] = &N;

  err = GpuKernel_call(&((blas_handle *)ctx->blas_handle)->sgerBH_gen_small, 3, ls, gs, 0, args);

  cuda_ops.buffer_release(Aa);
  cuda_ops.buffer_release(xa);
  cuda_ops.buffer_release(ya);

  if (err != GA_NO_ERROR) {
    cuda_exit(ctx);
    return err;
  }


  for (i = 0; i < batchCount; i++) {
    GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(A[i], CUDA_WAIT_READ));
    GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(x[i], CUDA_WAIT_READ));
    GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(y[i], CUDA_WAIT_ALL));
  }

  cuda_exit(ctx);
  return GA_NO_ERROR;
}