summaryrefslogtreecommitdiff
path: root/boost/mpi/collectives/all_reduce.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'boost/mpi/collectives/all_reduce.hpp')
-rw-r--r--boost/mpi/collectives/all_reduce.hpp31
1 files changed, 29 insertions, 2 deletions
diff --git a/boost/mpi/collectives/all_reduce.hpp b/boost/mpi/collectives/all_reduce.hpp
index 26b8fc9cd1..06e116a65e 100644
--- a/boost/mpi/collectives/all_reduce.hpp
+++ b/boost/mpi/collectives/all_reduce.hpp
@@ -12,12 +12,15 @@
#ifndef BOOST_MPI_ALL_REDUCE_HPP
#define BOOST_MPI_ALL_REDUCE_HPP
+#include <vector>
+
+#include <boost/mpi/inplace.hpp>
+
// All-reduce falls back to reduce() + broadcast() in some cases.
#include <boost/mpi/collectives/broadcast.hpp>
#include <boost/mpi/collectives/reduce.hpp>
namespace boost { namespace mpi {
-
namespace detail {
/**********************************************************************
* Simple reduction with MPI_Allreduce *
@@ -67,7 +70,17 @@ namespace detail {
T* out_values, Op op, mpl::false_ /*is_mpi_op*/,
mpl::false_ /*is_mpi_datatype*/)
{
- reduce(comm, in_values, n, out_values, op, 0);
+ if (in_values == MPI_IN_PLACE) {
+ // if in_values matches the in place tag, then the output
+ // buffer actually contains the input data.
+ // But we can just go back to the out of place
+ // implementation in this case.
+ // it's not clear how/if we can avoid the copy.
+ std::vector<T> tmp_in( out_values, out_values + n);
+ reduce(comm, &(tmp_in[0]), n, out_values, op, 0);
+ } else {
+ reduce(comm, in_values, n, out_values, op, 0);
+ }
broadcast(comm, out_values, n, 0);
}
} // end namespace detail
@@ -83,6 +96,20 @@ all_reduce(const communicator& comm, const T* in_values, int n, T* out_values,
template<typename T, typename Op>
inline void
+all_reduce(const communicator& comm, inplace_t<T*> inout_values, int n, Op op)
+{
+ all_reduce(comm, static_cast<const T*>(MPI_IN_PLACE), n, inout_values.buffer, op);
+}
+
+template<typename T, typename Op>
+inline void
+all_reduce(const communicator& comm, inplace_t<T> inout_values, Op op)
+{
+ all_reduce(comm, static_cast<const T*>(MPI_IN_PLACE), 1, &(inout_values.buffer), op);
+}
+
+template<typename T, typename Op>
+inline void
all_reduce(const communicator& comm, const T& in_value, T& out_value, Op op)
{
detail::all_reduce_impl(comm, &in_value, 1, &out_value, op,