diff options
author | Fritz Obermeyer <fritz.obermeyer@gmail.com> | 2018-02-13 16:32:11 -0800 |
---|---|---|
committer | Adam Paszke <adam.paszke@gmail.com> | 2018-02-14 01:32:11 +0100 |
commit | a4d0a74ceefa9e18810f8caf0fdf66fea473008b (patch) | |
tree | 136ca16f1ce3017e7a0de672ac0375be0812ee56 | |
parent | 1b71e78d133eb156bf037f0f12550032d1b90bd8 (diff) | |
download | pytorch-a4d0a74ceefa9e18810f8caf0fdf66fea473008b.tar.gz pytorch-a4d0a74ceefa9e18810f8caf0fdf66fea473008b.tar.bz2 pytorch-a4d0a74ceefa9e18810f8caf0fdf66fea473008b.zip |
Ensure Distribution.sample() result is detached (#5086)
-rw-r--r-- | test/test_distributions.py | 25 | ||||
-rw-r--r-- | torch/distributions/bernoulli.py | 3 | ||||
-rw-r--r-- | torch/distributions/binomial.py | 3 | ||||
-rw-r--r-- | torch/distributions/distribution.py | 4 | ||||
-rw-r--r-- | torch/distributions/geometric.py | 5 | ||||
-rw-r--r-- | torch/distributions/normal.py | 3 | ||||
-rw-r--r-- | torch/distributions/poisson.py | 3 | ||||
-rw-r--r-- | torch/distributions/transformed_distribution.py | 9 |
8 files changed, 43 insertions, 12 deletions
diff --git a/test/test_distributions.py b/test/test_distributions.py index 9a11db6c2d..c52c07357e 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -393,6 +393,31 @@ class TestDistributions(TestCase): actual = dist(param).enumerate_support() self.assertEqual(actual, expected) + def test_sample_detached(self): + for Dist, params in EXAMPLES: + for i, param in enumerate(params): + variable_params = [p for p in param.values() if getattr(p, 'requires_grad', False)] + if not variable_params: + continue + dist = Dist(**param) + sample = dist.sample() + self.assertFalse(sample.requires_grad, + msg='{} example {}/{}, .sample() is not detached'.format( + Dist.__name__, i + 1, len(params))) + + def test_rsample_requires_grad(self): + for Dist, params in EXAMPLES: + for i, param in enumerate(params): + if not any(getattr(p, 'requires_grad', False) for p in param.values()): + continue + dist = Dist(**param) + if not dist.has_rsample: + continue + sample = dist.rsample() + self.assertTrue(sample.requires_grad, + msg='{} example {}/{}, .rsample() does not require grad'.format( + Dist.__name__, i + 1, len(params))) + def test_enumerate_support_type(self): for Dist, params in EXAMPLES: for i, param in enumerate(params): diff --git a/torch/distributions/bernoulli.py b/torch/distributions/bernoulli.py index a8a532ab01..dd6c7052a0 100644 --- a/torch/distributions/bernoulli.py +++ b/torch/distributions/bernoulli.py @@ -72,7 +72,8 @@ class Bernoulli(ExponentialFamily): def sample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) - return torch.bernoulli(self.probs.expand(shape)) + with torch.no_grad(): + return torch.bernoulli(self.probs.expand(shape)) def log_prob(self, value): self._validate_log_prob_arg(value) diff --git a/torch/distributions/binomial.py b/torch/distributions/binomial.py index 471525e00a..c6dad21958 100644 --- a/torch/distributions/binomial.py +++ b/torch/distributions/binomial.py @@ -83,7 +83,8 @@ class Binomial(Distribution): def sample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) + (self.total_count,) - return torch.bernoulli(self.probs.unsqueeze(-1).expand(shape)).sum(dim=-1) + with torch.no_grad(): + return torch.bernoulli(self.probs.unsqueeze(-1).expand(shape)).sum(dim=-1) def log_prob(self, value): self._validate_log_prob_arg(value) diff --git a/torch/distributions/distribution.py b/torch/distributions/distribution.py index ffa6dfbfe2..210df1f5d5 100644 --- a/torch/distributions/distribution.py +++ b/torch/distributions/distribution.py @@ -72,8 +72,8 @@ class Distribution(object): Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched. """ - z = self.rsample(sample_shape) - return z.detach() if hasattr(z, 'detach') else z + with torch.no_grad(): + return self.rsample(sample_shape) def rsample(self, sample_shape=torch.Size()): """ diff --git a/torch/distributions/geometric.py b/torch/distributions/geometric.py index df09821985..c231ac30ce 100644 --- a/torch/distributions/geometric.py +++ b/torch/distributions/geometric.py @@ -63,8 +63,9 @@ class Geometric(Distribution): def sample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) - u = self.probs.new(shape).uniform_(_finfo(self.probs).tiny, 1) - return (u.log() / (-self.probs).log1p()).floor() + with torch.no_grad(): + u = self.probs.new(shape).uniform_(_finfo(self.probs).tiny, 1) + return (u.log() / (-self.probs).log1p()).floor() def log_prob(self, value): self._validate_log_prob_arg(value) diff --git a/torch/distributions/normal.py b/torch/distributions/normal.py index 8720cedc35..ffbe042af8 100644 --- a/torch/distributions/normal.py +++ b/torch/distributions/normal.py @@ -51,7 +51,8 @@ class Normal(ExponentialFamily): def sample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) - return torch.normal(self.loc.expand(shape), self.scale.expand(shape)) + with torch.no_grad(): + return torch.normal(self.loc.expand(shape), self.scale.expand(shape)) def rsample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) diff --git a/torch/distributions/poisson.py b/torch/distributions/poisson.py index f51eb958e3..06dae94a88 100644 --- a/torch/distributions/poisson.py +++ b/torch/distributions/poisson.py @@ -51,7 +51,8 @@ class Poisson(ExponentialFamily): def sample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) - return _poisson(self.rate.expand(shape)) + with torch.no_grad(): + return _poisson(self.rate.expand(shape)) def log_prob(self, value): self._validate_log_prob_arg(value) diff --git a/torch/distributions/transformed_distribution.py b/torch/distributions/transformed_distribution.py index 282dd2cf7e..b4bcfdfc91 100644 --- a/torch/distributions/transformed_distribution.py +++ b/torch/distributions/transformed_distribution.py @@ -53,10 +53,11 @@ class TransformedDistribution(Distribution): base distribution and applies `transform()` for every transform in the list. """ - x = self.base_dist.sample(sample_shape) - for transform in self.transforms: - x = transform(x) - return x + with torch.no_grad(): + x = self.base_dist.sample(sample_shape) + for transform in self.transforms: + x = transform(x) + return x def rsample(self, sample_shape=torch.Size()): """ |