diff options
Diffstat (limited to 'boost/mpi/collectives/all_reduce.hpp')
-rw-r--r-- | boost/mpi/collectives/all_reduce.hpp | 31 |
1 files changed, 29 insertions, 2 deletions
diff --git a/boost/mpi/collectives/all_reduce.hpp b/boost/mpi/collectives/all_reduce.hpp index 26b8fc9cd1..06e116a65e 100644 --- a/boost/mpi/collectives/all_reduce.hpp +++ b/boost/mpi/collectives/all_reduce.hpp @@ -12,12 +12,15 @@ #ifndef BOOST_MPI_ALL_REDUCE_HPP #define BOOST_MPI_ALL_REDUCE_HPP +#include <vector> + +#include <boost/mpi/inplace.hpp> + // All-reduce falls back to reduce() + broadcast() in some cases. #include <boost/mpi/collectives/broadcast.hpp> #include <boost/mpi/collectives/reduce.hpp> namespace boost { namespace mpi { - namespace detail { /********************************************************************** * Simple reduction with MPI_Allreduce * @@ -67,7 +70,17 @@ namespace detail { T* out_values, Op op, mpl::false_ /*is_mpi_op*/, mpl::false_ /*is_mpi_datatype*/) { - reduce(comm, in_values, n, out_values, op, 0); + if (in_values == MPI_IN_PLACE) { + // if in_values matches the in place tag, then the output + // buffer actually contains the input data. + // But we can just go back to the out of place + // implementation in this case. + // it's not clear how/if we can avoid the copy. + std::vector<T> tmp_in( out_values, out_values + n); + reduce(comm, &(tmp_in[0]), n, out_values, op, 0); + } else { + reduce(comm, in_values, n, out_values, op, 0); + } broadcast(comm, out_values, n, 0); } } // end namespace detail @@ -83,6 +96,20 @@ all_reduce(const communicator& comm, const T* in_values, int n, T* out_values, template<typename T, typename Op> inline void +all_reduce(const communicator& comm, inplace_t<T*> inout_values, int n, Op op) +{ + all_reduce(comm, static_cast<const T*>(MPI_IN_PLACE), n, inout_values.buffer, op); +} + +template<typename T, typename Op> +inline void +all_reduce(const communicator& comm, inplace_t<T> inout_values, Op op) +{ + all_reduce(comm, static_cast<const T*>(MPI_IN_PLACE), 1, &(inout_values.buffer), op); +} + +template<typename T, typename Op> +inline void all_reduce(const communicator& comm, const T& in_value, T& out_value, Op op) { detail::all_reduce_impl(comm, &in_value, 1, &out_value, op, |