summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNeeraj Pradhan <prad.neeraj@gmail.com>2018-04-26 06:53:01 -0700
committerAdam Paszke <adam.paszke@gmail.com>2018-04-26 15:53:01 +0200
commit3964253f943668fe076c992dbf91e975b3b4c3c4 (patch)
tree6aae2f53dedccc8a2347450148f4c0b22cf55c1b
parentf98b7780867c847eef27fbf492c9ee8dfe023c20 (diff)
downloadpytorch-3964253f943668fe076c992dbf91e975b3b4c3c4.tar.gz
pytorch-3964253f943668fe076c992dbf91e975b3b4c3c4.tar.bz2
pytorch-3964253f943668fe076c992dbf91e975b3b4c3c4.zip
Allowing for vectorized counts in Binomial Distribution (#6720)
-rw-r--r--test/test_distributions.py54
-rw-r--r--torch/distributions/binomial.py51
-rw-r--r--torch/distributions/kl.py10
3 files changed, 88 insertions, 27 deletions
diff --git a/test/test_distributions.py b/test/test_distributions.py
index 09ba06b9fb..917fd7200b 100644
--- a/test/test_distributions.py
+++ b/test/test_distributions.py
@@ -113,6 +113,12 @@ EXAMPLES = [
Example(Binomial, [
{'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True), 'total_count': 10},
{'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 'total_count': 10},
+ {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 'total_count': torch.tensor([10])},
+ {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 'total_count': torch.tensor([10, 8])},
+ {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True),
+ 'total_count': torch.tensor([[10., 8.], [5., 3.]])},
+ {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True),
+ 'total_count': torch.tensor(0.)},
]),
Example(Multinomial, [
{'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True), 'total_count': 10},
@@ -795,6 +801,15 @@ class TestDistributions(TestCase):
logits = probs_to_logits(probs, is_binary=True)
self._check_log_prob(Binomial(total_count, logits=logits), ref_log_prob)
+ @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
+ def test_binomial_log_prob_vectorized_count(self):
+ probs = torch.tensor([0.2, 0.7, 0.9])
+ for total_count, sample in [(torch.tensor([10]), torch.tensor([7., 3., 9.])),
+ (torch.tensor([1, 2, 10]), torch.tensor([0., 1., 9.]))]:
+ log_prob = Binomial(total_count, probs).log_prob(sample)
+ expected = scipy.stats.binom(total_count.cpu().numpy(), probs.cpu().numpy()).logpmf(sample)
+ self.assertAlmostEqual(log_prob, expected, places=4)
+
def test_binomial_extreme_vals(self):
total_count = 100
bin0 = Binomial(total_count, 0)
@@ -805,6 +820,28 @@ class TestDistributions(TestCase):
self.assertEqual(bin1.sample(), total_count)
self.assertAlmostEqual(bin1.log_prob(torch.tensor([float(total_count)]))[0], 0, places=3)
self.assertEqual(float(bin1.log_prob(torch.tensor([float(total_count - 1)])).exp()), 0, allow_inf=True)
+ zero_counts = torch.zeros(torch.Size((2, 2)))
+ bin2 = Binomial(zero_counts, 1)
+ self.assertEqual(bin2.sample(), zero_counts)
+ self.assertEqual(bin2.log_prob(zero_counts), zero_counts)
+
+ def test_binomial_vectorized_count(self):
+ set_rng_seed(0)
+ total_count = torch.tensor([[4, 7], [3, 8]])
+ bin0 = Binomial(total_count, torch.tensor(1.))
+ self.assertEqual(bin0.sample(), total_count)
+ bin1 = Binomial(total_count, torch.tensor(0.5))
+ samples = bin1.sample(torch.Size((100000,)))
+ self.assertTrue((samples <= total_count.type_as(samples)).all())
+ self.assertEqual(samples.mean(dim=0), bin1.mean, prec=0.02)
+ self.assertEqual(samples.var(dim=0), bin1.variance, prec=0.02)
+
+ def test_binomial_enumerate_support(self):
+ set_rng_seed(0)
+ bin0 = Binomial(0, torch.tensor(1.))
+ self.assertEqual(bin0.enumerate_support(), torch.tensor([0.]))
+ bin1 = Binomial(torch.tensor(5), torch.tensor(0.5))
+ self.assertEqual(bin1.enumerate_support(), torch.arange(6))
def test_multinomial_1d(self):
total_count = 10
@@ -1793,9 +1830,8 @@ class TestDistributions(TestCase):
self.assertEqual(indep_dist.has_rsample, base_dist.has_rsample)
if indep_dist.has_rsample:
self.assertEqual(indep_dist.sample().shape, base_dist.sample().shape)
- if indep_dist.has_enumerate_support:
- self.assertEqual(indep_dist.enumerate_support().shape, base_dist.enumerate_support().shape)
try:
+ self.assertEqual(indep_dist.enumerate_support().shape, base_dist.enumerate_support().shape)
self.assertEqual(indep_dist.mean.shape, base_dist.mean.shape)
except NotImplementedError:
pass
@@ -2301,6 +2337,15 @@ class TestDistributionShapes(TestCase):
self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
+ def test_binomial_shape_vectorized_n(self):
+ dist = Binomial(torch.tensor([[10, 3, 1], [4, 8, 4]]), torch.tensor([0.6, 0.3, 0.1]))
+ self.assertEqual(dist._batch_shape, torch.Size((2, 3)))
+ self.assertEqual(dist._event_shape, torch.Size(()))
+ self.assertEqual(dist.sample().size(), torch.Size((2, 3)))
+ self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 2, 3)))
+ self.assertEqual(dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
+ self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1)
+
def test_multinomial_shape(self):
dist = Multinomial(10, torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
self.assertEqual(dist._batch_shape, torch.Size((3,)))
@@ -2562,6 +2607,8 @@ class TestKL(TestCase):
# e.g. bernoulli[1] varies row-wise; that way we test all param pairs.
bernoulli = pairwise(Bernoulli, [0.1, 0.2, 0.6, 0.9])
binomial30 = pairwise(Binomial30, [0.1, 0.2, 0.6, 0.9])
+ binomial_vectorized_count = (Binomial(torch.tensor([3, 4]), torch.tensor([0.4, 0.6])),
+ Binomial(torch.tensor([3, 4]), torch.tensor([0.5, 0.8])))
beta = pairwise(Beta, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5])
categorical = pairwise(Categorical, [[0.4, 0.3, 0.3],
[0.2, 0.7, 0.1],
@@ -2607,6 +2654,7 @@ class TestKL(TestCase):
(beta, gamma),
(beta, normal),
(binomial30, binomial30),
+ (binomial_vectorized_count, binomial_vectorized_count),
(categorical, categorical),
(chi2, chi2),
(chi2, exponential),
@@ -2654,6 +2702,8 @@ class TestKL(TestCase):
(Beta(1, 2), Uniform(0.25, 0.75)),
(Beta(1, 2), Pareto(1, 2)),
(Binomial(31, 0.7), Binomial(30, 0.3)),
+ (Binomial(torch.tensor([3, 4]), torch.tensor([0.4, 0.6])),
+ Binomial(torch.tensor([2, 3]), torch.tensor([0.5, 0.8]))),
(Chi2(1), Beta(2, 3)),
(Chi2(1), Pareto(2, 3)),
(Chi2(1), Uniform(-2, 3)),
diff --git a/torch/distributions/binomial.py b/torch/distributions/binomial.py
index f89e4dc577..ef52ff2986 100644
--- a/torch/distributions/binomial.py
+++ b/torch/distributions/binomial.py
@@ -1,19 +1,15 @@
from numbers import Number
import torch
-import math
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all, probs_to_logits, lazy_property, logits_to_probs
-from torch.distributions.utils import clamp_probs
class Binomial(Distribution):
r"""
Creates a Binomial distribution parameterized by `total_count` and
- either `probs` or `logits` (but not both).
-
- - Requires a single shared `total_count` for all
- parameters and samples.
+ either `probs` or `logits` (but not both). `total_count` must be
+ broadcastable with `probs`/`logits`.
Example::
@@ -25,26 +21,32 @@ class Binomial(Distribution):
100
[torch.FloatTensor of size 4]]
+ >>> m = Binomial(torch.Tensor([[5.], [10.]]), torch.Tensor([0.5, 0.8]))
+ >>> x = m.sample()
+ 4 5
+ 7 6
+ [torch.FloatTensor of size (2,2)]
+
Args:
- total_count (int): number of Bernoulli trials
+ total_count (int or Tensor): number of Bernoulli trials
probs (Tensor): Event probabilities
logits (Tensor): Event log-odds
"""
- arg_constraints = {'probs': constraints.unit_interval}
+ arg_constraints = {'total_count': constraints.nonnegative_integer,
+ 'probs': constraints.unit_interval}
has_enumerate_support = True
def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
- if not isinstance(total_count, Number):
- raise NotImplementedError('inhomogeneous total_count is not supported')
- self.total_count = total_count
if (probs is None) == (logits is None):
raise ValueError("Either `probs` or `logits` must be specified, but not both.")
if probs is not None:
- is_scalar = isinstance(probs, Number)
- self.probs, = broadcast_all(probs)
+ self.total_count, self.probs, = broadcast_all(total_count, probs)
+ self.total_count = self.total_count.type_as(self.logits)
+ is_scalar = isinstance(self.probs, Number)
else:
- is_scalar = isinstance(logits, Number)
- self.logits, = broadcast_all(logits)
+ self.total_count, self.logits, = broadcast_all(total_count, logits)
+ self.total_count = self.total_count.type_as(self.logits)
+ is_scalar = isinstance(self.logits, Number)
self._param = self.probs if probs is not None else self.logits
if is_scalar:
@@ -81,14 +83,20 @@ class Binomial(Distribution):
return self._param.size()
def sample(self, sample_shape=torch.Size()):
- shape = self._extended_shape(sample_shape) + (self.total_count,)
with torch.no_grad():
- return torch.bernoulli(self.probs.unsqueeze(-1).expand(shape)).sum(dim=-1)
+ max_count = max(int(self.total_count.max()), 1)
+ shape = self._extended_shape(sample_shape) + (max_count,)
+ bernoullis = torch.bernoulli(self.probs.unsqueeze(-1).expand(shape))
+ if self.total_count.min() != max_count:
+ arange = torch.arange(max_count, out=self.total_count.new_empty(max_count))
+ mask = arange >= self.total_count.unsqueeze(-1)
+ bernoullis.masked_fill_(mask, 0.)
+ return bernoullis.sum(dim=-1)
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
- log_factorial_n = math.lgamma(self.total_count + 1)
+ log_factorial_n = torch.lgamma(self.total_count + 1)
log_factorial_k = torch.lgamma(value + 1)
log_factorial_nmk = torch.lgamma(self.total_count - value + 1)
max_val = (-self.logits).clamp(min=0.0)
@@ -98,8 +106,11 @@ class Binomial(Distribution):
self.total_count * torch.log1p((self.logits + 2 * max_val).exp()))
def enumerate_support(self):
- values = self._new((self.total_count,))
- torch.arange(self.total_count, out=values.data)
+ total_count = int(self.total_count.max())
+ if not self.total_count.min() == total_count:
+ raise NotImplementedError("Inhomogeneous total count not supported by `enumerate_support`.")
+ values = self._new(1 + total_count,)
+ torch.arange(1 + total_count, out=values)
values = values.view((-1,) + (1,) * len(self._batch_shape))
values = values.expand((-1,) + self._batch_shape)
return values
diff --git a/torch/distributions/kl.py b/torch/distributions/kl.py
index bdb9269c49..203e0726ad 100644
--- a/torch/distributions/kl.py
+++ b/torch/distributions/kl.py
@@ -198,12 +198,12 @@ def _kl_beta_beta(p, q):
def _kl_binomial_binomial(p, q):
# from https://math.stackexchange.com/questions/2214993/
# kullback-leibler-divergence-for-binomial-distributions-p-and-q
- if p.total_count > q.total_count:
- return _infinite_like(p.probs)
- elif p.total_count == q.total_count:
- return p.total_count * (p.probs * (p.logits - q.logits) + (-p.probs).log1p() - (-q.probs).log1p())
- else:
+ if (p.total_count < q.total_count).any():
raise NotImplementedError('KL between Binomials where q.total_count > p.total_count is not implemented')
+ kl = p.total_count * (p.probs * (p.logits - q.logits) + (-p.probs).log1p() - (-q.probs).log1p())
+ inf_idxs = p.total_count > q.total_count
+ kl[inf_idxs] = _infinite_like(kl[inf_idxs])
+ return kl
@register_kl(Categorical, Categorical)