Exemplo n.º 1
0
 std::vector<T> mpi_reduce(std::vector<T> const &a, communicator c, int root, bool all, MPI_Op op, std::true_type) {
  std::vector<T> b(a.size());
  if (!all)
   MPI_Reduce((void *)a.data(), b.data(), a.size(), mpi_datatype<T>(), op, root, c.get());
  else
   MPI_Allreduce((void *)a.data(), b.data(), a.size(), mpi_datatype<T>(), op, c.get());
  return b;
 }
Exemplo n.º 2
0
Arquivo: mpi.hpp Projeto: TRIQS/triqs
 template <typename T> REQUIRES_IS_BASIC(T, T) mpi_reduce(T a, communicator c = {}, int root = 0, bool all = false, MPI_Op op = MPI_SUM) {
   T b;
   auto d = datatype<T>();
   if (!all)
     MPI_Reduce(&a, &b, 1, d, op, root, c.get());
   else
     MPI_Allreduce(&a, &b, 1, d, op, c.get());
   return b;
 }
Exemplo n.º 3
0
 //---------
 static void broadcast(communicator c, A &a, int root) {
     check_is_contiguous(a);
     auto sh = a.shape();
     MPI_Bcast(&sh[0], sh.size(), mpi_datatype<typename decltype(sh)::value_type>::invoke(), root, c.get());
     if (c.rank() != root) a.resize(sh);
     MPI_Bcast(a.data_start(), a.domain().number_of_elements(), D(), root, c.get());
 }
Exemplo n.º 4
0
 template <typename T> std::vector<T> mpi_gather(std::vector<T> const &a, communicator c, int root, bool all, std::true_type) {
  long size = mpi_reduce(a.size(), c, root, all);
  std::vector<T> b((all || (c.rank() == root) ? size : 0));

  auto recvcounts = std::vector<int>(c.size());
  auto displs = std::vector<int>(c.size() + 1, 0);
  int sendcount = a.size();
  auto mpi_ty = mpi::mpi_datatype<int>();
  if (!all)
   MPI_Gather(&sendcount, 1, mpi_ty, &recvcounts[0], 1, mpi_ty, root, c.get());
  else
   MPI_Allgather(&sendcount, 1, mpi_ty, &recvcounts[0], 1, mpi_ty, c.get());

  for (int r = 0; r < c.size(); ++r) displs[r + 1] = recvcounts[r] + displs[r];

  if (!all)
   MPI_Gatherv((void *)a.data(), sendcount, mpi_datatype<T>(), (void *)b.data(), &recvcounts[0], &displs[0], mpi_datatype<T>(),
               root, c.get());
  else
   MPI_Allgatherv((void *)a.data(), sendcount, mpi_datatype<T>(), (void *)b.data(), &recvcounts[0], &displs[0], mpi_datatype<T>(),
                  c.get());

  return b;
 }
Exemplo n.º 5
0
 template <typename T> std::vector<T> mpi_scatter(std::vector<T> const &a, communicator c, int root, std::true_type) {
  auto slow_size = a.size();
  auto sendcounts = std::vector<int>(c.size());
  auto displs = std::vector<int>(c.size() + 1, 0);
  int recvcount = slice_length(slow_size - 1, c.size(), c.rank());
  std::vector<T> b(recvcount);

  for (int r = 0; r < c.size(); ++r) {
   sendcounts[r] = slice_length(slow_size - 1, c.size(), r);
   displs[r + 1] = sendcounts[r] + displs[r];
  }

  MPI_Scatterv((void *)a.data(), &sendcounts[0], &displs[0], mpi_datatype<T>(), (void *)b.data(), recvcount, mpi_datatype<T>(),
               root, c.get());
  return b;
 }
Exemplo n.º 6
0
Arquivo: mpi.hpp Projeto: TRIQS/triqs
 // NOTE: We keep the naming mpi_XXX for the actual implementation functions
 // so they can be defined in other namespaces and the mpi::reduce(T,...) function
 // can find them via ADL
 template <typename T> REQUIRES_IS_BASIC(T, void) mpi_broadcast(T &a, communicator c = {}, int root = 0) {
   MPI_Bcast(&a, 1, datatype<T>(), root, c.get());
 }
Exemplo n.º 7
0
 //---------
 static void all_reduce_in_place(communicator c, A &a, int root) {
     check_is_contiguous(a);
     // assume arrays have the same size on all nodes...
     MPI_Allreduce(MPI_IN_PLACE, a.data_start(), a.domain().number_of_elements(), D(), MPI_SUM, c.get());
 }
Exemplo n.º 8
0
 void mpi_reduce_in_place(std::vector<T> &a, communicator c, int root, bool all, MPI_Op op, std::true_type) {
  if (!all)
   MPI_Reduce((c.rank() == root ? MPI_IN_PLACE : a.data()), a.data(), a.size(), mpi_datatype<T>(), op, root, c.get());
  else
   MPI_Allreduce(MPI_IN_PLACE, a.data(), a.size(), mpi_datatype<T>(), op, c.get());
 }
Exemplo n.º 9
0
 template <typename T> void mpi_broadcast(std::vector<T> &a, communicator c, int root, std::true_type) {
  size_t s = a.size();
  mpi_broadcast(s, c, root);
  if (c.rank() != root) a.resize(s);
  MPI_Bcast(a.data(), a.size(), mpi_datatype<T>(), root, c.get());
 }