diff options
Diffstat (limited to 'boost/random/discrete_distribution.hpp')
-rw-r--r-- | boost/random/discrete_distribution.hpp | 285 |
1 files changed, 234 insertions, 51 deletions
diff --git a/boost/random/discrete_distribution.hpp b/boost/random/discrete_distribution.hpp index bbdc055383..6407272093 100644 --- a/boost/random/discrete_distribution.hpp +++ b/boost/random/discrete_distribution.hpp @@ -7,7 +7,7 @@ * * See http://www.boost.org for most recent version including documentation. * - * $Id: discrete_distribution.hpp 79771 2012-07-27 18:15:55Z jewillco $ + * $Id$ */ #ifndef BOOST_RANDOM_DISCRETE_DISTRIBUTION_HPP_INCLUDED @@ -20,7 +20,7 @@ #include <iterator> #include <boost/assert.hpp> #include <boost/random/uniform_01.hpp> -#include <boost/random/uniform_int.hpp> +#include <boost/random/uniform_int_distribution.hpp> #include <boost/random/detail/config.hpp> #include <boost/random/detail/operators.hpp> #include <boost/random/detail/vector_io.hpp> @@ -36,6 +36,186 @@ namespace boost { namespace random { +namespace detail { + +template<class IntType, class WeightType> +struct integer_alias_table { + WeightType get_weight(IntType bin) const { + WeightType result = _average; + if(bin < _excess) ++result; + return result; + } + template<class Iter> + WeightType init_average(Iter begin, Iter end) { + WeightType weight_average = 0; + IntType excess = 0; + IntType n = 0; + // weight_average * n + excess == current partial sum + // This is a bit messy, but it's guaranteed not to overflow + for(Iter iter = begin; iter != end; ++iter) { + ++n; + if(*iter < weight_average) { + WeightType diff = weight_average - *iter; + weight_average -= diff / n; + if(diff % n > excess) { + --weight_average; + excess += n - diff % n; + } else { + excess -= diff % n; + } + } else { + WeightType diff = *iter - weight_average; + weight_average += diff / n; + if(diff % n < n - excess) { + excess += diff % n; + } else { + ++weight_average; + excess -= n - diff % n; + } + } + } + _alias_table.resize(static_cast<std::size_t>(n)); + _average = weight_average; + _excess = excess; + return weight_average; + } + void init_empty() + { + _alias_table.clear(); + _alias_table.push_back(std::make_pair(static_cast<WeightType>(1), + static_cast<IntType>(0))); + _average = static_cast<WeightType>(1); + _excess = static_cast<IntType>(0); + } + bool operator==(const integer_alias_table& other) const + { + return _alias_table == other._alias_table && + _average == other._average && _excess == other._excess; + } + static WeightType normalize(WeightType val, WeightType average) + { + return val; + } + static void normalize(std::vector<WeightType>&) {} + template<class URNG> + WeightType test(URNG &urng) const + { + return uniform_int_distribution<WeightType>(0, _average)(urng); + } + bool accept(IntType result, WeightType val) const + { + return result < _excess || val < _average; + } + static WeightType try_get_sum(const std::vector<WeightType>& weights) + { + WeightType result = static_cast<WeightType>(0); + for(typename std::vector<WeightType>::const_iterator + iter = weights.begin(), end = weights.end(); + iter != end; ++iter) + { + if((std::numeric_limits<WeightType>::max)() - result > *iter) { + return static_cast<WeightType>(0); + } + result += *iter; + } + return result; + } + template<class URNG> + static WeightType generate_in_range(URNG &urng, WeightType max) + { + return uniform_int_distribution<WeightType>( + static_cast<WeightType>(0), max-1)(urng); + } + typedef std::vector<std::pair<WeightType, IntType> > alias_table_t; + alias_table_t _alias_table; + WeightType _average; + IntType _excess; +}; + +template<class IntType, class WeightType> +struct real_alias_table { + WeightType get_weight(IntType) const + { + return WeightType(1.0); + } + template<class Iter> + WeightType init_average(Iter first, Iter last) + { + std::size_t size = std::distance(first, last); + WeightType weight_sum = + std::accumulate(first, last, static_cast<WeightType>(0)); + _alias_table.resize(size); + return weight_sum / size; + } + void init_empty() + { + _alias_table.clear(); + _alias_table.push_back(std::make_pair(static_cast<WeightType>(1), + static_cast<IntType>(0))); + } + bool operator==(const real_alias_table& other) const + { + return _alias_table == other._alias_table; + } + static WeightType normalize(WeightType val, WeightType average) + { + return val / average; + } + static void normalize(std::vector<WeightType>& weights) + { + WeightType sum = + std::accumulate(weights.begin(), weights.end(), + static_cast<WeightType>(0)); + for(typename std::vector<WeightType>::iterator + iter = weights.begin(), + end = weights.end(); + iter != end; ++iter) + { + *iter /= sum; + } + } + template<class URNG> + WeightType test(URNG &urng) const + { + return uniform_01<WeightType>()(urng); + } + bool accept(IntType, WeightType) const + { + return true; + } + static WeightType try_get_sum(const std::vector<WeightType>& weights) + { + return static_cast<WeightType>(1); + } + template<class URNG> + static WeightType generate_in_range(URNG &urng, WeightType) + { + return uniform_01<WeightType>()(urng); + } + typedef std::vector<std::pair<WeightType, IntType> > alias_table_t; + alias_table_t _alias_table; +}; + +template<bool IsIntegral> +struct select_alias_table; + +template<> +struct select_alias_table<true> { + template<class IntType, class WeightType> + struct apply { + typedef integer_alias_table<IntType, WeightType> type; + }; +}; + +template<> +struct select_alias_table<false> { + template<class IntType, class WeightType> + struct apply { + typedef real_alias_table<IntType, WeightType> type; + }; +}; + +} /** * The class @c discrete_distribution models a \random_distribution. @@ -155,16 +335,7 @@ public: {} void normalize() { - WeightType sum = - std::accumulate(_probabilities.begin(), _probabilities.end(), - static_cast<WeightType>(0)); - for(typename std::vector<WeightType>::iterator - iter = _probabilities.begin(), - end = _probabilities.end(); - iter != end; ++iter) - { - *iter /= sum; - } + impl_type::normalize(_probabilities); } std::vector<WeightType> _probabilities; /// @endcond @@ -176,8 +347,7 @@ public: */ discrete_distribution() { - _alias_table.push_back(std::make_pair(static_cast<WeightType>(1), - static_cast<IntType>(0))); + _impl.init_empty(); } /** * Constructs a discrete_distribution from an iterator range. @@ -257,13 +427,17 @@ public: template<class URNG> IntType operator()(URNG& urng) const { - BOOST_ASSERT(!_alias_table.empty()); - WeightType test = uniform_01<WeightType>()(urng); - IntType result = uniform_int<IntType>((min)(), (max)())(urng); - if(test < _alias_table[result].first) { + BOOST_ASSERT(!_impl._alias_table.empty()); + IntType result; + WeightType test; + do { + result = uniform_int_distribution<IntType>((min)(), (max)())(urng); + test = _impl.test(urng); + } while(!_impl.accept(result, test)); + if(test < _impl._alias_table[result].first) { return result; } else { - return(_alias_table[result].second); + return(_impl._alias_table[result].second); } } @@ -274,13 +448,13 @@ public: template<class URNG> IntType operator()(URNG& urng, const param_type& parm) const { - while(true) { - WeightType val = uniform_01<WeightType>()(urng); + if(WeightType limit = impl_type::try_get_sum(parm._probabilities)) { + WeightType val = impl_type::generate_in_range(urng, limit); WeightType sum = 0; std::size_t result = 0; for(typename std::vector<WeightType>::const_iterator - iter = parm._probabilities.begin(), - end = parm._probabilities.end(); + iter = parm._probabilities.begin(), + end = parm._probabilities.end(); iter != end; ++iter, ++result) { sum += *iter; @@ -288,6 +462,14 @@ public: return result; } } + // This shouldn't be reachable, but round-off error + // can prevent any match from being found when val is + // very close to 1. + return static_cast<IntType>(parm._probabilities.size() - 1); + } else { + // WeightType is integral and sum(parm._probabilities) + // would overflow. Just use the easy solution. + return discrete_distribution(parm)(urng); } } @@ -295,7 +477,7 @@ public: result_type min BOOST_PREVENT_MACRO_SUBSTITUTION () const { return 0; } /** Returns the largest value that the distribution can produce. */ result_type max BOOST_PREVENT_MACRO_SUBSTITUTION () const - { return static_cast<result_type>(_alias_table.size() - 1); } + { return static_cast<result_type>(_impl._alias_table.size() - 1); } /** * Returns a vector containing the probabilities of each @@ -307,22 +489,24 @@ public: * @endcode * * the vector, p will contain {0.1, 0.4, 0.5}. + * + * If @c WeightType is integral, then the weights + * will be returned unchanged. */ std::vector<WeightType> probabilities() const { - std::vector<WeightType> result(_alias_table.size()); - const WeightType mean = - static_cast<WeightType>(1) / _alias_table.size(); + std::vector<WeightType> result(_impl._alias_table.size()); std::size_t i = 0; - for(typename alias_table_t::const_iterator - iter = _alias_table.begin(), - end = _alias_table.end(); + for(typename impl_type::alias_table_t::const_iterator + iter = _impl._alias_table.begin(), + end = _impl._alias_table.end(); iter != end; ++iter, ++i) { - WeightType val = iter->first * mean; + WeightType val = iter->first; result[i] += val; - result[iter->second] += mean - val; + result[iter->second] += _impl.get_weight(i) - val; } + impl_type::normalize(result); return(result); } @@ -366,7 +550,7 @@ public: */ BOOST_RANDOM_DETAIL_EQUALITY_OPERATOR(discrete_distribution, lhs, rhs) { - return lhs._alias_table == rhs._alias_table; + return lhs._impl == rhs._impl; } /** * Returns true if the two distributions may return different @@ -389,59 +573,58 @@ private: { std::vector<std::pair<WeightType, IntType> > below_average; std::vector<std::pair<WeightType, IntType> > above_average; - std::size_t size = std::distance(first, last); - WeightType weight_sum = - std::accumulate(first, last, static_cast<WeightType>(0)); - WeightType weight_average = weight_sum / size; + WeightType weight_average = _impl.init_average(first, last); + WeightType normalized_average = _impl.get_weight(0); std::size_t i = 0; for(; first != last; ++first, ++i) { - WeightType val = *first / weight_average; + WeightType val = impl_type::normalize(*first, weight_average); std::pair<WeightType, IntType> elem(val, static_cast<IntType>(i)); - if(val < static_cast<WeightType>(1)) { + if(val < normalized_average) { below_average.push_back(elem); } else { above_average.push_back(elem); } } - _alias_table.resize(size); - typename alias_table_t::iterator + typename impl_type::alias_table_t::iterator b_iter = below_average.begin(), b_end = below_average.end(), a_iter = above_average.begin(), a_end = above_average.end() ; while(b_iter != b_end && a_iter != a_end) { - _alias_table[b_iter->second] = + _impl._alias_table[b_iter->second] = std::make_pair(b_iter->first, a_iter->second); - a_iter->first -= (static_cast<WeightType>(1) - b_iter->first); - if(a_iter->first < static_cast<WeightType>(1)) { + a_iter->first -= (_impl.get_weight(b_iter->second) - b_iter->first); + if(a_iter->first < normalized_average) { *b_iter = *a_iter++; } else { ++b_iter; } } for(; b_iter != b_end; ++b_iter) { - _alias_table[b_iter->second].first = static_cast<WeightType>(1); + _impl._alias_table[b_iter->second].first = + _impl.get_weight(b_iter->second); } for(; a_iter != a_end; ++a_iter) { - _alias_table[a_iter->second].first = static_cast<WeightType>(1); + _impl._alias_table[a_iter->second].first = + _impl.get_weight(a_iter->second); } } template<class Iter> void init(Iter first, Iter last) { if(first == last) { - _alias_table.clear(); - _alias_table.push_back(std::make_pair(static_cast<WeightType>(1), - static_cast<IntType>(0))); + _impl.init_empty(); } else { typename std::iterator_traits<Iter>::iterator_category category; init(first, last, category); } } - typedef std::vector<std::pair<WeightType, IntType> > alias_table_t; - alias_table_t _alias_table; + typedef typename detail::select_alias_table< + (::boost::is_integral<WeightType>::value) + >::template apply<IntType, WeightType>::type impl_type; + impl_type _impl; /// @endcond }; |