diff options
author | Fritz Obermeyer <fritz.obermeyer@gmail.com> | 2018-02-13 16:32:11 -0800 |
---|---|---|
committer | Adam Paszke <adam.paszke@gmail.com> | 2018-02-14 01:32:11 +0100 |
commit | a4d0a74ceefa9e18810f8caf0fdf66fea473008b (patch) | |
tree | 136ca16f1ce3017e7a0de672ac0375be0812ee56 /torch | |
parent | 1b71e78d133eb156bf037f0f12550032d1b90bd8 (diff) | |
download | pytorch-a4d0a74ceefa9e18810f8caf0fdf66fea473008b.tar.gz pytorch-a4d0a74ceefa9e18810f8caf0fdf66fea473008b.tar.bz2 pytorch-a4d0a74ceefa9e18810f8caf0fdf66fea473008b.zip |
Ensure Distribution.sample() result is detached (#5086)
Diffstat (limited to 'torch')
-rw-r--r-- | torch/distributions/bernoulli.py | 3 | ||||
-rw-r--r-- | torch/distributions/binomial.py | 3 | ||||
-rw-r--r-- | torch/distributions/distribution.py | 4 | ||||
-rw-r--r-- | torch/distributions/geometric.py | 5 | ||||
-rw-r--r-- | torch/distributions/normal.py | 3 | ||||
-rw-r--r-- | torch/distributions/poisson.py | 3 | ||||
-rw-r--r-- | torch/distributions/transformed_distribution.py | 9 |
7 files changed, 18 insertions, 12 deletions
diff --git a/torch/distributions/bernoulli.py b/torch/distributions/bernoulli.py index a8a532ab01..dd6c7052a0 100644 --- a/torch/distributions/bernoulli.py +++ b/torch/distributions/bernoulli.py @@ -72,7 +72,8 @@ class Bernoulli(ExponentialFamily): def sample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) - return torch.bernoulli(self.probs.expand(shape)) + with torch.no_grad(): + return torch.bernoulli(self.probs.expand(shape)) def log_prob(self, value): self._validate_log_prob_arg(value) diff --git a/torch/distributions/binomial.py b/torch/distributions/binomial.py index 471525e00a..c6dad21958 100644 --- a/torch/distributions/binomial.py +++ b/torch/distributions/binomial.py @@ -83,7 +83,8 @@ class Binomial(Distribution): def sample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) + (self.total_count,) - return torch.bernoulli(self.probs.unsqueeze(-1).expand(shape)).sum(dim=-1) + with torch.no_grad(): + return torch.bernoulli(self.probs.unsqueeze(-1).expand(shape)).sum(dim=-1) def log_prob(self, value): self._validate_log_prob_arg(value) diff --git a/torch/distributions/distribution.py b/torch/distributions/distribution.py index ffa6dfbfe2..210df1f5d5 100644 --- a/torch/distributions/distribution.py +++ b/torch/distributions/distribution.py @@ -72,8 +72,8 @@ class Distribution(object): Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched. """ - z = self.rsample(sample_shape) - return z.detach() if hasattr(z, 'detach') else z + with torch.no_grad(): + return self.rsample(sample_shape) def rsample(self, sample_shape=torch.Size()): """ diff --git a/torch/distributions/geometric.py b/torch/distributions/geometric.py index df09821985..c231ac30ce 100644 --- a/torch/distributions/geometric.py +++ b/torch/distributions/geometric.py @@ -63,8 +63,9 @@ class Geometric(Distribution): def sample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) - u = self.probs.new(shape).uniform_(_finfo(self.probs).tiny, 1) - return (u.log() / (-self.probs).log1p()).floor() + with torch.no_grad(): + u = self.probs.new(shape).uniform_(_finfo(self.probs).tiny, 1) + return (u.log() / (-self.probs).log1p()).floor() def log_prob(self, value): self._validate_log_prob_arg(value) diff --git a/torch/distributions/normal.py b/torch/distributions/normal.py index 8720cedc35..ffbe042af8 100644 --- a/torch/distributions/normal.py +++ b/torch/distributions/normal.py @@ -51,7 +51,8 @@ class Normal(ExponentialFamily): def sample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) - return torch.normal(self.loc.expand(shape), self.scale.expand(shape)) + with torch.no_grad(): + return torch.normal(self.loc.expand(shape), self.scale.expand(shape)) def rsample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) diff --git a/torch/distributions/poisson.py b/torch/distributions/poisson.py index f51eb958e3..06dae94a88 100644 --- a/torch/distributions/poisson.py +++ b/torch/distributions/poisson.py @@ -51,7 +51,8 @@ class Poisson(ExponentialFamily): def sample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) - return _poisson(self.rate.expand(shape)) + with torch.no_grad(): + return _poisson(self.rate.expand(shape)) def log_prob(self, value): self._validate_log_prob_arg(value) diff --git a/torch/distributions/transformed_distribution.py b/torch/distributions/transformed_distribution.py index 282dd2cf7e..b4bcfdfc91 100644 --- a/torch/distributions/transformed_distribution.py +++ b/torch/distributions/transformed_distribution.py @@ -53,10 +53,11 @@ class TransformedDistribution(Distribution): base distribution and applies `transform()` for every transform in the list. """ - x = self.base_dist.sample(sample_shape) - for transform in self.transforms: - x = transform(x) - return x + with torch.no_grad(): + x = self.base_dist.sample(sample_shape) + for transform in self.transforms: + x = transform(x) + return x def rsample(self, sample_shape=torch.Size()): """ |