summaryrefslogtreecommitdiff
path: root/torch/distributions/transforms.py
diff options
context:
space:
mode:
authorWei Yang <weiyang@fb.com>2018-07-02 15:44:12 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-07-02 15:54:46 -0700
commitcb1bfe91afcab1f7431202afe03381a5d59e4ee3 (patch)
treeefcf92c5200984cb5c014bbd1f4bea6ec2e6093b /torch/distributions/transforms.py
parent50392cc55444672a4b63b335455040a6ba53ae32 (diff)
downloadpytorch-cb1bfe91afcab1f7431202afe03381a5d59e4ee3.tar.gz
pytorch-cb1bfe91afcab1f7431202afe03381a5d59e4ee3.tar.bz2
pytorch-cb1bfe91afcab1f7431202afe03381a5d59e4ee3.zip
Deprecated several functions at torch.nn.functional (#8748)
Summary: 1. fixes #6245 2. deprecated tanh, sigmoid Closes https://github.com/pytorch/pytorch/pull/8748 Differential Revision: D8697975 Pulled By: weiyangfb fbshipit-source-id: f30714aa0611a1fe870040692f3dbcc8238aece9
Diffstat (limited to 'torch/distributions/transforms.py')
-rw-r--r--torch/distributions/transforms.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/torch/distributions/transforms.py b/torch/distributions/transforms.py
index 6852a71df9..9b241f6eba 100644
--- a/torch/distributions/transforms.py
+++ b/torch/distributions/transforms.py
@@ -6,7 +6,7 @@ import torch
from torch.distributions import constraints
from torch.distributions.utils import (_sum_rightmost, broadcast_all,
lazy_property)
-from torch.nn.functional import pad, sigmoid
+from torch.nn.functional import pad
__all__ = [
'AbsTransform',
@@ -341,7 +341,7 @@ class SigmoidTransform(Transform):
return isinstance(other, SigmoidTransform)
def _call(self, x):
- return sigmoid(x)
+ return torch.sigmoid(x)
def _inverse(self, y):
return y.log() - (-y).log1p()
@@ -483,7 +483,7 @@ class StickBreakingTransform(Transform):
def _call(self, x):
offset = (x.shape[-1] + 1) - x.new([1]).expand(x.shape).cumsum(-1)
- z = sigmoid(x - offset.log())
+ z = torch.sigmoid(x - offset.log())
z_cumprod = (1 - z).cumprod(-1)
y = pad(z, (0, 1), value=1) * pad(z_cumprod, (1, 0), value=1)
return y
@@ -497,7 +497,7 @@ class StickBreakingTransform(Transform):
def log_abs_det_jacobian(self, x, y):
offset = (x.shape[-1] + 1) - x.new([1]).expand(x.shape).cumsum(-1)
- z = sigmoid(x - offset.log())
+ z = torch.sigmoid(x - offset.log())
detJ = ((1 - z).log() + y[..., :-1].log()).sum(-1)
return detJ