summaryrefslogtreecommitdiff
path: root/torch
diff options
context:
space:
mode:
authorFritz Obermeyer <fritz.obermeyer@gmail.com>2018-02-13 16:32:11 -0800
committerAdam Paszke <adam.paszke@gmail.com>2018-02-14 01:32:11 +0100
commita4d0a74ceefa9e18810f8caf0fdf66fea473008b (patch)
tree136ca16f1ce3017e7a0de672ac0375be0812ee56 /torch
parent1b71e78d133eb156bf037f0f12550032d1b90bd8 (diff)
downloadpytorch-a4d0a74ceefa9e18810f8caf0fdf66fea473008b.tar.gz
pytorch-a4d0a74ceefa9e18810f8caf0fdf66fea473008b.tar.bz2
pytorch-a4d0a74ceefa9e18810f8caf0fdf66fea473008b.zip
Ensure Distribution.sample() result is detached (#5086)
Diffstat (limited to 'torch')
-rw-r--r--torch/distributions/bernoulli.py3
-rw-r--r--torch/distributions/binomial.py3
-rw-r--r--torch/distributions/distribution.py4
-rw-r--r--torch/distributions/geometric.py5
-rw-r--r--torch/distributions/normal.py3
-rw-r--r--torch/distributions/poisson.py3
-rw-r--r--torch/distributions/transformed_distribution.py9
7 files changed, 18 insertions, 12 deletions
diff --git a/torch/distributions/bernoulli.py b/torch/distributions/bernoulli.py
index a8a532ab01..dd6c7052a0 100644
--- a/torch/distributions/bernoulli.py
+++ b/torch/distributions/bernoulli.py
@@ -72,7 +72,8 @@ class Bernoulli(ExponentialFamily):
def sample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
- return torch.bernoulli(self.probs.expand(shape))
+ with torch.no_grad():
+ return torch.bernoulli(self.probs.expand(shape))
def log_prob(self, value):
self._validate_log_prob_arg(value)
diff --git a/torch/distributions/binomial.py b/torch/distributions/binomial.py
index 471525e00a..c6dad21958 100644
--- a/torch/distributions/binomial.py
+++ b/torch/distributions/binomial.py
@@ -83,7 +83,8 @@ class Binomial(Distribution):
def sample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape) + (self.total_count,)
- return torch.bernoulli(self.probs.unsqueeze(-1).expand(shape)).sum(dim=-1)
+ with torch.no_grad():
+ return torch.bernoulli(self.probs.unsqueeze(-1).expand(shape)).sum(dim=-1)
def log_prob(self, value):
self._validate_log_prob_arg(value)
diff --git a/torch/distributions/distribution.py b/torch/distributions/distribution.py
index ffa6dfbfe2..210df1f5d5 100644
--- a/torch/distributions/distribution.py
+++ b/torch/distributions/distribution.py
@@ -72,8 +72,8 @@ class Distribution(object):
Generates a sample_shape shaped sample or sample_shape shaped batch of
samples if the distribution parameters are batched.
"""
- z = self.rsample(sample_shape)
- return z.detach() if hasattr(z, 'detach') else z
+ with torch.no_grad():
+ return self.rsample(sample_shape)
def rsample(self, sample_shape=torch.Size()):
"""
diff --git a/torch/distributions/geometric.py b/torch/distributions/geometric.py
index df09821985..c231ac30ce 100644
--- a/torch/distributions/geometric.py
+++ b/torch/distributions/geometric.py
@@ -63,8 +63,9 @@ class Geometric(Distribution):
def sample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
- u = self.probs.new(shape).uniform_(_finfo(self.probs).tiny, 1)
- return (u.log() / (-self.probs).log1p()).floor()
+ with torch.no_grad():
+ u = self.probs.new(shape).uniform_(_finfo(self.probs).tiny, 1)
+ return (u.log() / (-self.probs).log1p()).floor()
def log_prob(self, value):
self._validate_log_prob_arg(value)
diff --git a/torch/distributions/normal.py b/torch/distributions/normal.py
index 8720cedc35..ffbe042af8 100644
--- a/torch/distributions/normal.py
+++ b/torch/distributions/normal.py
@@ -51,7 +51,8 @@ class Normal(ExponentialFamily):
def sample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
- return torch.normal(self.loc.expand(shape), self.scale.expand(shape))
+ with torch.no_grad():
+ return torch.normal(self.loc.expand(shape), self.scale.expand(shape))
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
diff --git a/torch/distributions/poisson.py b/torch/distributions/poisson.py
index f51eb958e3..06dae94a88 100644
--- a/torch/distributions/poisson.py
+++ b/torch/distributions/poisson.py
@@ -51,7 +51,8 @@ class Poisson(ExponentialFamily):
def sample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
- return _poisson(self.rate.expand(shape))
+ with torch.no_grad():
+ return _poisson(self.rate.expand(shape))
def log_prob(self, value):
self._validate_log_prob_arg(value)
diff --git a/torch/distributions/transformed_distribution.py b/torch/distributions/transformed_distribution.py
index 282dd2cf7e..b4bcfdfc91 100644
--- a/torch/distributions/transformed_distribution.py
+++ b/torch/distributions/transformed_distribution.py
@@ -53,10 +53,11 @@ class TransformedDistribution(Distribution):
base distribution and applies `transform()` for every transform in the
list.
"""
- x = self.base_dist.sample(sample_shape)
- for transform in self.transforms:
- x = transform(x)
- return x
+ with torch.no_grad():
+ x = self.base_dist.sample(sample_shape)
+ for transform in self.transforms:
+ x = transform(x)
+ return x
def rsample(self, sample_shape=torch.Size()):
"""