diff options
Diffstat (limited to 'boost/mpi/collectives/gather.hpp')
-rw-r--r-- | boost/mpi/collectives/gather.hpp | 170 |
1 files changed, 97 insertions, 73 deletions
diff --git a/boost/mpi/collectives/gather.hpp b/boost/mpi/collectives/gather.hpp index 70dfd65313..386bfdd1a1 100644 --- a/boost/mpi/collectives/gather.hpp +++ b/boost/mpi/collectives/gather.hpp @@ -8,6 +8,9 @@ #ifndef BOOST_MPI_GATHER_HPP #define BOOST_MPI_GATHER_HPP +#include <cassert> +#include <cstddef> +#include <numeric> #include <boost/mpi/exception.hpp> #include <boost/mpi/datatype.hpp> #include <vector> @@ -16,89 +19,116 @@ #include <boost/mpi/detail/point_to_point.hpp> #include <boost/mpi/communicator.hpp> #include <boost/mpi/environment.hpp> +#include <boost/mpi/detail/offsets.hpp> +#include <boost/mpi/detail/antiques.hpp> #include <boost/assert.hpp> namespace boost { namespace mpi { namespace detail { - // We're gathering at the root for a type that has an associated MPI - // datatype, so we'll use MPI_Gather to do all of the work. - template<typename T> - void - gather_impl(const communicator& comm, const T* in_values, int n, - T* out_values, int root, mpl::true_) - { - MPI_Datatype type = get_mpi_datatype<T>(*in_values); - BOOST_MPI_CHECK_RESULT(MPI_Gather, - (const_cast<T*>(in_values), n, type, - out_values, n, type, root, comm)); - } +// We're gathering at the root for a type that has an associated MPI +// datatype, so we'll use MPI_Gather to do all of the work. +template<typename T> +void +gather_impl(const communicator& comm, const T* in_values, int n, + T* out_values, int root, mpl::true_) +{ + MPI_Datatype type = get_mpi_datatype<T>(*in_values); + BOOST_MPI_CHECK_RESULT(MPI_Gather, + (const_cast<T*>(in_values), n, type, + out_values, n, type, root, comm)); +} - // We're gathering from a non-root for a type that has an associated MPI - // datatype, so we'll use MPI_Gather to do all of the work. - template<typename T> - void - gather_impl(const communicator& comm, const T* in_values, int n, int root, - mpl::true_) - { - MPI_Datatype type = get_mpi_datatype<T>(*in_values); - BOOST_MPI_CHECK_RESULT(MPI_Gather, - (const_cast<T*>(in_values), n, type, - 0, n, type, root, comm)); - } +// We're gathering from a non-root for a type that has an associated MPI +// datatype, so we'll use MPI_Gather to do all of the work. +template<typename T> +void +gather_impl(const communicator& comm, const T* in_values, int n, int root, + mpl::true_ is_mpi_type) +{ + assert(comm.rank() != root); + gather_impl(comm, in_values, n, (T*)0, root, is_mpi_type); +} - // We're gathering at the root for a type that does not have an - // associated MPI datatype, so we'll need to serialize - // it. Unfortunately, this means that we cannot use MPI_Gather, so - // we'll just have all of the non-root nodes send individual - // messages to the root. - template<typename T> - void - gather_impl(const communicator& comm, const T* in_values, int n, - T* out_values, int root, mpl::false_) - { - int tag = environment::collectives_tag(); - int size = comm.size(); - - for (int src = 0; src < size; ++src) { - if (src == root) - std::copy(in_values, in_values + n, out_values + n * src); - else - comm.recv(src, tag, out_values + n * src, n); +// We're gathering at the root for a type that does not have an +// associated MPI datatype, so we'll need to serialize +// it. +template<typename T> +void +gather_impl(const communicator& comm, const T* in_values, int n, T* out_values, + int const* nslot, int const* nskip, int root, mpl::false_) +{ + int nproc = comm.size(); + // first, gather all size, these size can be different for + // each process + packed_oarchive oa(comm); + for (int i = 0; i < n; ++i) { + oa << in_values[i]; + } + bool is_root = comm.rank() == root; + std::vector<int> oasizes(is_root ? nproc : 0); + int oasize = oa.size(); + BOOST_MPI_CHECK_RESULT(MPI_Gather, + (&oasize, 1, MPI_INTEGER, + c_data(oasizes), 1, MPI_INTEGER, + root, MPI_Comm(comm))); + // Gather the archives, which can be of different sizes, so + // we need to use gatherv. + // Everything is contiguous (in the transmitted archive), so + // the offsets can be deduced from the collected sizes. + std::vector<int> offsets; + if (is_root) sizes2offsets(oasizes, offsets); + packed_iarchive::buffer_type recv_buffer(is_root ? std::accumulate(oasizes.begin(), oasizes.end(), 0) : 0); + BOOST_MPI_CHECK_RESULT(MPI_Gatherv, + (const_cast<void*>(oa.address()), int(oa.size()), MPI_BYTE, + c_data(recv_buffer), c_data(oasizes), c_data(offsets), MPI_BYTE, + root, MPI_Comm(comm))); + if (is_root) { + for (int src = 0; src < nproc; ++src) { + // handle variadic case + int nb = nslot ? nslot[src] : n; + int skip = nskip ? nskip[src] : 0; + std::advance(out_values, skip); + if (src == root) { + BOOST_ASSERT(nb == n); + for (int i = 0; i < nb; ++i) { + *out_values++ = *in_values++; + } + } else { + packed_iarchive ia(comm, recv_buffer, boost::archive::no_header, offsets[src]); + for (int i = 0; i < nb; ++i) { + ia >> *out_values++; + } + } } } +} - // We're gathering at a non-root for a type that does not have an - // associated MPI datatype, so we'll need to serialize - // it. Unfortunately, this means that we cannot use MPI_Gather, so - // we'll just have all of the non-root nodes send individual - // messages to the root. - template<typename T> - void - gather_impl(const communicator& comm, const T* in_values, int n, int root, - mpl::false_) - { - int tag = environment::collectives_tag(); - comm.send(root, tag, in_values, n); - } +// We're gathering at a non-root for a type that does not have an +// associated MPI datatype, so we'll need to serialize +// it. +template<typename T> +void +gather_impl(const communicator& comm, const T* in_values, int n, T* out_values,int root, + mpl::false_ is_mpi_type) +{ + gather_impl(comm, in_values, n, out_values, (int const*)0, (int const*)0, root, is_mpi_type); +} } // end namespace detail template<typename T> void gather(const communicator& comm, const T& in_value, T* out_values, int root) { - if (comm.rank() == root) - detail::gather_impl(comm, &in_value, 1, out_values, root, - is_mpi_datatype<T>()); - else - detail::gather_impl(comm, &in_value, 1, root, is_mpi_datatype<T>()); + BOOST_ASSERT(out_values || (comm.rank() != root)); + detail::gather_impl(comm, &in_value, 1, out_values, root, is_mpi_datatype<T>()); } template<typename T> void gather(const communicator& comm, const T& in_value, int root) { BOOST_ASSERT(comm.rank() != root); - detail::gather_impl(comm, &in_value, 1, root, is_mpi_datatype<T>()); + detail::gather_impl(comm, &in_value, 1, (T*)0, root, is_mpi_datatype<T>()); } template<typename T> @@ -106,12 +136,11 @@ void gather(const communicator& comm, const T& in_value, std::vector<T>& out_values, int root) { + using detail::c_data; if (comm.rank() == root) { out_values.resize(comm.size()); - ::boost::mpi::gather(comm, in_value, &out_values[0], root); - } else { - ::boost::mpi::gather(comm, in_value, root); } + ::boost::mpi::gather(comm, in_value, c_data(out_values), root); } template<typename T> @@ -119,11 +148,8 @@ void gather(const communicator& comm, const T* in_values, int n, T* out_values, int root) { - if (comm.rank() == root) - detail::gather_impl(comm, in_values, n, out_values, root, - is_mpi_datatype<T>()); - else - detail::gather_impl(comm, in_values, n, root, is_mpi_datatype<T>()); + detail::gather_impl(comm, in_values, n, out_values, root, + is_mpi_datatype<T>()); } template<typename T> @@ -133,10 +159,8 @@ gather(const communicator& comm, const T* in_values, int n, { if (comm.rank() == root) { out_values.resize(comm.size() * n); - ::boost::mpi::gather(comm, in_values, n, &out_values[0], root); - } - else - ::boost::mpi::gather(comm, in_values, n, root); + } + ::boost::mpi::gather(comm, in_values, n, out_values.data(), root); } template<typename T> |