void DataChannelMPI::receive(at::Tensor& data, rank_type src_rank) { if (!data.is_contiguous()) throw std::logic_error("tensor to receive is not contiguous"); MPI_Recv(data.data_ptr(), data.numel(), mpi_datatype.at(data.type().scalarType()), src_rank, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE); }
void DataChannelMPI::send(at::Tensor& data, rank_type dst_rank) { if (!data.is_contiguous()) throw std::logic_error("tensor to send is not contiguous"); MPI_Send(data.data_ptr(), data.numel(), mpi_datatype.at(data.type().scalarType()), dst_rank, 0, MPI_COMM_WORLD); }
rank_type DataChannelMPI::receive(at::Tensor& data) { if (!data.is_contiguous()) throw std::logic_error("tensor to receive is not contiguous"); MPI_Status status; MPI_Recv(data.data_ptr(), data.numel(), mpi_datatype.at(data.type().scalarType()), MPI_ANY_SOURCE, 0, MPI_COMM_WORLD, &status); return status.MPI_SOURCE; }
void DataChannelMPI::allReduce(at::Tensor& data, THDReduceOp operation, THDGroup group_id) { const auto& comm = _groups.at(group_id).first; if (comm == MPI_COMM_NULL) return; if (!data.is_contiguous()) throw std::runtime_error("all_reduce input has to be contiguous"); MPI_Allreduce(MPI_IN_PLACE, data.data_ptr(), data.numel(), mpi_datatype.at(data.type().scalarType()), mpi_op.at(operation), comm); }
void DataChannelMPI::broadcast(at::Tensor& data, rank_type src_rank, THDGroup group_id) { const auto& group_pair = _groups.at(group_id); const auto& comm = group_pair.first; if (comm == MPI_COMM_NULL) return; if (!data.is_contiguous()) throw std::runtime_error("broadcast input has to be contiguous"); rank_type group_src_rank = group_pair.second.mustGetGroupRank(src_rank); MPI_Bcast(data.data_ptr(), data.numel(), mpi_datatype.at(data.type().scalarType()), group_src_rank, comm); }
void DataChannelMPI::scatter(std::vector<at::Tensor>& input, at::Tensor& output, rank_type src_rank, THDGroup group_id) { const auto& group_pair = _groups.at(group_id); const auto& comm = group_pair.first; if (comm == MPI_COMM_NULL) return; if (!output.is_contiguous()) throw std::runtime_error("scatter output has to be a contiguous tensor"); at::Tensor send_buffer; void *sendbuf = nullptr; if (_rank != src_rank) { if (input.size() > 0) throw std::logic_error("scatter: number of input tensors should be 0 for non root"); } else { if (input.size() != group_pair.second.size()) throw std::logic_error("scatter: number of input tensors and group size does not match"); for (auto in_tensor : input) assertSameSizeAndType(in_tensor, output, "scatter"); send_buffer = _newLikeFlat(input); for (size_t i = 0; i < input.size(); ++i) send_buffer[i].copy_(input[i]); sendbuf = send_buffer.data_ptr(); } rank_type group_src_rank = group_pair.second.mustGetGroupRank(src_rank); MPI_Scatter( sendbuf, output.numel(), mpi_datatype.at(output.type().scalarType()), output.data_ptr(), output.numel(), mpi_datatype.at(output.type().scalarType()), group_src_rank, comm ); }
void DataChannelMPI::gather(std::vector<at::Tensor>& output, at::Tensor& input, rank_type dst_rank, THDGroup group_id) { const auto& group_pair = _groups.at(group_id); const auto& comm = group_pair.first; if (comm == MPI_COMM_NULL) return; at::Tensor recv_buffer; void *recvbuf = nullptr; if (_rank != dst_rank) { if (output.size() > 0) throw std::logic_error("gather: number of input tensors should be 0 for non root"); } else { if (output.size() != group_pair.second.size()) throw std::logic_error("gather: number of output tensors and group size does not match"); for (auto out_tensor : output) assertSameSizeAndType(out_tensor, input, "gather"); recv_buffer = _newLikeFlat(output); recvbuf = recv_buffer.data_ptr(); } rank_type group_dst_rank = group_pair.second.mustGetGroupRank(dst_rank); auto contig_input = input.contiguous(); MPI_Gather( contig_input.data_ptr(), input.numel(), mpi_datatype.at(input.type().scalarType()), recvbuf, input.numel(), mpi_datatype.at(input.type().scalarType()), group_dst_rank, comm ); // NOTE: this is a no-op in all processes except dst_rank for (size_t i = 0; i < output.size(); ++i) output[i].copy_(recv_buffer[i]); }
void DataChannelMPI::reduce(at::Tensor& data, THDReduceOp operation, rank_type dst_rank, THDGroup group_id) { const auto& group_pair = _groups.at(group_id); const auto& comm = group_pair.first; if (comm == MPI_COMM_NULL) return; if (!data.is_contiguous()) throw std::runtime_error("reduce input has to be contiguous"); auto group_dst_rank = group_pair.second.mustGetGroupRank(dst_rank); void *sendbuf = (_rank == dst_rank) ? MPI_IN_PLACE : data.data_ptr(); void *recvbuf = (_rank == dst_rank) ? data.data_ptr() : nullptr; MPI_Reduce(sendbuf, recvbuf, data.numel(), mpi_datatype.at(data.type().scalarType()), mpi_op.at(operation), group_dst_rank, comm); }