// Copyright Oliver Kowalke 2017. // Distributed under the Boost Software License, Version 1.0. // (See accompanying file LICENSE_1_0.txt or copy at // http://www.boost.org/LICENSE_1_0.txt) #ifndef BOOST_FIBERS_CUDA_WAITFOR_H #define BOOST_FIBERS_CUDA_WAITFOR_H #include #include #include #include #include #include #include #include #include #include #include #include #include #ifdef BOOST_HAS_ABI_HEADERS # include BOOST_ABI_PREFIX #endif namespace boost { namespace fibers { namespace cuda { namespace detail { template< typename Rendezvous > static void trampoline( hipStream_t st, hipError_t status, void * vp) { Rendezvous * data = static_cast< Rendezvous * >( vp); data->notify( st, status); } class single_stream_rendezvous { public: single_stream_rendezvous( hipStream_t st) { unsigned int flags = 0; hipError_t status = ::hipStreamAddCallback( st, trampoline< single_stream_rendezvous >, this, flags); if ( hipSuccess != status) { st_ = st; status_ = status; done_ = true; } } void notify( hipStream_t st, hipError_t status) noexcept { std::unique_lock< mutex > lk{ mtx_ }; st_ = st; status_ = status; done_ = true; lk.unlock(); cv_.notify_one(); } std::tuple< hipStream_t, hipError_t > wait() { std::unique_lock< mutex > lk{ mtx_ }; cv_.wait( lk, [this]{ return done_; }); return std::make_tuple( st_, status_); } private: mutex mtx_{}; condition_variable cv_{}; hipStream_t st_{}; hipError_t status_{ hipErrorUnknown }; bool done_{ false }; }; class many_streams_rendezvous { public: many_streams_rendezvous( std::initializer_list< hipStream_t > l) : stx_{ l } { results_.reserve( stx_.size() ); for ( hipStream_t st : stx_) { unsigned int flags = 0; hipError_t status = ::hipStreamAddCallback( st, trampoline< many_streams_rendezvous >, this, flags); if ( hipSuccess != status) { std::unique_lock< mutex > lk{ mtx_ }; stx_.erase( st); results_.push_back( std::make_tuple( st, status) ); } } } void notify( hipStream_t st, hipError_t status) noexcept { std::unique_lock< mutex > lk{ mtx_ }; stx_.erase( st); results_.push_back( std::make_tuple( st, status) ); if ( stx_.empty() ) { lk.unlock(); cv_.notify_one(); } } std::vector< std::tuple< hipStream_t, hipError_t > > wait() { std::unique_lock< mutex > lk{ mtx_ }; cv_.wait( lk, [this]{ return stx_.empty(); }); return results_; } private: mutex mtx_{}; condition_variable cv_{}; std::set< hipStream_t > stx_; std::vector< std::tuple< hipStream_t, hipError_t > > results_; }; } void waitfor_all(); inline std::tuple< hipStream_t, hipError_t > waitfor_all( hipStream_t st) { detail::single_stream_rendezvous rendezvous( st); return rendezvous.wait(); } template< typename ... STP > std::vector< std::tuple< hipStream_t, hipError_t > > waitfor_all( hipStream_t st0, STP ... stx) { static_assert( boost::fibers::detail::is_all_same< hipStream_t, STP ...>::value, "all arguments must be of type `CUstream*`."); detail::many_streams_rendezvous rendezvous{ st0, stx ... }; return rendezvous.wait(); } }}} #ifdef BOOST_HAS_ABI_HEADERS # include BOOST_ABI_SUFFIX #endif #endif // BOOST_FIBERS_CUDA_WAITFOR_H