summaryrefslogtreecommitdiff
path: root/boost/random/discrete_distribution.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'boost/random/discrete_distribution.hpp')
-rw-r--r--boost/random/discrete_distribution.hpp285
1 files changed, 234 insertions, 51 deletions
diff --git a/boost/random/discrete_distribution.hpp b/boost/random/discrete_distribution.hpp
index bbdc055383..6407272093 100644
--- a/boost/random/discrete_distribution.hpp
+++ b/boost/random/discrete_distribution.hpp
@@ -7,7 +7,7 @@
*
* See http://www.boost.org for most recent version including documentation.
*
- * $Id: discrete_distribution.hpp 79771 2012-07-27 18:15:55Z jewillco $
+ * $Id$
*/
#ifndef BOOST_RANDOM_DISCRETE_DISTRIBUTION_HPP_INCLUDED
@@ -20,7 +20,7 @@
#include <iterator>
#include <boost/assert.hpp>
#include <boost/random/uniform_01.hpp>
-#include <boost/random/uniform_int.hpp>
+#include <boost/random/uniform_int_distribution.hpp>
#include <boost/random/detail/config.hpp>
#include <boost/random/detail/operators.hpp>
#include <boost/random/detail/vector_io.hpp>
@@ -36,6 +36,186 @@
namespace boost {
namespace random {
+namespace detail {
+
+template<class IntType, class WeightType>
+struct integer_alias_table {
+ WeightType get_weight(IntType bin) const {
+ WeightType result = _average;
+ if(bin < _excess) ++result;
+ return result;
+ }
+ template<class Iter>
+ WeightType init_average(Iter begin, Iter end) {
+ WeightType weight_average = 0;
+ IntType excess = 0;
+ IntType n = 0;
+ // weight_average * n + excess == current partial sum
+ // This is a bit messy, but it's guaranteed not to overflow
+ for(Iter iter = begin; iter != end; ++iter) {
+ ++n;
+ if(*iter < weight_average) {
+ WeightType diff = weight_average - *iter;
+ weight_average -= diff / n;
+ if(diff % n > excess) {
+ --weight_average;
+ excess += n - diff % n;
+ } else {
+ excess -= diff % n;
+ }
+ } else {
+ WeightType diff = *iter - weight_average;
+ weight_average += diff / n;
+ if(diff % n < n - excess) {
+ excess += diff % n;
+ } else {
+ ++weight_average;
+ excess -= n - diff % n;
+ }
+ }
+ }
+ _alias_table.resize(static_cast<std::size_t>(n));
+ _average = weight_average;
+ _excess = excess;
+ return weight_average;
+ }
+ void init_empty()
+ {
+ _alias_table.clear();
+ _alias_table.push_back(std::make_pair(static_cast<WeightType>(1),
+ static_cast<IntType>(0)));
+ _average = static_cast<WeightType>(1);
+ _excess = static_cast<IntType>(0);
+ }
+ bool operator==(const integer_alias_table& other) const
+ {
+ return _alias_table == other._alias_table &&
+ _average == other._average && _excess == other._excess;
+ }
+ static WeightType normalize(WeightType val, WeightType average)
+ {
+ return val;
+ }
+ static void normalize(std::vector<WeightType>&) {}
+ template<class URNG>
+ WeightType test(URNG &urng) const
+ {
+ return uniform_int_distribution<WeightType>(0, _average)(urng);
+ }
+ bool accept(IntType result, WeightType val) const
+ {
+ return result < _excess || val < _average;
+ }
+ static WeightType try_get_sum(const std::vector<WeightType>& weights)
+ {
+ WeightType result = static_cast<WeightType>(0);
+ for(typename std::vector<WeightType>::const_iterator
+ iter = weights.begin(), end = weights.end();
+ iter != end; ++iter)
+ {
+ if((std::numeric_limits<WeightType>::max)() - result > *iter) {
+ return static_cast<WeightType>(0);
+ }
+ result += *iter;
+ }
+ return result;
+ }
+ template<class URNG>
+ static WeightType generate_in_range(URNG &urng, WeightType max)
+ {
+ return uniform_int_distribution<WeightType>(
+ static_cast<WeightType>(0), max-1)(urng);
+ }
+ typedef std::vector<std::pair<WeightType, IntType> > alias_table_t;
+ alias_table_t _alias_table;
+ WeightType _average;
+ IntType _excess;
+};
+
+template<class IntType, class WeightType>
+struct real_alias_table {
+ WeightType get_weight(IntType) const
+ {
+ return WeightType(1.0);
+ }
+ template<class Iter>
+ WeightType init_average(Iter first, Iter last)
+ {
+ std::size_t size = std::distance(first, last);
+ WeightType weight_sum =
+ std::accumulate(first, last, static_cast<WeightType>(0));
+ _alias_table.resize(size);
+ return weight_sum / size;
+ }
+ void init_empty()
+ {
+ _alias_table.clear();
+ _alias_table.push_back(std::make_pair(static_cast<WeightType>(1),
+ static_cast<IntType>(0)));
+ }
+ bool operator==(const real_alias_table& other) const
+ {
+ return _alias_table == other._alias_table;
+ }
+ static WeightType normalize(WeightType val, WeightType average)
+ {
+ return val / average;
+ }
+ static void normalize(std::vector<WeightType>& weights)
+ {
+ WeightType sum =
+ std::accumulate(weights.begin(), weights.end(),
+ static_cast<WeightType>(0));
+ for(typename std::vector<WeightType>::iterator
+ iter = weights.begin(),
+ end = weights.end();
+ iter != end; ++iter)
+ {
+ *iter /= sum;
+ }
+ }
+ template<class URNG>
+ WeightType test(URNG &urng) const
+ {
+ return uniform_01<WeightType>()(urng);
+ }
+ bool accept(IntType, WeightType) const
+ {
+ return true;
+ }
+ static WeightType try_get_sum(const std::vector<WeightType>& weights)
+ {
+ return static_cast<WeightType>(1);
+ }
+ template<class URNG>
+ static WeightType generate_in_range(URNG &urng, WeightType)
+ {
+ return uniform_01<WeightType>()(urng);
+ }
+ typedef std::vector<std::pair<WeightType, IntType> > alias_table_t;
+ alias_table_t _alias_table;
+};
+
+template<bool IsIntegral>
+struct select_alias_table;
+
+template<>
+struct select_alias_table<true> {
+ template<class IntType, class WeightType>
+ struct apply {
+ typedef integer_alias_table<IntType, WeightType> type;
+ };
+};
+
+template<>
+struct select_alias_table<false> {
+ template<class IntType, class WeightType>
+ struct apply {
+ typedef real_alias_table<IntType, WeightType> type;
+ };
+};
+
+}
/**
* The class @c discrete_distribution models a \random_distribution.
@@ -155,16 +335,7 @@ public:
{}
void normalize()
{
- WeightType sum =
- std::accumulate(_probabilities.begin(), _probabilities.end(),
- static_cast<WeightType>(0));
- for(typename std::vector<WeightType>::iterator
- iter = _probabilities.begin(),
- end = _probabilities.end();
- iter != end; ++iter)
- {
- *iter /= sum;
- }
+ impl_type::normalize(_probabilities);
}
std::vector<WeightType> _probabilities;
/// @endcond
@@ -176,8 +347,7 @@ public:
*/
discrete_distribution()
{
- _alias_table.push_back(std::make_pair(static_cast<WeightType>(1),
- static_cast<IntType>(0)));
+ _impl.init_empty();
}
/**
* Constructs a discrete_distribution from an iterator range.
@@ -257,13 +427,17 @@ public:
template<class URNG>
IntType operator()(URNG& urng) const
{
- BOOST_ASSERT(!_alias_table.empty());
- WeightType test = uniform_01<WeightType>()(urng);
- IntType result = uniform_int<IntType>((min)(), (max)())(urng);
- if(test < _alias_table[result].first) {
+ BOOST_ASSERT(!_impl._alias_table.empty());
+ IntType result;
+ WeightType test;
+ do {
+ result = uniform_int_distribution<IntType>((min)(), (max)())(urng);
+ test = _impl.test(urng);
+ } while(!_impl.accept(result, test));
+ if(test < _impl._alias_table[result].first) {
return result;
} else {
- return(_alias_table[result].second);
+ return(_impl._alias_table[result].second);
}
}
@@ -274,13 +448,13 @@ public:
template<class URNG>
IntType operator()(URNG& urng, const param_type& parm) const
{
- while(true) {
- WeightType val = uniform_01<WeightType>()(urng);
+ if(WeightType limit = impl_type::try_get_sum(parm._probabilities)) {
+ WeightType val = impl_type::generate_in_range(urng, limit);
WeightType sum = 0;
std::size_t result = 0;
for(typename std::vector<WeightType>::const_iterator
- iter = parm._probabilities.begin(),
- end = parm._probabilities.end();
+ iter = parm._probabilities.begin(),
+ end = parm._probabilities.end();
iter != end; ++iter, ++result)
{
sum += *iter;
@@ -288,6 +462,14 @@ public:
return result;
}
}
+ // This shouldn't be reachable, but round-off error
+ // can prevent any match from being found when val is
+ // very close to 1.
+ return static_cast<IntType>(parm._probabilities.size() - 1);
+ } else {
+ // WeightType is integral and sum(parm._probabilities)
+ // would overflow. Just use the easy solution.
+ return discrete_distribution(parm)(urng);
}
}
@@ -295,7 +477,7 @@ public:
result_type min BOOST_PREVENT_MACRO_SUBSTITUTION () const { return 0; }
/** Returns the largest value that the distribution can produce. */
result_type max BOOST_PREVENT_MACRO_SUBSTITUTION () const
- { return static_cast<result_type>(_alias_table.size() - 1); }
+ { return static_cast<result_type>(_impl._alias_table.size() - 1); }
/**
* Returns a vector containing the probabilities of each
@@ -307,22 +489,24 @@ public:
* @endcode
*
* the vector, p will contain {0.1, 0.4, 0.5}.
+ *
+ * If @c WeightType is integral, then the weights
+ * will be returned unchanged.
*/
std::vector<WeightType> probabilities() const
{
- std::vector<WeightType> result(_alias_table.size());
- const WeightType mean =
- static_cast<WeightType>(1) / _alias_table.size();
+ std::vector<WeightType> result(_impl._alias_table.size());
std::size_t i = 0;
- for(typename alias_table_t::const_iterator
- iter = _alias_table.begin(),
- end = _alias_table.end();
+ for(typename impl_type::alias_table_t::const_iterator
+ iter = _impl._alias_table.begin(),
+ end = _impl._alias_table.end();
iter != end; ++iter, ++i)
{
- WeightType val = iter->first * mean;
+ WeightType val = iter->first;
result[i] += val;
- result[iter->second] += mean - val;
+ result[iter->second] += _impl.get_weight(i) - val;
}
+ impl_type::normalize(result);
return(result);
}
@@ -366,7 +550,7 @@ public:
*/
BOOST_RANDOM_DETAIL_EQUALITY_OPERATOR(discrete_distribution, lhs, rhs)
{
- return lhs._alias_table == rhs._alias_table;
+ return lhs._impl == rhs._impl;
}
/**
* Returns true if the two distributions may return different
@@ -389,59 +573,58 @@ private:
{
std::vector<std::pair<WeightType, IntType> > below_average;
std::vector<std::pair<WeightType, IntType> > above_average;
- std::size_t size = std::distance(first, last);
- WeightType weight_sum =
- std::accumulate(first, last, static_cast<WeightType>(0));
- WeightType weight_average = weight_sum / size;
+ WeightType weight_average = _impl.init_average(first, last);
+ WeightType normalized_average = _impl.get_weight(0);
std::size_t i = 0;
for(; first != last; ++first, ++i) {
- WeightType val = *first / weight_average;
+ WeightType val = impl_type::normalize(*first, weight_average);
std::pair<WeightType, IntType> elem(val, static_cast<IntType>(i));
- if(val < static_cast<WeightType>(1)) {
+ if(val < normalized_average) {
below_average.push_back(elem);
} else {
above_average.push_back(elem);
}
}
- _alias_table.resize(size);
- typename alias_table_t::iterator
+ typename impl_type::alias_table_t::iterator
b_iter = below_average.begin(),
b_end = below_average.end(),
a_iter = above_average.begin(),
a_end = above_average.end()
;
while(b_iter != b_end && a_iter != a_end) {
- _alias_table[b_iter->second] =
+ _impl._alias_table[b_iter->second] =
std::make_pair(b_iter->first, a_iter->second);
- a_iter->first -= (static_cast<WeightType>(1) - b_iter->first);
- if(a_iter->first < static_cast<WeightType>(1)) {
+ a_iter->first -= (_impl.get_weight(b_iter->second) - b_iter->first);
+ if(a_iter->first < normalized_average) {
*b_iter = *a_iter++;
} else {
++b_iter;
}
}
for(; b_iter != b_end; ++b_iter) {
- _alias_table[b_iter->second].first = static_cast<WeightType>(1);
+ _impl._alias_table[b_iter->second].first =
+ _impl.get_weight(b_iter->second);
}
for(; a_iter != a_end; ++a_iter) {
- _alias_table[a_iter->second].first = static_cast<WeightType>(1);
+ _impl._alias_table[a_iter->second].first =
+ _impl.get_weight(a_iter->second);
}
}
template<class Iter>
void init(Iter first, Iter last)
{
if(first == last) {
- _alias_table.clear();
- _alias_table.push_back(std::make_pair(static_cast<WeightType>(1),
- static_cast<IntType>(0)));
+ _impl.init_empty();
} else {
typename std::iterator_traits<Iter>::iterator_category category;
init(first, last, category);
}
}
- typedef std::vector<std::pair<WeightType, IntType> > alias_table_t;
- alias_table_t _alias_table;
+ typedef typename detail::select_alias_table<
+ (::boost::is_integral<WeightType>::value)
+ >::template apply<IntType, WeightType>::type impl_type;
+ impl_type _impl;
/// @endcond
};