diff options
author | Neeraj Pradhan <prad.neeraj@gmail.com> | 2018-04-26 06:53:01 -0700 |
---|---|---|
committer | Adam Paszke <adam.paszke@gmail.com> | 2018-04-26 15:53:01 +0200 |
commit | 3964253f943668fe076c992dbf91e975b3b4c3c4 (patch) | |
tree | 6aae2f53dedccc8a2347450148f4c0b22cf55c1b | |
parent | f98b7780867c847eef27fbf492c9ee8dfe023c20 (diff) | |
download | pytorch-3964253f943668fe076c992dbf91e975b3b4c3c4.tar.gz pytorch-3964253f943668fe076c992dbf91e975b3b4c3c4.tar.bz2 pytorch-3964253f943668fe076c992dbf91e975b3b4c3c4.zip |
Allowing for vectorized counts in Binomial Distribution (#6720)
-rw-r--r-- | test/test_distributions.py | 54 | ||||
-rw-r--r-- | torch/distributions/binomial.py | 51 | ||||
-rw-r--r-- | torch/distributions/kl.py | 10 |
3 files changed, 88 insertions, 27 deletions
diff --git a/test/test_distributions.py b/test/test_distributions.py index 09ba06b9fb..917fd7200b 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -113,6 +113,12 @@ EXAMPLES = [ Example(Binomial, [ {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True), 'total_count': 10}, {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 'total_count': 10}, + {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 'total_count': torch.tensor([10])}, + {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 'total_count': torch.tensor([10, 8])}, + {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), + 'total_count': torch.tensor([[10., 8.], [5., 3.]])}, + {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), + 'total_count': torch.tensor(0.)}, ]), Example(Multinomial, [ {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True), 'total_count': 10}, @@ -795,6 +801,15 @@ class TestDistributions(TestCase): logits = probs_to_logits(probs, is_binary=True) self._check_log_prob(Binomial(total_count, logits=logits), ref_log_prob) + @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + def test_binomial_log_prob_vectorized_count(self): + probs = torch.tensor([0.2, 0.7, 0.9]) + for total_count, sample in [(torch.tensor([10]), torch.tensor([7., 3., 9.])), + (torch.tensor([1, 2, 10]), torch.tensor([0., 1., 9.]))]: + log_prob = Binomial(total_count, probs).log_prob(sample) + expected = scipy.stats.binom(total_count.cpu().numpy(), probs.cpu().numpy()).logpmf(sample) + self.assertAlmostEqual(log_prob, expected, places=4) + def test_binomial_extreme_vals(self): total_count = 100 bin0 = Binomial(total_count, 0) @@ -805,6 +820,28 @@ class TestDistributions(TestCase): self.assertEqual(bin1.sample(), total_count) self.assertAlmostEqual(bin1.log_prob(torch.tensor([float(total_count)]))[0], 0, places=3) self.assertEqual(float(bin1.log_prob(torch.tensor([float(total_count - 1)])).exp()), 0, allow_inf=True) + zero_counts = torch.zeros(torch.Size((2, 2))) + bin2 = Binomial(zero_counts, 1) + self.assertEqual(bin2.sample(), zero_counts) + self.assertEqual(bin2.log_prob(zero_counts), zero_counts) + + def test_binomial_vectorized_count(self): + set_rng_seed(0) + total_count = torch.tensor([[4, 7], [3, 8]]) + bin0 = Binomial(total_count, torch.tensor(1.)) + self.assertEqual(bin0.sample(), total_count) + bin1 = Binomial(total_count, torch.tensor(0.5)) + samples = bin1.sample(torch.Size((100000,))) + self.assertTrue((samples <= total_count.type_as(samples)).all()) + self.assertEqual(samples.mean(dim=0), bin1.mean, prec=0.02) + self.assertEqual(samples.var(dim=0), bin1.variance, prec=0.02) + + def test_binomial_enumerate_support(self): + set_rng_seed(0) + bin0 = Binomial(0, torch.tensor(1.)) + self.assertEqual(bin0.enumerate_support(), torch.tensor([0.])) + bin1 = Binomial(torch.tensor(5), torch.tensor(0.5)) + self.assertEqual(bin1.enumerate_support(), torch.arange(6)) def test_multinomial_1d(self): total_count = 10 @@ -1793,9 +1830,8 @@ class TestDistributions(TestCase): self.assertEqual(indep_dist.has_rsample, base_dist.has_rsample) if indep_dist.has_rsample: self.assertEqual(indep_dist.sample().shape, base_dist.sample().shape) - if indep_dist.has_enumerate_support: - self.assertEqual(indep_dist.enumerate_support().shape, base_dist.enumerate_support().shape) try: + self.assertEqual(indep_dist.enumerate_support().shape, base_dist.enumerate_support().shape) self.assertEqual(indep_dist.mean.shape, base_dist.mean.shape) except NotImplementedError: pass @@ -2301,6 +2337,15 @@ class TestDistributionShapes(TestCase): self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2) + def test_binomial_shape_vectorized_n(self): + dist = Binomial(torch.tensor([[10, 3, 1], [4, 8, 4]]), torch.tensor([0.6, 0.3, 0.1])) + self.assertEqual(dist._batch_shape, torch.Size((2, 3))) + self.assertEqual(dist._event_shape, torch.Size(())) + self.assertEqual(dist.sample().size(), torch.Size((2, 3))) + self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 2, 3))) + self.assertEqual(dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))) + self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1) + def test_multinomial_shape(self): dist = Multinomial(10, torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]])) self.assertEqual(dist._batch_shape, torch.Size((3,))) @@ -2562,6 +2607,8 @@ class TestKL(TestCase): # e.g. bernoulli[1] varies row-wise; that way we test all param pairs. bernoulli = pairwise(Bernoulli, [0.1, 0.2, 0.6, 0.9]) binomial30 = pairwise(Binomial30, [0.1, 0.2, 0.6, 0.9]) + binomial_vectorized_count = (Binomial(torch.tensor([3, 4]), torch.tensor([0.4, 0.6])), + Binomial(torch.tensor([3, 4]), torch.tensor([0.5, 0.8]))) beta = pairwise(Beta, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5]) categorical = pairwise(Categorical, [[0.4, 0.3, 0.3], [0.2, 0.7, 0.1], @@ -2607,6 +2654,7 @@ class TestKL(TestCase): (beta, gamma), (beta, normal), (binomial30, binomial30), + (binomial_vectorized_count, binomial_vectorized_count), (categorical, categorical), (chi2, chi2), (chi2, exponential), @@ -2654,6 +2702,8 @@ class TestKL(TestCase): (Beta(1, 2), Uniform(0.25, 0.75)), (Beta(1, 2), Pareto(1, 2)), (Binomial(31, 0.7), Binomial(30, 0.3)), + (Binomial(torch.tensor([3, 4]), torch.tensor([0.4, 0.6])), + Binomial(torch.tensor([2, 3]), torch.tensor([0.5, 0.8]))), (Chi2(1), Beta(2, 3)), (Chi2(1), Pareto(2, 3)), (Chi2(1), Uniform(-2, 3)), diff --git a/torch/distributions/binomial.py b/torch/distributions/binomial.py index f89e4dc577..ef52ff2986 100644 --- a/torch/distributions/binomial.py +++ b/torch/distributions/binomial.py @@ -1,19 +1,15 @@ from numbers import Number import torch -import math from torch.distributions import constraints from torch.distributions.distribution import Distribution from torch.distributions.utils import broadcast_all, probs_to_logits, lazy_property, logits_to_probs -from torch.distributions.utils import clamp_probs class Binomial(Distribution): r""" Creates a Binomial distribution parameterized by `total_count` and - either `probs` or `logits` (but not both). - - - Requires a single shared `total_count` for all - parameters and samples. + either `probs` or `logits` (but not both). `total_count` must be + broadcastable with `probs`/`logits`. Example:: @@ -25,26 +21,32 @@ class Binomial(Distribution): 100 [torch.FloatTensor of size 4]] + >>> m = Binomial(torch.Tensor([[5.], [10.]]), torch.Tensor([0.5, 0.8])) + >>> x = m.sample() + 4 5 + 7 6 + [torch.FloatTensor of size (2,2)] + Args: - total_count (int): number of Bernoulli trials + total_count (int or Tensor): number of Bernoulli trials probs (Tensor): Event probabilities logits (Tensor): Event log-odds """ - arg_constraints = {'probs': constraints.unit_interval} + arg_constraints = {'total_count': constraints.nonnegative_integer, + 'probs': constraints.unit_interval} has_enumerate_support = True def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): - if not isinstance(total_count, Number): - raise NotImplementedError('inhomogeneous total_count is not supported') - self.total_count = total_count if (probs is None) == (logits is None): raise ValueError("Either `probs` or `logits` must be specified, but not both.") if probs is not None: - is_scalar = isinstance(probs, Number) - self.probs, = broadcast_all(probs) + self.total_count, self.probs, = broadcast_all(total_count, probs) + self.total_count = self.total_count.type_as(self.logits) + is_scalar = isinstance(self.probs, Number) else: - is_scalar = isinstance(logits, Number) - self.logits, = broadcast_all(logits) + self.total_count, self.logits, = broadcast_all(total_count, logits) + self.total_count = self.total_count.type_as(self.logits) + is_scalar = isinstance(self.logits, Number) self._param = self.probs if probs is not None else self.logits if is_scalar: @@ -81,14 +83,20 @@ class Binomial(Distribution): return self._param.size() def sample(self, sample_shape=torch.Size()): - shape = self._extended_shape(sample_shape) + (self.total_count,) with torch.no_grad(): - return torch.bernoulli(self.probs.unsqueeze(-1).expand(shape)).sum(dim=-1) + max_count = max(int(self.total_count.max()), 1) + shape = self._extended_shape(sample_shape) + (max_count,) + bernoullis = torch.bernoulli(self.probs.unsqueeze(-1).expand(shape)) + if self.total_count.min() != max_count: + arange = torch.arange(max_count, out=self.total_count.new_empty(max_count)) + mask = arange >= self.total_count.unsqueeze(-1) + bernoullis.masked_fill_(mask, 0.) + return bernoullis.sum(dim=-1) def log_prob(self, value): if self._validate_args: self._validate_sample(value) - log_factorial_n = math.lgamma(self.total_count + 1) + log_factorial_n = torch.lgamma(self.total_count + 1) log_factorial_k = torch.lgamma(value + 1) log_factorial_nmk = torch.lgamma(self.total_count - value + 1) max_val = (-self.logits).clamp(min=0.0) @@ -98,8 +106,11 @@ class Binomial(Distribution): self.total_count * torch.log1p((self.logits + 2 * max_val).exp())) def enumerate_support(self): - values = self._new((self.total_count,)) - torch.arange(self.total_count, out=values.data) + total_count = int(self.total_count.max()) + if not self.total_count.min() == total_count: + raise NotImplementedError("Inhomogeneous total count not supported by `enumerate_support`.") + values = self._new(1 + total_count,) + torch.arange(1 + total_count, out=values) values = values.view((-1,) + (1,) * len(self._batch_shape)) values = values.expand((-1,) + self._batch_shape) return values diff --git a/torch/distributions/kl.py b/torch/distributions/kl.py index bdb9269c49..203e0726ad 100644 --- a/torch/distributions/kl.py +++ b/torch/distributions/kl.py @@ -198,12 +198,12 @@ def _kl_beta_beta(p, q): def _kl_binomial_binomial(p, q): # from https://math.stackexchange.com/questions/2214993/ # kullback-leibler-divergence-for-binomial-distributions-p-and-q - if p.total_count > q.total_count: - return _infinite_like(p.probs) - elif p.total_count == q.total_count: - return p.total_count * (p.probs * (p.logits - q.logits) + (-p.probs).log1p() - (-q.probs).log1p()) - else: + if (p.total_count < q.total_count).any(): raise NotImplementedError('KL between Binomials where q.total_count > p.total_count is not implemented') + kl = p.total_count * (p.probs * (p.logits - q.logits) + (-p.probs).log1p() - (-q.probs).log1p()) + inf_idxs = p.total_count > q.total_count + kl[inf_idxs] = _infinite_like(kl[inf_idxs]) + return kl @register_kl(Categorical, Categorical) |