summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFritz Obermeyer <fritz.obermeyer@gmail.com>2018-02-13 16:32:11 -0800
committerAdam Paszke <adam.paszke@gmail.com>2018-02-14 01:32:11 +0100
commita4d0a74ceefa9e18810f8caf0fdf66fea473008b (patch)
tree136ca16f1ce3017e7a0de672ac0375be0812ee56
parent1b71e78d133eb156bf037f0f12550032d1b90bd8 (diff)
downloadpytorch-a4d0a74ceefa9e18810f8caf0fdf66fea473008b.tar.gz
pytorch-a4d0a74ceefa9e18810f8caf0fdf66fea473008b.tar.bz2
pytorch-a4d0a74ceefa9e18810f8caf0fdf66fea473008b.zip
Ensure Distribution.sample() result is detached (#5086)
-rw-r--r--test/test_distributions.py25
-rw-r--r--torch/distributions/bernoulli.py3
-rw-r--r--torch/distributions/binomial.py3
-rw-r--r--torch/distributions/distribution.py4
-rw-r--r--torch/distributions/geometric.py5
-rw-r--r--torch/distributions/normal.py3
-rw-r--r--torch/distributions/poisson.py3
-rw-r--r--torch/distributions/transformed_distribution.py9
8 files changed, 43 insertions, 12 deletions
diff --git a/test/test_distributions.py b/test/test_distributions.py
index 9a11db6c2d..c52c07357e 100644
--- a/test/test_distributions.py
+++ b/test/test_distributions.py
@@ -393,6 +393,31 @@ class TestDistributions(TestCase):
actual = dist(param).enumerate_support()
self.assertEqual(actual, expected)
+ def test_sample_detached(self):
+ for Dist, params in EXAMPLES:
+ for i, param in enumerate(params):
+ variable_params = [p for p in param.values() if getattr(p, 'requires_grad', False)]
+ if not variable_params:
+ continue
+ dist = Dist(**param)
+ sample = dist.sample()
+ self.assertFalse(sample.requires_grad,
+ msg='{} example {}/{}, .sample() is not detached'.format(
+ Dist.__name__, i + 1, len(params)))
+
+ def test_rsample_requires_grad(self):
+ for Dist, params in EXAMPLES:
+ for i, param in enumerate(params):
+ if not any(getattr(p, 'requires_grad', False) for p in param.values()):
+ continue
+ dist = Dist(**param)
+ if not dist.has_rsample:
+ continue
+ sample = dist.rsample()
+ self.assertTrue(sample.requires_grad,
+ msg='{} example {}/{}, .rsample() does not require grad'.format(
+ Dist.__name__, i + 1, len(params)))
+
def test_enumerate_support_type(self):
for Dist, params in EXAMPLES:
for i, param in enumerate(params):
diff --git a/torch/distributions/bernoulli.py b/torch/distributions/bernoulli.py
index a8a532ab01..dd6c7052a0 100644
--- a/torch/distributions/bernoulli.py
+++ b/torch/distributions/bernoulli.py
@@ -72,7 +72,8 @@ class Bernoulli(ExponentialFamily):
def sample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
- return torch.bernoulli(self.probs.expand(shape))
+ with torch.no_grad():
+ return torch.bernoulli(self.probs.expand(shape))
def log_prob(self, value):
self._validate_log_prob_arg(value)
diff --git a/torch/distributions/binomial.py b/torch/distributions/binomial.py
index 471525e00a..c6dad21958 100644
--- a/torch/distributions/binomial.py
+++ b/torch/distributions/binomial.py
@@ -83,7 +83,8 @@ class Binomial(Distribution):
def sample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape) + (self.total_count,)
- return torch.bernoulli(self.probs.unsqueeze(-1).expand(shape)).sum(dim=-1)
+ with torch.no_grad():
+ return torch.bernoulli(self.probs.unsqueeze(-1).expand(shape)).sum(dim=-1)
def log_prob(self, value):
self._validate_log_prob_arg(value)
diff --git a/torch/distributions/distribution.py b/torch/distributions/distribution.py
index ffa6dfbfe2..210df1f5d5 100644
--- a/torch/distributions/distribution.py
+++ b/torch/distributions/distribution.py
@@ -72,8 +72,8 @@ class Distribution(object):
Generates a sample_shape shaped sample or sample_shape shaped batch of
samples if the distribution parameters are batched.
"""
- z = self.rsample(sample_shape)
- return z.detach() if hasattr(z, 'detach') else z
+ with torch.no_grad():
+ return self.rsample(sample_shape)
def rsample(self, sample_shape=torch.Size()):
"""
diff --git a/torch/distributions/geometric.py b/torch/distributions/geometric.py
index df09821985..c231ac30ce 100644
--- a/torch/distributions/geometric.py
+++ b/torch/distributions/geometric.py
@@ -63,8 +63,9 @@ class Geometric(Distribution):
def sample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
- u = self.probs.new(shape).uniform_(_finfo(self.probs).tiny, 1)
- return (u.log() / (-self.probs).log1p()).floor()
+ with torch.no_grad():
+ u = self.probs.new(shape).uniform_(_finfo(self.probs).tiny, 1)
+ return (u.log() / (-self.probs).log1p()).floor()
def log_prob(self, value):
self._validate_log_prob_arg(value)
diff --git a/torch/distributions/normal.py b/torch/distributions/normal.py
index 8720cedc35..ffbe042af8 100644
--- a/torch/distributions/normal.py
+++ b/torch/distributions/normal.py
@@ -51,7 +51,8 @@ class Normal(ExponentialFamily):
def sample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
- return torch.normal(self.loc.expand(shape), self.scale.expand(shape))
+ with torch.no_grad():
+ return torch.normal(self.loc.expand(shape), self.scale.expand(shape))
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
diff --git a/torch/distributions/poisson.py b/torch/distributions/poisson.py
index f51eb958e3..06dae94a88 100644
--- a/torch/distributions/poisson.py
+++ b/torch/distributions/poisson.py
@@ -51,7 +51,8 @@ class Poisson(ExponentialFamily):
def sample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
- return _poisson(self.rate.expand(shape))
+ with torch.no_grad():
+ return _poisson(self.rate.expand(shape))
def log_prob(self, value):
self._validate_log_prob_arg(value)
diff --git a/torch/distributions/transformed_distribution.py b/torch/distributions/transformed_distribution.py
index 282dd2cf7e..b4bcfdfc91 100644
--- a/torch/distributions/transformed_distribution.py
+++ b/torch/distributions/transformed_distribution.py
@@ -53,10 +53,11 @@ class TransformedDistribution(Distribution):
base distribution and applies `transform()` for every transform in the
list.
"""
- x = self.base_dist.sample(sample_shape)
- for transform in self.transforms:
- x = transform(x)
- return x
+ with torch.no_grad():
+ x = self.base_dist.sample(sample_shape)
+ for transform in self.transforms:
+ x = transform(x)
+ return x
def rsample(self, sample_shape=torch.Size()):
"""