Пример #1
0
void DataChannelMPI::receive(at::Tensor& data, rank_type src_rank) {
  if (!data.is_contiguous())
    throw std::logic_error("tensor to receive is not contiguous");

  MPI_Recv(data.data_ptr(), data.numel(), mpi_datatype.at(data.type().scalarType()),
           src_rank, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
}
Пример #2
0
void DataChannelMPI::send(at::Tensor& data, rank_type dst_rank) {
  if (!data.is_contiguous())
    throw std::logic_error("tensor to send is not contiguous");

  MPI_Send(data.data_ptr(), data.numel(), mpi_datatype.at(data.type().scalarType()),
           dst_rank, 0, MPI_COMM_WORLD);
}
Пример #3
0
rank_type DataChannelMPI::receive(at::Tensor& data) {
  if (!data.is_contiguous())
    throw std::logic_error("tensor to receive is not contiguous");

  MPI_Status status;
  MPI_Recv(data.data_ptr(), data.numel(), mpi_datatype.at(data.type().scalarType()),
           MPI_ANY_SOURCE, 0, MPI_COMM_WORLD, &status);
  return status.MPI_SOURCE;
}
Пример #4
0
void DataChannelMPI::allReduce(at::Tensor& data, THDReduceOp operation,
                               THDGroup group_id) {
  const auto& comm = _groups.at(group_id).first;
  if (comm == MPI_COMM_NULL)
    return;

  if (!data.is_contiguous())
    throw std::runtime_error("all_reduce input has to be contiguous");

  MPI_Allreduce(MPI_IN_PLACE, data.data_ptr(), data.numel(),
                mpi_datatype.at(data.type().scalarType()), mpi_op.at(operation), comm);
}
Пример #5
0
void DataChannelMPI::broadcast(at::Tensor& data, rank_type src_rank,
                               THDGroup group_id) {
  const auto& group_pair = _groups.at(group_id);
  const auto& comm = group_pair.first;
  if (comm == MPI_COMM_NULL)
    return;

  if (!data.is_contiguous())
    throw std::runtime_error("broadcast input has to be contiguous");

  rank_type group_src_rank = group_pair.second.mustGetGroupRank(src_rank);
  MPI_Bcast(data.data_ptr(), data.numel(), mpi_datatype.at(data.type().scalarType()),
            group_src_rank, comm);
}
Пример #6
0
void DataChannelMPI::scatter(std::vector<at::Tensor>& input,
                             at::Tensor& output,
                             rank_type src_rank, THDGroup group_id) {
  const auto& group_pair = _groups.at(group_id);
  const auto& comm = group_pair.first;
  if (comm == MPI_COMM_NULL)
    return;

  if (!output.is_contiguous())
    throw std::runtime_error("scatter output has to be a contiguous tensor");

  at::Tensor send_buffer;
  void *sendbuf = nullptr;
  if (_rank != src_rank) {
    if (input.size() > 0)
      throw std::logic_error("scatter: number of input tensors should be 0 for non root");
  } else {
    if (input.size() != group_pair.second.size())
      throw std::logic_error("scatter: number of input tensors and group size does not match");

    for (auto in_tensor : input)
      assertSameSizeAndType(in_tensor, output, "scatter");

    send_buffer = _newLikeFlat(input);
    for (size_t i = 0; i < input.size(); ++i)
      send_buffer[i].copy_(input[i]);
    sendbuf = send_buffer.data_ptr();
  }

  rank_type group_src_rank = group_pair.second.mustGetGroupRank(src_rank);

  MPI_Scatter(
    sendbuf, output.numel(), mpi_datatype.at(output.type().scalarType()),
    output.data_ptr(), output.numel(), mpi_datatype.at(output.type().scalarType()),
    group_src_rank, comm
  );
}
Пример #7
0
void DataChannelMPI::gather(std::vector<at::Tensor>& output,
                            at::Tensor& input, rank_type dst_rank,
                            THDGroup group_id) {
  const auto& group_pair = _groups.at(group_id);
  const auto& comm = group_pair.first;
  if (comm == MPI_COMM_NULL)
    return;

  at::Tensor recv_buffer;
  void *recvbuf = nullptr;
  if (_rank != dst_rank) {
    if (output.size() > 0)
      throw std::logic_error("gather: number of input tensors should be 0 for non root");
  } else {
    if (output.size() != group_pair.second.size())
      throw std::logic_error("gather: number of output tensors and group size does not match");

    for (auto out_tensor : output)
      assertSameSizeAndType(out_tensor, input, "gather");

    recv_buffer = _newLikeFlat(output);
    recvbuf = recv_buffer.data_ptr();
  }

  rank_type group_dst_rank = group_pair.second.mustGetGroupRank(dst_rank);
  auto contig_input = input.contiguous();

  MPI_Gather(
    contig_input.data_ptr(), input.numel(), mpi_datatype.at(input.type().scalarType()),
    recvbuf, input.numel(), mpi_datatype.at(input.type().scalarType()),
    group_dst_rank, comm
  );

  // NOTE: this is a no-op in all processes except dst_rank
  for (size_t i = 0; i < output.size(); ++i)
    output[i].copy_(recv_buffer[i]);
}
Пример #8
0
void DataChannelMPI::reduce(at::Tensor& data, THDReduceOp operation,
                            rank_type dst_rank, THDGroup group_id) {
  const auto& group_pair = _groups.at(group_id);
  const auto& comm = group_pair.first;
  if (comm == MPI_COMM_NULL)
    return;

  if (!data.is_contiguous())
    throw std::runtime_error("reduce input has to be contiguous");

  auto group_dst_rank = group_pair.second.mustGetGroupRank(dst_rank);
  void *sendbuf = (_rank == dst_rank) ? MPI_IN_PLACE    : data.data_ptr();
  void *recvbuf = (_rank == dst_rank) ? data.data_ptr() : nullptr;
  MPI_Reduce(sendbuf, recvbuf, data.numel(), mpi_datatype.at(data.type().scalarType()),
             mpi_op.at(operation), group_dst_rank, comm);
}