/**
 * \brief NCCL implementation of \ref gpucomm_broadcast.
 */
static int broadcast(gpudata *array, size_t offset, size_t count, int typecode,
                     int root, gpucomm *comm) {
  // need dummy init so that compiler shuts up
  ncclDataType_t datatype = ncclNumTypes;
  int rank = 0;
  cuda_context *ctx;

  ASSERT_BUF(array);
  ASSERT_COMM(comm);
  GA_CHECK(check_restrictions(array, offset, NULL, 0, count, typecode, 0, comm,
                              &datatype, NULL));
  GA_CHECK(get_rank(comm, &rank));

  ctx = comm->ctx;
  cuda_enter(ctx);

  // sync: wait till a write has finished (out of concurrent kernels)
  if (rank == root)
    GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(array, CUDA_WAIT_READ));
  else
    GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(array, CUDA_WAIT_WRITE));

  // change stream of nccl ops to enable concurrency
  NCCL_EXIT_ON_ERROR(ctx, ncclBcast((void *)(array->ptr + offset), count,
                                    datatype, root, comm->c, ctx->s));

  if (rank == root)
    GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(array, CUDA_WAIT_READ));
  else
    GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(array, CUDA_WAIT_WRITE));

  cuda_exit(ctx);

  return GA_NO_ERROR;
}
int GpuArray_all_gather(const GpuArray* src, GpuArray* dest, gpucomm* comm) {
  size_t count = 0;
  int ndev = 0;
  GA_CHECK(gpucomm_get_count(comm, &ndev));
  GA_CHECK(check_gpuarrays(ndev, src, 1, dest, &count));
  return gpucomm_all_gather(src->data, src->offset, dest->data, dest->offset,
                            count, src->typecode, comm);
}
int GpuArray_reduce_scatter(const GpuArray* src, GpuArray* dest, int opcode,
                            gpucomm* comm) {
  size_t count = 0;
  int ndev = 0;
  GA_CHECK(gpucomm_get_count(comm, &ndev));
  GA_CHECK(check_gpuarrays(1, src, ndev, dest, &count));
  return gpucomm_reduce_scatter(src->data, src->offset, dest->data,
                                dest->offset, count, src->typecode, opcode,
                                comm);
}
int GpuArray_reduce(const GpuArray* src, GpuArray* dest, int opcode, int root,
                    gpucomm* comm) {
  int rank = 0;
  GA_CHECK(gpucomm_get_rank(comm, &rank));
  if (rank == root) {
    size_t count = 0;
    GA_CHECK(check_gpuarrays(1, src, 1, dest, &count));
    return gpucomm_reduce(src->data, src->offset, dest->data, dest->offset,
                          count, src->typecode, opcode, root, comm);
  } else {
    return GpuArray_reduce_from(src, opcode, root, comm);
  }
}
int GpuArray_all_reduce(const GpuArray* src, GpuArray* dest, int opcode,
                        gpucomm* comm) {
  size_t count = 0;
  GA_CHECK(check_gpuarrays(1, src, 1, dest, &count));
  return gpucomm_all_reduce(src->data, src->offset, dest->data, dest->offset,
                            count, src->typecode, opcode, comm);
}
/**
 * \brief NCCL implementation of \ref gpucomm_new.
 */
static int comm_new(gpucomm **comm_ptr, gpucontext *ctx,
                    gpucommCliqueId comm_id, int ndev, int rank) {
  gpucomm *comm;
  ncclResult_t err;

  ASSERT_CTX(ctx);

  GA_CHECK(setup_lib(ctx->err));

  comm = calloc(1, sizeof(*comm));  // Allocate memory
  if (comm == NULL) {
    *comm_ptr = NULL;  // Set to NULL if failed
    return error_sys(ctx->err, "calloc");
  }
  comm->ctx = (cuda_context *)ctx;  // convert to underlying cuda context
  // So that context would not be destroyed before communicator
  comm->ctx->refcnt++;
  cuda_enter(comm->ctx);  // Use device
  err = ncclCommInitRank(&comm->c, ndev, *((ncclUniqueId *)&comm_id), rank);
  cuda_exit(comm->ctx);
  TAG_COMM(comm);
  if (err != ncclSuccess) {
    *comm_ptr = NULL;  // Set to NULL if failed
    comm_clear(comm);
    return error_nccl(ctx->err, "ncclCommInitRank", err);
  }
  *comm_ptr = comm;
  return GA_NO_ERROR;
}
static int setup_lib(error *e) {
  if (setup_done)
    return GA_NO_ERROR;
  GA_CHECK(load_libnccl(e));
  setup_done = 1;
  return GA_NO_ERROR;
}
/**
 * \brief NCCL implementation of \ref gpucomm_all_reduce.
 */
static int all_reduce(gpudata *src, size_t offsrc, gpudata *dest,
                      size_t offdest, size_t count, int typecode, int opcode,
                      gpucomm *comm) {
  // need dummy init so that compiler shuts up
  ncclRedOp_t op = ncclNumOps;
  ncclDataType_t datatype = ncclNumTypes;
  cuda_context *ctx;

  ASSERT_BUF(src);
  ASSERT_COMM(comm);
  ASSERT_BUF(dest);
  GA_CHECK(check_restrictions(src, offsrc, dest, offdest, count, typecode,
                              opcode, comm, &datatype, &op));

  ctx = comm->ctx;
  cuda_enter(ctx);

  // sync: wait till a write has finished (out of concurrent kernels)
  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(src, CUDA_WAIT_READ));
  // sync: wait till a read/write has finished (out of concurrent kernels)
  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(dest, CUDA_WAIT_WRITE));

  // change stream of nccl ops to enable concurrency
  NCCL_EXIT_ON_ERROR(ctx, ncclAllReduce((void *)(src->ptr + offsrc),
                                        (void *)(dest->ptr + offdest), count,
                                        datatype, op, comm->c, ctx->s));

  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(src, CUDA_WAIT_READ));
  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(dest, CUDA_WAIT_WRITE));

  cuda_exit(ctx);

  return GA_NO_ERROR;
}
/**
 * \brief NCCL implementation of \ref gpucomm_reduce_scatter.
 */
static int reduce_scatter(gpudata *src, size_t offsrc, gpudata *dest,
                          size_t offdest, size_t count, int typecode,
                          int opcode, gpucomm *comm) {
  // need dummy init so that compiler shuts up
  ncclRedOp_t op = ncclNumOps;
  ncclDataType_t datatype = ncclNumTypes;
  int ndev = 0;
  size_t resc_size;
  cuda_context *ctx;

  ASSERT_BUF(src);
  ASSERT_COMM(comm);
  ASSERT_BUF(dest);
  GA_CHECK(get_count(comm, &ndev));
  GA_CHECK(check_restrictions(src, offsrc, NULL, 0, count * ndev, typecode,
                              opcode, comm, &datatype, &op));
  if (dest->ctx != comm->ctx)
    return error_set(comm->ctx->err, GA_VALUE_ERROR, "destination and comm context differ");
  resc_size = count * gpuarray_get_elsize(typecode);
  if ((dest->sz - offdest) < resc_size)
    return error_set(comm->ctx->err, GA_VALUE_ERROR, "destination too small for operation");
  assert(!(offdest > dest->sz));

  ctx = comm->ctx;
  cuda_enter(ctx);

  // sync: wait till a write has finished (out of concurrent kernels)
  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(src, CUDA_WAIT_READ));
  // sync: wait till a read/write has finished (out of concurrent kernels)
  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_wait(dest, CUDA_WAIT_WRITE));

  // change stream of nccl ops to enable concurrency
  NCCL_EXIT_ON_ERROR(ctx, ncclReduceScatter((void *)(src->ptr + offsrc),
                                            (void *)(dest->ptr + offdest), count,
                                            datatype, op, comm->c, ctx->s));

  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(src, CUDA_WAIT_READ));
  GA_CUDA_EXIT_ON_ERROR(ctx, cuda_record(dest, CUDA_WAIT_WRITE));

  cuda_exit(ctx);

  return GA_NO_ERROR;
}
int GpuArray_broadcast(GpuArray* array, int root, gpucomm* comm) {
  int rank = 0;
  size_t total_elems;
  GA_CHECK(gpucomm_get_rank(comm, &rank));
  if (rank == root) {
    if (!GpuArray_CHKFLAGS(array, GA_BEHAVED))
      return GA_UNALIGNED_ERROR;
  } else {
    if (!GpuArray_ISALIGNED(array))
      return GA_UNALIGNED_ERROR;
  }

  total_elems = find_total_elems(array);

  return gpucomm_broadcast(array->data, array->offset, total_elems,
                           array->typecode, root, comm);
}
/**
 * \brief NCCL implementation of \ref gpucomm_gen_clique_id.
 */
static int generate_clique_id(gpucontext *c, gpucommCliqueId *comm_id) {
  ASSERT_CTX(c);

  GA_CHECK(setup_lib(c->err));
  NCCL_CHKFAIL(c, ncclGetUniqueId((ncclUniqueId *)comm_id));
}