summaryrefslogtreecommitdiff
path: root/test/test_distributions.py
diff options
context:
space:
mode:
authorlazypanda1 <35884075+lazypanda1@users.noreply.github.com>2018-03-15 17:11:20 -0500
committerAdam Paszke <adam.paszke@gmail.com>2018-03-15 23:11:20 +0100
commit7f864bbe52ad295b08f57fef7ad77226b303e492 (patch)
tree8d5a4a18457fadcce9fe8ac7b556e9f274842306 /test/test_distributions.py
parenteeb90d9c95065e332c7f1545fa11e162f4ea3f62 (diff)
downloadpytorch-7f864bbe52ad295b08f57fef7ad77226b303e492.tar.gz
pytorch-7f864bbe52ad295b08f57fef7ad77226b303e492.tar.bz2
pytorch-7f864bbe52ad295b08f57fef7ad77226b303e492.zip
Fixed distribution constraints and added some test cases for distributions parameter check (#5358)
Diffstat (limited to 'test/test_distributions.py')
-rw-r--r--test/test_distributions.py715
1 files changed, 476 insertions, 239 deletions
diff --git a/test/test_distributions.py b/test/test_distributions.py
index c2661811ed..3a076564ad 100644
--- a/test/test_distributions.py
+++ b/test/test_distributions.py
@@ -31,7 +31,7 @@ from random import shuffle
import torch
from common import TestCase, run_tests, set_rng_seed
-from torch.autograd import Variable, grad, gradcheck
+from torch.autograd import Variable, grad, gradcheck, variable
from torch.distributions import (Bernoulli, Beta, Binomial, Categorical,
Cauchy, Chi2, Dirichlet, Distribution,
Exponential, ExponentialFamily,
@@ -86,138 +86,138 @@ def is_all_nan(tensor):
Example = namedtuple('Example', ['Dist', 'params'])
EXAMPLES = [
Example(Bernoulli, [
- {'probs': Variable(torch.Tensor([0.7, 0.2, 0.4]), requires_grad=True)},
- {'probs': Variable(torch.Tensor([0.3]), requires_grad=True)},
+ {'probs': variable([0.7, 0.2, 0.4], requires_grad=True)},
+ {'probs': variable([0.3], requires_grad=True)},
{'probs': 0.3},
]),
Example(Geometric, [
- {'probs': Variable(torch.Tensor([0.7, 0.2, 0.4]), requires_grad=True)},
- {'probs': Variable(torch.Tensor([0.3]), requires_grad=True)},
+ {'probs': variable([0.7, 0.2, 0.4], requires_grad=True)},
+ {'probs': variable([0.3], requires_grad=True)},
{'probs': 0.3},
]),
Example(Beta, [
{
- 'concentration1': Variable(torch.exp(torch.randn(2, 3)), requires_grad=True),
- 'concentration0': Variable(torch.exp(torch.randn(2, 3)), requires_grad=True),
+ 'concentration1': variable(torch.exp(torch.randn(2, 3)), requires_grad=True),
+ 'concentration0': variable(torch.exp(torch.randn(2, 3)), requires_grad=True),
},
{
- 'concentration1': Variable(torch.exp(torch.randn(4)), requires_grad=True),
- 'concentration0': Variable(torch.exp(torch.randn(4)), requires_grad=True),
+ 'concentration1': variable(torch.exp(torch.randn(4)), requires_grad=True),
+ 'concentration0': variable(torch.exp(torch.randn(4)), requires_grad=True),
},
]),
Example(Categorical, [
- {'probs': Variable(torch.Tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]), requires_grad=True)},
- {'probs': Variable(torch.Tensor([[1.0, 0.0], [0.0, 1.0]]), requires_grad=True)},
+ {'probs': variable([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True)},
+ {'probs': variable([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)},
]),
Example(Binomial, [
- {'probs': Variable(torch.Tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]), requires_grad=True), 'total_count': 10},
- {'probs': Variable(torch.Tensor([[1.0, 0.0], [0.0, 1.0]]), requires_grad=True), 'total_count': 10},
+ {'probs': variable([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True), 'total_count': 10},
+ {'probs': variable([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 'total_count': 10},
]),
Example(Multinomial, [
- {'probs': Variable(torch.Tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]), requires_grad=True), 'total_count': 10},
- {'probs': Variable(torch.Tensor([[1.0, 0.0], [0.0, 1.0]]), requires_grad=True), 'total_count': 10},
+ {'probs': variable([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True), 'total_count': 10},
+ {'probs': variable([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 'total_count': 10},
]),
Example(Cauchy, [
{'loc': 0.0, 'scale': 1.0},
- {'loc': Variable(torch.Tensor([0.0])), 'scale': 1.0},
- {'loc': Variable(torch.Tensor([[0.0], [0.0]])),
- 'scale': Variable(torch.Tensor([[1.0], [1.0]]))}
+ {'loc': variable([0.0]), 'scale': 1.0},
+ {'loc': variable([[0.0], [0.0]]),
+ 'scale': variable([[1.0], [1.0]])}
]),
Example(Chi2, [
- {'df': Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)},
- {'df': Variable(torch.exp(torch.randn(1)), requires_grad=True)},
+ {'df': variable(torch.exp(torch.randn(2, 3)), requires_grad=True)},
+ {'df': variable(torch.exp(torch.randn(1)), requires_grad=True)},
]),
Example(StudentT, [
- {'df': Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)},
- {'df': Variable(torch.exp(torch.randn(1)), requires_grad=True)},
+ {'df': variable(torch.exp(torch.randn(2, 3)), requires_grad=True)},
+ {'df': variable(torch.exp(torch.randn(1)), requires_grad=True)},
]),
Example(Dirichlet, [
- {'concentration': Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)},
- {'concentration': Variable(torch.exp(torch.randn(4)), requires_grad=True)},
+ {'concentration': variable(torch.exp(torch.randn(2, 3)), requires_grad=True)},
+ {'concentration': variable(torch.exp(torch.randn(4)), requires_grad=True)},
]),
Example(Exponential, [
- {'rate': Variable(torch.randn(5, 5).abs(), requires_grad=True)},
- {'rate': Variable(torch.randn(1).abs(), requires_grad=True)},
+ {'rate': variable(torch.randn(5, 5).abs(), requires_grad=True)},
+ {'rate': variable(torch.randn(1).abs(), requires_grad=True)},
]),
Example(FisherSnedecor, [
{
- 'df1': Variable(torch.randn(5, 5).abs(), requires_grad=True),
- 'df2': Variable(torch.randn(5, 5).abs(), requires_grad=True),
+ 'df1': variable(torch.randn(5, 5).abs(), requires_grad=True),
+ 'df2': variable(torch.randn(5, 5).abs(), requires_grad=True),
},
{
- 'df1': Variable(torch.randn(1).abs(), requires_grad=True),
- 'df2': Variable(torch.randn(1).abs(), requires_grad=True),
+ 'df1': variable(torch.randn(1).abs(), requires_grad=True),
+ 'df2': variable(torch.randn(1).abs(), requires_grad=True),
},
{
- 'df1': Variable(torch.Tensor([1.0])),
+ 'df1': variable([1.0]),
'df2': 1.0,
}
]),
Example(Gamma, [
{
- 'concentration': Variable(torch.exp(torch.randn(2, 3)), requires_grad=True),
- 'rate': Variable(torch.exp(torch.randn(2, 3)), requires_grad=True),
+ 'concentration': variable(torch.exp(torch.randn(2, 3)), requires_grad=True),
+ 'rate': variable(torch.exp(torch.randn(2, 3)), requires_grad=True),
},
{
- 'concentration': Variable(torch.exp(torch.randn(1)), requires_grad=True),
- 'rate': Variable(torch.exp(torch.randn(1)), requires_grad=True),
+ 'concentration': variable(torch.exp(torch.randn(1)), requires_grad=True),
+ 'rate': variable(torch.exp(torch.randn(1)), requires_grad=True),
},
]),
Example(Gumbel, [
{
- 'loc': Variable(torch.randn(5, 5), requires_grad=True),
- 'scale': Variable(torch.randn(5, 5).abs(), requires_grad=True),
+ 'loc': variable(torch.randn(5, 5), requires_grad=True),
+ 'scale': variable(torch.randn(5, 5).abs(), requires_grad=True),
},
{
- 'loc': Variable(torch.randn(1), requires_grad=True),
- 'scale': Variable(torch.randn(1).abs(), requires_grad=True),
+ 'loc': variable(torch.randn(1), requires_grad=True),
+ 'scale': variable(torch.randn(1).abs(), requires_grad=True),
},
]),
Example(Laplace, [
{
- 'loc': Variable(torch.randn(5, 5), requires_grad=True),
- 'scale': Variable(torch.randn(5, 5).abs(), requires_grad=True),
+ 'loc': variable(torch.randn(5, 5), requires_grad=True),
+ 'scale': variable(torch.randn(5, 5).abs(), requires_grad=True),
},
{
- 'loc': Variable(torch.randn(1), requires_grad=True),
- 'scale': Variable(torch.randn(1).abs(), requires_grad=True),
+ 'loc': variable(torch.randn(1), requires_grad=True),
+ 'scale': variable(torch.randn(1).abs(), requires_grad=True),
},
{
- 'loc': Variable(torch.Tensor([1.0, 0.0]), requires_grad=True),
- 'scale': Variable(torch.Tensor([1e-5, 1e-5]), requires_grad=True),
+ 'loc': variable([1.0, 0.0], requires_grad=True),
+ 'scale': variable([1e-5, 1e-5], requires_grad=True),
},
]),
Example(LogNormal, [
{
- 'loc': Variable(torch.randn(5, 5), requires_grad=True),
- 'scale': Variable(torch.randn(5, 5).abs(), requires_grad=True),
+ 'loc': variable(torch.randn(5, 5), requires_grad=True),
+ 'scale': variable(torch.randn(5, 5).abs(), requires_grad=True),
},
{
- 'loc': Variable(torch.randn(1), requires_grad=True),
- 'scale': Variable(torch.randn(1).abs(), requires_grad=True),
+ 'loc': variable(torch.randn(1), requires_grad=True),
+ 'scale': variable(torch.randn(1).abs(), requires_grad=True),
},
{
- 'loc': Variable(torch.Tensor([1.0, 0.0]), requires_grad=True),
- 'scale': Variable(torch.Tensor([1e-5, 1e-5]), requires_grad=True),
+ 'loc': variable([1.0, 0.0], requires_grad=True),
+ 'scale': variable([1e-5, 1e-5], requires_grad=True),
},
]),
Example(Normal, [
{
- 'loc': Variable(torch.randn(5, 5), requires_grad=True),
- 'scale': Variable(torch.randn(5, 5).abs(), requires_grad=True),
+ 'loc': variable(torch.randn(5, 5), requires_grad=True),
+ 'scale': variable(torch.randn(5, 5).abs(), requires_grad=True),
},
{
- 'loc': Variable(torch.randn(1), requires_grad=True),
- 'scale': Variable(torch.randn(1).abs(), requires_grad=True),
+ 'loc': variable(torch.randn(1), requires_grad=True),
+ 'scale': variable(torch.randn(1).abs(), requires_grad=True),
},
{
- 'loc': Variable(torch.Tensor([1.0, 0.0]), requires_grad=True),
- 'scale': Variable(torch.Tensor([1e-5, 1e-5]), requires_grad=True),
+ 'loc': variable([1.0, 0.0], requires_grad=True),
+ 'scale': variable([1e-5, 1e-5], requires_grad=True),
},
]),
Example(OneHotCategorical, [
- {'probs': Variable(torch.Tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]), requires_grad=True)},
- {'probs': Variable(torch.Tensor([[1.0, 0.0], [0.0, 1.0]]), requires_grad=True)},
+ {'probs': variable([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True)},
+ {'probs': variable([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)},
]),
Example(Pareto, [
{
@@ -225,20 +225,20 @@ EXAMPLES = [
'alpha': 1.0
},
{
- 'scale': Variable(torch.randn(5, 5).abs(), requires_grad=True),
- 'alpha': Variable(torch.randn(5, 5).abs(), requires_grad=True)
+ 'scale': variable(torch.randn(5, 5).abs(), requires_grad=True),
+ 'alpha': variable(torch.randn(5, 5).abs(), requires_grad=True)
},
{
- 'scale': torch.tensor([1.0]),
+ 'scale': variable([1.0]),
'alpha': 1.0
}
]),
Example(Poisson, [
{
- 'rate': Variable(torch.randn(5, 5).abs(), requires_grad=True),
+ 'rate': variable(torch.randn(5, 5).abs(), requires_grad=True),
},
{
- 'rate': Variable(torch.randn(3).abs(), requires_grad=True),
+ 'rate': variable(torch.randn(3).abs(), requires_grad=True),
},
{
'rate': 0.2,
@@ -246,66 +246,263 @@ EXAMPLES = [
]),
Example(RelaxedBernoulli, [
{
- 'temperature': Variable(torch.Tensor([0.5]), requires_grad=True),
- 'probs': Variable(torch.Tensor([0.7, 0.2, 0.4]), requires_grad=True),
+ 'temperature': variable([0.5], requires_grad=True),
+ 'probs': variable([0.7, 0.2, 0.4], requires_grad=True),
},
{
- 'temperature': Variable(torch.Tensor([2.0])),
- 'probs': Variable(torch.Tensor([0.3])),
+ 'temperature': variable([2.0]),
+ 'probs': variable([0.3]),
},
{
- 'temperature': Variable(torch.Tensor([7.2])),
- 'logits': Variable(torch.Tensor([-2.0, 2.0, 1.0, 5.0]))
+ 'temperature': variable([7.2]),
+ 'logits': variable([-2.0, 2.0, 1.0, 5.0])
}
]),
Example(RelaxedOneHotCategorical, [
{
- 'temperature': Variable(torch.Tensor([0.5]), requires_grad=True),
- 'probs': Variable(torch.Tensor([[0.1, 0.2, 0.7], [0.5, 0.3, 0.2]]), requires_grad=True)
+ 'temperature': variable([0.5], requires_grad=True),
+ 'probs': variable([[0.1, 0.2, 0.7], [0.5, 0.3, 0.2]], requires_grad=True)
},
{
- 'temperature': Variable(torch.Tensor([2.0])),
- 'probs': Variable(torch.Tensor([[1.0, 0.0], [0.0, 1.0]]))
+ 'temperature': variable([2.0]),
+ 'probs': variable([[1.0, 0.0], [0.0, 1.0]])
},
{
- 'temperature': Variable(torch.Tensor([7.2])),
- 'logits': Variable(torch.Tensor([[-2.0, 2.0], [1.0, 5.0]]))
+ 'temperature': variable([7.2]),
+ 'logits': variable([[-2.0, 2.0], [1.0, 5.0]])
}
]),
Example(TransformedDistribution, [
{
- 'base_distribution': Normal(Variable(torch.randn(2, 3), requires_grad=True),
- Variable(torch.randn(2, 3).abs(), requires_grad=True)),
+ 'base_distribution': Normal(variable(torch.randn(2, 3), requires_grad=True),
+ variable(torch.randn(2, 3).abs(), requires_grad=True)),
'transforms': [],
},
{
- 'base_distribution': Normal(Variable(torch.randn(2, 3), requires_grad=True),
- Variable(torch.randn(2, 3).abs(), requires_grad=True)),
+ 'base_distribution': Normal(variable(torch.randn(2, 3), requires_grad=True),
+ variable(torch.randn(2, 3).abs(), requires_grad=True)),
'transforms': ExpTransform(),
},
{
- 'base_distribution': Normal(Variable(torch.randn(2, 3, 5), requires_grad=True),
- Variable(torch.randn(2, 3, 5).abs(), requires_grad=True)),
- 'transforms': [AffineTransform(Variable(torch.randn(3, 5)), Variable(torch.randn(3, 5))),
+ 'base_distribution': Normal(variable(torch.randn(2, 3, 5), requires_grad=True),
+ variable(torch.randn(2, 3, 5).abs(), requires_grad=True)),
+ 'transforms': [AffineTransform(variable(torch.randn(3, 5)), variable(torch.randn(3, 5))),
ExpTransform()],
},
]),
Example(Uniform, [
{
- 'low': Variable(torch.zeros(5, 5), requires_grad=True),
- 'high': Variable(torch.ones(5, 5), requires_grad=True),
+ 'low': variable(torch.zeros(5, 5), requires_grad=True),
+ 'high': variable(torch.ones(5, 5), requires_grad=True),
},
{
- 'low': Variable(torch.zeros(1), requires_grad=True),
- 'high': Variable(torch.ones(1), requires_grad=True),
+ 'low': variable(torch.zeros(1), requires_grad=True),
+ 'high': variable(torch.ones(1), requires_grad=True),
},
{
- 'low': Variable(torch.Tensor([1.0, 1.0]), requires_grad=True),
- 'high': Variable(torch.Tensor([2.0, 3.0]), requires_grad=True),
+ 'low': variable([1.0, 1.0], requires_grad=True),
+ 'high': variable([2.0, 3.0], requires_grad=True),
},
]),
]
+BAD_EXAMPLES = [
+ Example(Bernoulli, [
+ {'probs': variable([1.1, 0.2, 0.4], requires_grad=True)},
+ {'probs': variable([-0.5], requires_grad=True)},
+ {'probs': 1.00001},
+ ]),
+ Example(Beta, [
+ {
+ 'concentration1': variable([0.0], requires_grad=True),
+ 'concentration0': variable([0.0], requires_grad=True),
+ },
+ {
+ 'concentration1': variable([-1.0], requires_grad=True),
+ 'concentration0': variable([-2.0], requires_grad=True),
+ },
+ ]),
+ Example(Geometric, [
+ {'probs': variable([1.1, 0.2, 0.4], requires_grad=True)},
+ {'probs': variable([-0.3], requires_grad=True)},
+ {'probs': 1.00000001},
+ ]),
+ Example(Categorical, [
+ {'probs': variable([[-0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True)},
+ {'probs': variable([[-1.0, 10.0], [0.0, -1.0]], requires_grad=True)},
+ ]),
+ Example(Binomial, [
+ {'probs': variable([[-0.0000001, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True),
+ 'total_count': 10},
+ {'probs': variable([[1.0, 0.0], [0.0, 2.0]], requires_grad=True),
+ 'total_count': 10},
+ ]),
+ Example(Cauchy, [
+ {'loc': 0.0, 'scale': -1.0},
+ {'loc': variable([0.0]), 'scale': 0.0},
+ {'loc': variable([[0.0], [-2.0]]),
+ 'scale': variable([[-0.000001], [1.0]])}
+ ]),
+ Example(Chi2, [
+ {'df': variable([0], requires_grad=True)},
+ {'df': variable([-2], requires_grad=True)},
+ ]),
+ Example(StudentT, [
+ {'df': variable([0], requires_grad=True)},
+ {'df': variable([-2], requires_grad=True)},
+ ]),
+ Example(Dirichlet, [
+ {'concentration': variable([0], requires_grad=True)},
+ {'concentration': variable([-2], requires_grad=True)}
+ ]),
+ Example(Exponential, [
+ {'rate': variable([0, 0], requires_grad=True)},
+ {'rate': variable([-2], requires_grad=True)}
+ ]),
+ Example(FisherSnedecor, [
+ {
+ 'df1': variable([0, 0], requires_grad=True),
+ 'df2': variable([-1, -100], requires_grad=True),
+ },
+ {
+ 'df1': variable([1, 1], requires_grad=True),
+ 'df2': variable([0, 0], requires_grad=True),
+ }
+ ]),
+ Example(Gamma, [
+ {
+ 'concentration': variable([0, 0], requires_grad=True),
+ 'rate': variable([-1, -100], requires_grad=True),
+ },
+ {
+ 'concentration': variable([1, 1], requires_grad=True),
+ 'rate': variable([0, 0], requires_grad=True),
+ }
+ ]),
+ Example(Gumbel, [
+ {
+ 'loc': variable([1, 1], requires_grad=True),
+ 'scale': variable([0, 1], requires_grad=True),
+ },
+ {
+ 'loc': variable([1, 1], requires_grad=True),
+ 'scale': variable([1, -1], requires_grad=True),
+ },
+ ]),
+ Example(Laplace, [
+ {
+ 'loc': variable([1, 1], requires_grad=True),
+ 'scale': variable([0, 1], requires_grad=True),
+ },
+ {
+ 'loc': variable([1, 1], requires_grad=True),
+ 'scale': variable([1, -1], requires_grad=True),
+ },
+ ]),
+ Example(LogNormal, [
+ {
+ 'loc': variable([1, 1], requires_grad=True),
+ 'scale': variable([0, 1], requires_grad=True),
+ },
+ {
+ 'loc': variable([1, 1], requires_grad=True),
+ 'scale': variable([1, -1], requires_grad=True),
+ },
+ ]),
+ Example(Normal, [
+ {
+ 'loc': variable([1, 1], requires_grad=True),
+ 'scale': variable([0, 1], requires_grad=True),
+ },
+ {
+ 'loc': variable([1, 1], requires_grad=True),
+ 'scale': variable([1, -1], requires_grad=True),
+ },
+ {
+ 'loc': variable([1.0, 0.0], requires_grad=True),
+ 'scale': variable([1e-5, -1e-5], requires_grad=True),
+ },
+ ]),
+ Example(OneHotCategorical, [
+ {'probs': variable([[0.1, 0.2, 0.3], [0.1, -10.0, 0.2]], requires_grad=True)},
+ {'probs': variable([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
+ ]),
+ Example(Pareto, [
+ {
+ 'scale': 0.0,
+ 'alpha': 0.0
+ },
+ {
+ 'scale': variable([0.0, 0.0], requires_grad=True),
+ 'alpha': variable([-1e-5, 0.0], requires_grad=True)
+ },
+ {
+ 'scale': variable([1.0]),
+ 'alpha': -1.0
+ }
+ ]),
+ Example(Poisson, [
+ {
+ 'rate': variable([0.0], requires_grad=True),
+ },
+ {
+ 'rate': -1.0,
+ }
+ ]),
+ Example(RelaxedBernoulli, [
+ {
+ 'temperature': variable([1.5], requires_grad=True),
+ 'probs': variable([1.7, 0.2, 0.4], requires_grad=True),
+ },
+ {
+ 'temperature': variable([2.0]),
+ 'probs': variable([-1.0]),
+ }
+ ]),
+ Example(RelaxedOneHotCategorical, [
+ {
+ 'temperature': variable([0.5], requires_grad=True),
+ 'probs': variable([[-0.1, 0.2, 0.7], [0.5, 0.3, 0.2]], requires_grad=True)
+ },
+ {
+ 'temperature': variable([2.0]),
+ 'probs': variable([[-1.0, 0.0], [-1.0, 1.1]])
+ }
+ ]),
+ Example(TransformedDistribution, [
+ {
+ 'base_distribution': Normal(variable([1, 1], requires_grad=True),
+ variable([0, 1], requires_grad=True)),
+ 'transforms': [],
+ },
+ {
+ 'base_distribution': Normal(variable([1, 1], requires_grad=True),
+ variable([-1, -1], requires_grad=True)),
+ 'transforms': ExpTransform(),
+ },
+ {
+ 'base_distribution': Normal(variable([1, 1, 0], requires_grad=True),
+ variable([-1, -2, 3], requires_grad=True)),
+ 'transforms': [AffineTransform(variable(torch.randn(3, 5)), variable(torch.randn(3, 5))),
+ ExpTransform()],
+ },
+ ]),
+ Example(Uniform, [
+ {
+ 'low': variable([2.0], requires_grad=True),
+ 'high': variable([2.0], requires_grad=True),
+ },
+ {
+ 'low': variable([0.0], requires_grad=True),
+ 'high': variable([0.0], requires_grad=True),
+ },
+ {
+ 'low': variable([1.0], requires_grad=True),
+ 'high': variable([0.0], requires_grad=True),
+ }
+ ])
+]
+
def unwrap(value):
if isinstance(value, Variable):
@@ -385,8 +582,8 @@ class TestDistributions(TestCase):
expected = torch.Tensor(expected)
actual = dist(param).enumerate_support()
self.assertEqual(actual, expected)
- param = Variable(param)
- expected = Variable(expected)
+ param = variable(param)
+ expected = variable(expected)
actual = dist(param).enumerate_support()
self.assertEqual(actual, expected)
@@ -468,15 +665,15 @@ class TestDistributions(TestCase):
self._check_enumerate_support(Bernoulli, examples)
def test_bernoulli_3d(self):
- p = Variable(torch.Tensor(2, 3, 5).fill_(0.5), requires_grad=True)
+ p = variable(torch.Tensor(2, 3, 5).fill_(0.5), requires_grad=True)
self.assertEqual(Bernoulli(p).sample().size(), (2, 3, 5))
self.assertEqual(Bernoulli(p).sample(sample_shape=(2, 5)).size(),
(2, 5, 2, 3, 5))
self.assertEqual(Bernoulli(p).sample((2,)).size(), (2, 2, 3, 5))
def test_geometric(self):
- p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True)
- r = torch.tensor(0.3, requires_grad=True)
+ p = variable([0.7, 0.2, 0.4], requires_grad=True)
+ r = variable(0.3, requires_grad=True)
s = 0.3
self.assertEqual(Geometric(p).sample((8,)).size(), (8, 3))
self.assertEqual(Geometric(1).sample(), 0)
@@ -493,7 +690,7 @@ class TestDistributions(TestCase):
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_geometric_log_prob_and_entropy(self):
- p = Variable(torch.Tensor([0.7, 0.2, 0.4]), requires_grad=True)
+ p = variable([0.7, 0.2, 0.4], requires_grad=True)
s = 0.3
def ref_log_prob(idx, val, log_prob):
@@ -516,7 +713,7 @@ class TestDistributions(TestCase):
'Geometric(prob={})'.format(prob))
def test_binomial(self):
- p = Variable(torch.arange(0.05, 1, 0.1), requires_grad=True)
+ p = variable(torch.arange(0.05, 1, 0.1), requires_grad=True)
for total_count in [1, 2, 10]:
self._gradcheck_log_prob(lambda p: Binomial(total_count, p), [p])
self._gradcheck_log_prob(lambda p: Binomial(total_count, None, p.log()), [p])
@@ -525,7 +722,7 @@ class TestDistributions(TestCase):
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_binomial_log_prob(self):
- probs = Variable(torch.arange(0.05, 1, 0.1))
+ probs = variable(torch.arange(0.05, 1, 0.1))
for total_count in [1, 2, 10]:
def ref_log_prob(idx, x, log_prob):
@@ -550,7 +747,7 @@ class TestDistributions(TestCase):
def test_multinomial_1d(self):
total_count = 10
- p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True)
+ p = variable([0.1, 0.2, 0.3], requires_grad=True)
self.assertEqual(Multinomial(total_count, p).sample().size(), (3,))
self.assertEqual(Multinomial(total_count, p).sample((2, 2)).size(), (2, 2, 3))
self.assertEqual(Multinomial(total_count, p).sample((1,)).size(), (1, 3))
@@ -561,7 +758,7 @@ class TestDistributions(TestCase):
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_multinomial_1d_log_prob(self):
total_count = 10
- p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True)
+ p = variable([0.1, 0.2, 0.3], requires_grad=True)
dist = Multinomial(total_count, probs=p)
x = dist.sample()
log_prob = dist.log_prob(x)
@@ -578,8 +775,8 @@ class TestDistributions(TestCase):
total_count = 10
probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
- p = Variable(torch.Tensor(probabilities), requires_grad=True)
- s = Variable(torch.Tensor(probabilities_1), requires_grad=True)
+ p = variable(probabilities, requires_grad=True)
+ s = variable(probabilities_1, requires_grad=True)
self.assertEqual(Multinomial(total_count, p).sample().size(), (2, 3))
self.assertEqual(Multinomial(total_count, p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3))
self.assertEqual(Multinomial(total_count, p).sample((6,)).size(), (6, 2, 3))
@@ -596,7 +793,7 @@ class TestDistributions(TestCase):
self.assertRaises(NotImplementedError, Multinomial(10, p).entropy)
def test_categorical_1d(self):
- p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True)
+ p = variable([0.1, 0.2, 0.3], requires_grad=True)
self.assertTrue(is_all_nan(Categorical(p).mean))
self.assertTrue(is_all_nan(Categorical(p).variance))
self.assertEqual(Categorical(p).sample().size(), ())
@@ -609,8 +806,8 @@ class TestDistributions(TestCase):
def test_categorical_2d(self):
probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
- p = Variable(torch.Tensor(probabilities), requires_grad=True)
- s = Variable(torch.Tensor(probabilities_1), requires_grad=True)
+ p = variable(probabilities, requires_grad=True)
+ s = variable(probabilities_1, requires_grad=True)
self.assertEqual(Categorical(p).mean.size(), (2,))
self.assertEqual(Categorical(p).variance.size(), (2,))
self.assertTrue(is_all_nan(Categorical(p).mean))
@@ -644,7 +841,7 @@ class TestDistributions(TestCase):
self._check_enumerate_support(Categorical, examples)
def test_one_hot_categorical_1d(self):
- p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True)
+ p = variable([0.1, 0.2, 0.3], requires_grad=True)
self.assertEqual(OneHotCategorical(p).sample().size(), (3,))
self.assertTrue(isinstance(OneHotCategorical(p).sample().data, torch.Tensor))
self.assertEqual(OneHotCategorical(p).sample((2, 2)).size(), (2, 2, 3))
@@ -655,8 +852,8 @@ class TestDistributions(TestCase):
def test_one_hot_categorical_2d(self):
probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
- p = Variable(torch.Tensor(probabilities), requires_grad=True)
- s = Variable(torch.Tensor(probabilities_1), requires_grad=True)
+ p = variable(probabilities, requires_grad=True)
+ s = variable(probabilities_1, requires_grad=True)
self.assertEqual(OneHotCategorical(p).sample().size(), (2, 3))
self.assertEqual(OneHotCategorical(p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3))
self.assertEqual(OneHotCategorical(p).sample((6,)).size(), (6, 2, 3))
@@ -674,8 +871,8 @@ class TestDistributions(TestCase):
self._check_enumerate_support(OneHotCategorical, examples)
def test_poisson_shape(self):
- rate = Variable(torch.randn(2, 3).abs(), requires_grad=True)
- rate_1d = Variable(torch.randn(1).abs(), requires_grad=True)
+ rate = variable(torch.randn(2, 3).abs(), requires_grad=True)
+ rate_1d = variable(torch.randn(1).abs(), requires_grad=True)
self.assertEqual(Poisson(rate).sample().size(), (2, 3))
self.assertEqual(Poisson(rate).sample((7,)).size(), (7, 2, 3))
self.assertEqual(Poisson(rate_1d).sample().size(), (1,))
@@ -684,8 +881,8 @@ class TestDistributions(TestCase):
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_poisson_log_prob(self):
- rate = Variable(torch.randn(2, 3).abs(), requires_grad=True)
- rate_1d = Variable(torch.randn(1).abs(), requires_grad=True)
+ rate = variable(torch.randn(2, 3).abs(), requires_grad=True)
+ rate_1d = variable(torch.randn(1).abs(), requires_grad=True)
def ref_log_prob(idx, x, log_prob):
l = rate.data.view(-1)[idx]
@@ -758,7 +955,7 @@ class TestDistributions(TestCase):
self.assertEqual(equal_probs, s)
def test_relaxed_one_hot_categorical_1d(self):
- p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True)
+ p = variable([0.1, 0.2, 0.3], requires_grad=True)
temp = torch.tensor(0.67, requires_grad=True)
self.assertEqual(RelaxedOneHotCategorical(probs=p, temperature=temp).sample().size(), (3,))
self.assertTrue(isinstance(RelaxedOneHotCategorical(probs=p, temperature=temp).sample().data, torch.Tensor))
@@ -769,10 +966,10 @@ class TestDistributions(TestCase):
def test_relaxed_one_hot_categorical_2d(self):
probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
- temp = Variable(torch.Tensor([3.00]), requires_grad=True)
- temp_2 = Variable(torch.Tensor([0.2]), requires_grad=True)
- p = Variable(torch.Tensor(probabilities), requires_grad=True)
- s = Variable(torch.Tensor(probabilities_1), requires_grad=True)
+ temp = variable([3.00], requires_grad=True)
+ temp_2 = variable([0.2], requires_grad=True)
+ p = variable(probabilities, requires_grad=True)
+ s = variable(probabilities_1, requires_grad=True)
self.assertEqual(RelaxedOneHotCategorical(temp, p).sample().size(), (2, 3))
self.assertEqual(RelaxedOneHotCategorical(temp, p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3))
self.assertEqual(RelaxedOneHotCategorical(temp, p).sample_n(6).size(), (6, 2, 3))
@@ -814,10 +1011,10 @@ class TestDistributions(TestCase):
self.assertEqual(equal_probs, s)
def test_uniform(self):
- low = Variable(torch.zeros(5, 5), requires_grad=True)
- high = Variable(torch.ones(5, 5) * 3, requires_grad=True)
- low_1d = Variable(torch.zeros(1), requires_grad=True)
- high_1d = Variable(torch.ones(1) * 3, requires_grad=True)
+ low = variable(torch.zeros(5, 5), requires_grad=True)
+ high = variable(torch.ones(5, 5) * 3, requires_grad=True)
+ low_1d = variable(torch.zeros(1), requires_grad=True)
+ high_1d = variable(torch.ones(1) * 3, requires_grad=True)
self.assertEqual(Uniform(low, high).sample().size(), (5, 5))
self.assertEqual(Uniform(low, high).sample((7,)).size(), (7, 5, 5))
self.assertEqual(Uniform(low_1d, high_1d).sample().size(), (1,))
@@ -826,8 +1023,8 @@ class TestDistributions(TestCase):
# Check log_prob computation when value outside range
uniform = Uniform(low_1d, high_1d)
- above_high = Variable(torch.Tensor([4.0]))
- below_low = Variable(torch.Tensor([-1.0]))
+ above_high = variable([4.0])
+ below_low = variable([-1.0])
self.assertEqual(uniform.log_prob(above_high).item(), -float('inf'), allow_inf=True)
self.assertEqual(uniform.log_prob(below_low).item(), -float('inf'), allow_inf=True)
@@ -847,10 +1044,10 @@ class TestDistributions(TestCase):
high.grad.zero_()
def test_cauchy(self):
- loc = Variable(torch.zeros(5, 5), requires_grad=True)
- scale = Variable(torch.ones(5, 5), requires_grad=True)
- loc_1d = Variable(torch.zeros(1), requires_grad=True)
- scale_1d = Variable(torch.ones(1), requires_grad=True)
+ loc = variable(torch.zeros(5, 5), requires_grad=True)
+ scale = variable(torch.ones(5, 5), requires_grad=True)
+ loc_1d = variable(torch.zeros(1), requires_grad=True)
+ scale_1d = variable(torch.ones(1), requires_grad=True)
self.assertTrue(is_all_nan(Cauchy(loc_1d, scale_1d).mean))
self.assertEqual(Cauchy(loc_1d, scale_1d).variance, float('inf'), allow_inf=True)
self.assertEqual(Cauchy(loc, scale).sample().size(), (5, 5))
@@ -875,12 +1072,12 @@ class TestDistributions(TestCase):
scale.grad.zero_()
def test_lognormal(self):
- mean = Variable(torch.randn(5, 5), requires_grad=True)
- std = Variable(torch.randn(5, 5).abs(), requires_grad=True)
- mean_1d = Variable(torch.randn(1), requires_grad=True)
- std_1d = Variable(torch.randn(1), requires_grad=True)
- mean_delta = torch.Tensor([1.0, 0.0])
- std_delta = torch.Tensor([1e-5, 1e-5])
+ mean = variable(torch.randn(5, 5), requires_grad=True)
+ std = variable(torch.randn(5, 5).abs(), requires_grad=True)
+ mean_1d = variable(torch.randn(1), requires_grad=True)
+ std_1d = variable(torch.randn(1), requires_grad=True)
+ mean_delta = variable([1.0, 0.0])
+ std_delta = variable([1e-5, 1e-5])
self.assertEqual(LogNormal(mean, std).sample().size(), (5, 5))
self.assertEqual(LogNormal(mean, std).sample((7,)).size(), (7, 5, 5))
self.assertEqual(LogNormal(mean_1d, std_1d).sample((1,)).size(), (1, 1))
@@ -900,8 +1097,8 @@ class TestDistributions(TestCase):
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_lognormal_logprob(self):
- mean = Variable(torch.randn(5, 1), requires_grad=True)
- std = Variable(torch.randn(5, 1).abs(), requires_grad=True)
+ mean = variable(torch.randn(5, 1), requires_grad=True)
+ std = variable(torch.randn(5, 1).abs(), requires_grad=True)
def ref_log_prob(idx, x, log_prob):
m = mean.data.view(-1)[idx]
@@ -920,12 +1117,12 @@ class TestDistributions(TestCase):
'LogNormal(loc={}, scale={})'.format(mean, std))
def test_normal(self):
- loc = Variable(torch.randn(5, 5), requires_grad=True)
- scale = Variable(torch.randn(5, 5).abs(), requires_grad=True)
- loc_1d = Variable(torch.randn(1), requires_grad=True)
- scale_1d = Variable(torch.randn(1), requires_grad=True)
- loc_delta = torch.Tensor([1.0, 0.0])
- scale_delta = torch.Tensor([1e-5, 1e-5])
+ loc = variable(torch.randn(5, 5), requires_grad=True)
+ scale = variable(torch.randn(5, 5).abs(), requires_grad=True)
+ loc_1d = variable(torch.randn(1), requires_grad=True)
+ scale_1d = variable(torch.randn(1), requires_grad=True)
+ loc_delta = variable([1.0, 0.0])
+ scale_delta = variable([1e-5, 1e-5])
self.assertEqual(Normal(loc, scale).sample().size(), (5, 5))
self.assertEqual(Normal(loc, scale).sample((7,)).size(), (7, 5, 5))
self.assertEqual(Normal(loc_1d, scale_1d).sample((1,)).size(), (1, 1))
@@ -972,8 +1169,8 @@ class TestDistributions(TestCase):
'Normal(mean={}, std={})'.format(loc, scale))
def test_exponential(self):
- rate = Variable(torch.randn(5, 5).abs(), requires_grad=True)
- rate_1d = Variable(torch.randn(1).abs(), requires_grad=True)
+ rate = variable(torch.randn(5, 5).abs(), requires_grad=True)
+ rate_1d = variable(torch.randn(1).abs(), requires_grad=True)
self.assertEqual(Exponential(rate).sample().size(), (5, 5))
self.assertEqual(Exponential(rate).sample((7,)).size(), (7, 5, 5))
self.assertEqual(Exponential(rate_1d).sample((1,)).size(), (1, 1))
@@ -1007,10 +1204,10 @@ class TestDistributions(TestCase):
'Exponential(rate={})'.format(rate))
def test_laplace(self):
- loc = Variable(torch.randn(5, 5), requires_grad=True)
- scale = Variable(torch.randn(5, 5).abs(), requires_grad=True)
- loc_1d = Variable(torch.randn(1), requires_grad=True)
- scale_1d = Variable(torch.randn(1), requires_grad=True)
+ loc = variable(torch.randn(5, 5), requires_grad=True)
+ scale = variable(torch.randn(5, 5).abs(), requires_grad=True)
+ loc_1d = variable(torch.randn(1), requires_grad=True)
+ scale_1d = variable(torch.randn(1), requires_grad=True)
loc_delta = torch.Tensor([1.0, 0.0])
scale_delta = torch.Tensor([1e-5, 1e-5])
self.assertEqual(Laplace(loc, scale).sample().size(), (5, 5))
@@ -1059,10 +1256,10 @@ class TestDistributions(TestCase):
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_gamma_shape(self):
- alpha = Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
- beta = Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
- alpha_1d = Variable(torch.exp(torch.randn(1)), requires_grad=True)
- beta_1d = Variable(torch.exp(torch.randn(1)), requires_grad=True)
+ alpha = variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
+ beta = variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
+ alpha_1d = variable(torch.exp(torch.randn(1)), requires_grad=True)
+ beta_1d = variable(torch.exp(torch.randn(1)), requires_grad=True)
self.assertEqual(Gamma(alpha, beta).sample().size(), (2, 3))
self.assertEqual(Gamma(alpha, beta).sample((5,)).size(), (5, 2, 3))
self.assertEqual(Gamma(alpha_1d, beta_1d).sample((1,)).size(), (1, 1))
@@ -1088,10 +1285,10 @@ class TestDistributions(TestCase):
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_pareto(self):
- scale = Variable(torch.randn(2, 3).abs(), requires_grad=True)
- alpha = Variable(torch.randn(2, 3).abs(), requires_grad=True)
- scale_1d = Variable(torch.randn(1).abs(), requires_grad=True)
- alpha_1d = Variable(torch.randn(1).abs(), requires_grad=True)
+ scale = variable(torch.randn(2, 3).abs(), requires_grad=True)
+ alpha = variable(torch.randn(2, 3).abs(), requires_grad=True)
+ scale_1d = variable(torch.randn(1).abs(), requires_grad=True)
+ alpha_1d = variable(torch.randn(1).abs(), requires_grad=True)
self.assertEqual(Pareto(scale_1d, 0.5).mean, float('inf'), allow_inf=True)
self.assertEqual(Pareto(scale_1d, 0.5).variance, float('inf'), allow_inf=True)
self.assertEqual(Pareto(scale, alpha).sample().size(), (2, 3))
@@ -1119,10 +1316,10 @@ class TestDistributions(TestCase):
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_gumbel(self):
- loc = Variable(torch.randn(2, 3), requires_grad=True)
- scale = Variable(torch.randn(2, 3).abs(), requires_grad=True)
- loc_1d = Variable(torch.randn(1), requires_grad=True)
- scale_1d = Variable(torch.randn(1).abs(), requires_grad=True)
+ loc = variable(torch.randn(2, 3), requires_grad=True)
+ scale = variable(torch.randn(2, 3).abs(), requires_grad=True)
+ loc_1d = variable(torch.randn(1), requires_grad=True)
+ scale_1d = variable(torch.randn(1).abs(), requires_grad=True)
self.assertEqual(Gumbel(loc, scale).sample().size(), (2, 3))
self.assertEqual(Gumbel(loc, scale).sample((5,)).size(), (5, 2, 3))
self.assertEqual(Gumbel(loc_1d, scale_1d).sample().size(), (1,))
@@ -1148,8 +1345,8 @@ class TestDistributions(TestCase):
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_fishersnedecor(self):
- df1 = Variable(torch.randn(2, 3).abs(), requires_grad=True)
- df2 = Variable(torch.randn(2, 3).abs(), requires_grad=True)
+ df1 = variable(torch.randn(2, 3).abs(), requires_grad=True)
+ df2 = variable(torch.randn(2, 3).abs(), requires_grad=True)
df1_1d = torch.randn(1).abs()
df2_1d = torch.randn(1).abs()
self.assertTrue(is_all_nan(FisherSnedecor(1, 2).mean))
@@ -1179,8 +1376,8 @@ class TestDistributions(TestCase):
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_chi2_shape(self):
- df = Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
- df_1d = Variable(torch.exp(torch.randn(1)), requires_grad=True)
+ df = variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
+ df_1d = variable(torch.exp(torch.randn(1)), requires_grad=True)
self.assertEqual(Chi2(df).sample().size(), (2, 3))
self.assertEqual(Chi2(df).sample((5,)).size(), (5, 2, 3))
self.assertEqual(Chi2(df_1d).sample((1,)).size(), (1, 1))
@@ -1206,8 +1403,8 @@ class TestDistributions(TestCase):
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_studentT(self):
- df = Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
- df_1d = Variable(torch.exp(torch.randn(1)), requires_grad=True)
+ df = variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
+ df_1d = variable(torch.exp(torch.randn(1)), requires_grad=True)
self.assertTrue(is_all_nan(StudentT(1).mean))
self.assertTrue(is_all_nan(StudentT(1).variance))
self.assertEqual(StudentT(2).variance, float('inf'), allow_inf=True)
@@ -1247,8 +1444,8 @@ class TestDistributions(TestCase):
self.assertAlmostEqual(float(actual_log_prob[i]), float(expected_log_prob), places=3)
def test_dirichlet_shape(self):
- alpha = Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
- alpha_1d = Variable(torch.exp(torch.randn(4)), requires_grad=True)
+ alpha = variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
+ alpha_1d = variable(torch.exp(torch.randn(4)), requires_grad=True)
self.assertEqual(Dirichlet(alpha).sample().size(), (2, 3))
self.assertEqual(Dirichlet(alpha).sample((5,)).size(), (5, 2, 3))
self.assertEqual(Dirichlet(alpha_1d).sample().size(), (4,))
@@ -1275,10 +1472,10 @@ class TestDistributions(TestCase):
multivariate=True)
def test_beta_shape(self):
- con1 = Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
- con0 = Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
- con1_1d = Variable(torch.exp(torch.randn(4)), requires_grad=True)
- con0_1d = Variable(torch.exp(torch.randn(4)), requires_grad=True)
+ con1 = variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
+ con0 = variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
+ con1_1d = variable(torch.exp(torch.randn(4)), requires_grad=True)
+ con0_1d = variable(torch.exp(torch.randn(4)), requires_grad=True)
self.assertEqual(Beta(con1, con0).sample().size(), (2, 3))
self.assertEqual(Beta(con1, con0).sample((5,)).size(), (5, 2, 3))
self.assertEqual(Beta(con1_1d, con0_1d).sample().size(), (4,))
@@ -1333,7 +1530,7 @@ class TestDistributions(TestCase):
for Dist, params in EXAMPLES:
for i, param in enumerate(params):
dist = Dist(**param)
- samples = Variable(dist.sample().data, requires_grad=True)
+ samples = variable(dist.sample().data, requires_grad=True)
try:
cdfs = dist.cdf(samples)
pdfs = dist.log_prob(samples).exp()
@@ -1502,8 +1699,8 @@ class TestRsample(TestCase):
def test_gamma(self):
num_samples = 100
for alpha in [1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]:
- alphas = Variable(torch.FloatTensor([alpha] * num_samples), requires_grad=True)
- betas = Variable(torch.ones(num_samples).type_as(alphas))
+ alphas = variable(torch.FloatTensor([alpha] * num_samples), requires_grad=True)
+ betas = variable(torch.ones(num_samples).type_as(alphas))
x = Gamma(alphas, betas).rsample()
x.sum().backward()
x, ind = x.data.sort()
@@ -1531,7 +1728,7 @@ class TestRsample(TestCase):
def test_chi2(self):
num_samples = 100
for df in [1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]:
- dfs = Variable(torch.FloatTensor([df] * num_samples), requires_grad=True)
+ dfs = variable(torch.FloatTensor([df] * num_samples), requires_grad=True)
x = Chi2(dfs).rsample()
x.sum().backward()
x, ind = x.data.sort()
@@ -1559,7 +1756,7 @@ class TestRsample(TestCase):
num_samples = 20
grid = [1e-1, 1e0, 1e1]
for a0, a1, a2 in product(grid, grid, grid):
- alphas = Variable(torch.FloatTensor([[a0, a1, a2]] * num_samples), requires_grad=True)
+ alphas = variable(torch.FloatTensor([[a0, a1, a2]] * num_samples), requires_grad=True)
x = Dirichlet(alphas).rsample()[:, 0]
x.sum().backward()
x, ind = x.data.sort()
@@ -1590,8 +1787,8 @@ class TestRsample(TestCase):
num_samples = 20
grid = [1e-2, 1e-1, 1e0, 1e1, 1e2]
for con1, con0 in product(grid, grid):
- con1s = Variable(torch.FloatTensor([con1] * num_samples), requires_grad=True)
- con0s = Variable(torch.FloatTensor([con0] * num_samples).type_as(con1s))
+ con1s = variable(torch.FloatTensor([con1] * num_samples), requires_grad=True)
+ con0s = variable(torch.FloatTensor([con0] * num_samples).type_as(con1s))
x = Beta(con1s, con0s).rsample()
x.sum().backward()
x, ind = x.data.sort()
@@ -1620,8 +1817,8 @@ class TestRsample(TestCase):
num_samples = 20
grid = [1e-2, 1e-1, 1e0, 1e1, 1e2]
for con1, con0 in product(grid, grid):
- con0s = Variable(torch.FloatTensor([con0] * num_samples), requires_grad=True)
- con1s = Variable(torch.FloatTensor([con1] * num_samples).type_as(con0s))
+ con0s = variable(torch.FloatTensor([con0] * num_samples), requires_grad=True)
+ con1s = variable(torch.FloatTensor([con1] * num_samples).type_as(con0s))
x = Beta(con1s, con0s).rsample()
x.sum().backward()
x, ind = x.data.sort()
@@ -1650,7 +1847,7 @@ class TestRsample(TestCase):
num_samples = 100000
for shift in [-0.1, -0.05, -0.01, 0.0, 0.01, 0.05, 0.10]:
alpha = alpha_crit + shift
- alpha = Variable(torch.FloatTensor([alpha]), requires_grad=True)
+ alpha = variable(torch.FloatTensor([alpha]), requires_grad=True)
alpha_vec = torch.cat([alpha, alpha, alpha.new([1])])
z = Dirichlet(alpha_vec.expand(num_samples, 3)).rsample()
mean_z3 = 1.0 / (2.0 * alpha + 1.0)
@@ -1679,7 +1876,7 @@ class TestRsample(TestCase):
], dim=-1)
for a1, a2, a3 in product(alpha_grid, alpha_grid, alpha_grid):
- alpha = Variable(torch.Tensor([a1, a2, a3]).expand(num_samples, 3), requires_grad=True)
+ alpha = variable([a1, a2, a3], requires_grad=True).expand(num_samples, 3)
x = Dirichlet(alpha).rsample()
dlogp_da = grad([Dirichlet(alpha).log_prob(x.detach()).sum()],
[alpha], retain_graph=True)[0].data[:, 0]
@@ -1711,13 +1908,18 @@ class TestDistributionShapes(TestCase):
def setUp(self):
super(TestCase, self).setUp()
self.scalar_sample = 1
- self.tensor_sample_1 = Variable(torch.ones(3, 2))
- self.tensor_sample_2 = Variable(torch.ones(3, 2, 3))
+ self.tensor_sample_1 = variable(torch.ones(3, 2))
+ self.tensor_sample_2 = variable(torch.ones(3, 2, 3))
+ Distribution.set_default_validate_args(True)
+
+ def tearDown(self):
+ super(TestCase, self).tearDown()
+ Distribution.set_default_validate_args(False)
def test_entropy_shape(self):
for Dist, params in EXAMPLES:
for i, param in enumerate(params):
- dist = Dist(**param)
+ dist = Dist(validate_args=False, **param)
try:
actual_shape = dist.entropy().size()
expected_shape = dist.batch_shape if dist.batch_shape else torch.Size()
@@ -1745,7 +1947,7 @@ class TestDistributionShapes(TestCase):
self.assertEqual(bernoulli.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
self.assertEqual(bernoulli.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, bernoulli.log_prob, self.tensor_sample_2)
- self.assertEqual(bernoulli.log_prob(Variable(torch.ones(3, 1, 1))).size(), torch.Size((3, 3, 2)))
+ self.assertEqual(bernoulli.log_prob(variable(torch.ones(3, 1, 1))).size(), torch.Size((3, 3, 2)))
def test_geometric_shape_scalar_params(self):
geometric = Geometric(0.3)
@@ -1765,7 +1967,7 @@ class TestDistributionShapes(TestCase):
self.assertEqual(geometric.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
self.assertEqual(geometric.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, geometric.log_prob, self.tensor_sample_2)
- self.assertEqual(geometric.log_prob(Variable(torch.ones(3, 1, 1))).size(), torch.Size((3, 3, 2)))
+ self.assertEqual(geometric.log_prob(variable(torch.ones(3, 1, 1))).size(), torch.Size((3, 3, 2)))
def test_beta_shape_scalar_params(self):
dist = Beta(0.1, 0.1)
@@ -1786,7 +1988,7 @@ class TestDistributionShapes(TestCase):
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
- self.assertEqual(dist.log_prob(Variable(torch.ones(3, 1, 1))).size(), torch.Size((3, 3, 2)))
+ self.assertEqual(dist.log_prob(variable(torch.ones(3, 1, 1))).size(), torch.Size((3, 3, 2)))
def test_binomial_shape(self):
dist = Binomial(10, torch.tensor([0.6, 0.3]))
@@ -1805,7 +2007,7 @@ class TestDistributionShapes(TestCase):
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3,)))
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
- self.assertEqual(dist.log_prob(Variable(torch.ones(3, 1, 2))).size(), torch.Size((3, 3)))
+ self.assertEqual(dist.log_prob(variable(torch.ones(3, 1, 2))).size(), torch.Size((3, 3)))
def test_categorical_shape(self):
# unbatched
@@ -1816,7 +2018,7 @@ class TestDistributionShapes(TestCase):
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2,)))
self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
- self.assertEqual(dist.log_prob(Variable(torch.ones(3, 1))).size(), torch.Size((3, 1)))
+ self.assertEqual(dist.log_prob(variable(torch.ones(3, 1))).size(), torch.Size((3, 1)))
# batched
dist = Categorical(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
self.assertEqual(dist._batch_shape, torch.Size((3,)))
@@ -1825,7 +2027,7 @@ class TestDistributionShapes(TestCase):
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3,)))
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1)
self.assertEqual(dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
- self.assertEqual(dist.log_prob(Variable(torch.ones(3, 1))).size(), torch.Size((3, 3)))
+ self.assertEqual(dist.log_prob(variable(torch.ones(3, 1))).size(), torch.Size((3, 3)))
def test_one_hot_categorical_shape(self):
# unbatched
@@ -1835,19 +2037,23 @@ class TestDistributionShapes(TestCase):
self.assertEqual(dist.sample().size(), torch.Size((3,)))
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3)))
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1)
- self.assertEqual(dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2,)))
+ simplex_sample = self.tensor_sample_2 / self.tensor_sample_2.sum(-1, keepdim=True)
+ self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3, 2,)))
self.assertEqual(dist.log_prob(dist.enumerate_support()).size(), torch.Size((3,)))
- self.assertEqual(dist.log_prob(Variable(torch.ones(3, 3))).size(), torch.Size((3,)))
+ simplex_sample = variable(torch.ones(3, 3)) / 3
+ self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3,)))
# batched
dist = OneHotCategorical(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
self.assertEqual(dist._batch_shape, torch.Size((3,)))
self.assertEqual(dist._event_shape, torch.Size((2,)))
self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
- self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3,)))
+ simplex_sample = self.tensor_sample_1 / self.tensor_sample_1.sum(-1, keepdim=True)
+ self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3,)))
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
self.assertEqual(dist.log_prob(dist.enumerate_support()).size(), torch.Size((2, 3)))
- self.assertEqual(dist.log_prob(Variable(torch.ones((3, 1, 2)))).size(), torch.Size((3, 3)))
+ simplex_sample = variable(torch.ones(3, 1, 2)) / 2
+ self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3, 3)))
def test_cauchy_shape_scalar_params(self):
cauchy = Cauchy(0, 1)
@@ -1867,7 +2073,7 @@ class TestDistributionShapes(TestCase):
self.assertEqual(cauchy.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2)))
self.assertEqual(cauchy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, cauchy.log_prob, self.tensor_sample_2)
- self.assertEqual(cauchy.log_prob(Variable(torch.ones(2, 1))).size(), torch.Size((2, 2)))
+ self.assertEqual(cauchy.log_prob(variable(torch.ones(2, 1))).size(), torch.Size((2, 2)))
def test_dirichlet_shape(self):
dist = Dirichlet(torch.tensor([[0.6, 0.3], [1.6, 1.3], [2.6, 2.3]]))
@@ -1875,9 +2081,12 @@ class TestDistributionShapes(TestCase):
self.assertEqual(dist._event_shape, torch.Size((2,)))
self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
self.assertEqual(dist.sample((5, 4)).size(), torch.Size((5, 4, 3, 2)))
- self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3,)))
+ simplex_sample = self.tensor_sample_1 / self.tensor_sample_1.sum(-1, keepdim=True)
+ self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3,)))
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
- self.assertEqual(dist.log_prob(Variable(torch.ones((3, 1, 2)))).size(), torch.Size((3, 3)))
+ simplex_sample = torch.ones((3, 1, 2))
+ simplex_sample = simplex_sample / simplex_sample.sum(-1).unsqueeze(-1)
+ self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3, 3)))
def test_gamma_shape_scalar_params(self):
gamma = Gamma(1, 1)
@@ -1897,7 +2106,7 @@ class TestDistributionShapes(TestCase):
self.assertEqual(gamma.sample((3, 2)).size(), torch.Size((3, 2, 2)))
self.assertEqual(gamma.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, gamma.log_prob, self.tensor_sample_2)
- self.assertEqual(gamma.log_prob(Variable(torch.ones(2, 1))).size(), torch.Size((2, 2)))
+ self.assertEqual(gamma.log_prob(variable(torch.ones(2, 1))).size(), torch.Size((2, 2)))
def test_chi2_shape_scalar_params(self):
chi2 = Chi2(1)
@@ -1917,7 +2126,7 @@ class TestDistributionShapes(TestCase):
self.assertEqual(chi2.sample((3, 2)).size(), torch.Size((3, 2, 2)))
self.assertEqual(chi2.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, chi2.log_prob, self.tensor_sample_2)
- self.assertEqual(chi2.log_prob(Variable(torch.ones(2, 1))).size(), torch.Size((2, 2)))
+ self.assertEqual(chi2.log_prob(variable(torch.ones(2, 1))).size(), torch.Size((2, 2)))
def test_studentT_shape_scalar_params(self):
st = StudentT(1)
@@ -1937,7 +2146,7 @@ class TestDistributionShapes(TestCase):
self.assertEqual(st.sample((3, 2)).size(), torch.Size((3, 2, 2)))
self.assertEqual(st.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, st.log_prob, self.tensor_sample_2)
- self.assertEqual(st.log_prob(Variable(torch.ones(2, 1))).size(), torch.Size((2, 2)))
+ self.assertEqual(st.log_prob(variable(torch.ones(2, 1))).size(), torch.Size((2, 2)))
def test_pareto_shape_scalar_params(self):
pareto = Pareto(1, 1)
@@ -1945,9 +2154,8 @@ class TestDistributionShapes(TestCase):
self.assertEqual(pareto._event_shape, torch.Size())
self.assertEqual(pareto.sample().size(), torch.Size())
self.assertEqual(pareto.sample((3, 2)).size(), torch.Size((3, 2)))
- self.assertRaises(ValueError, pareto.log_prob, self.scalar_sample)
- self.assertEqual(pareto.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
- self.assertEqual(pareto.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
+ self.assertEqual(pareto.log_prob(self.tensor_sample_1 + 1).size(), torch.Size((3, 2)))
+ self.assertEqual(pareto.log_prob(self.tensor_sample_2 + 1).size(), torch.Size((3, 2, 3)))
def test_gumbel_shape_scalar_params(self):
gumbel = Gumbel(1, 1)
@@ -1955,7 +2163,6 @@ class TestDistributionShapes(TestCase):
self.assertEqual(gumbel._event_shape, torch.Size())
self.assertEqual(gumbel.sample().size(), torch.Size())
self.assertEqual(gumbel.sample((3, 2)).size(), torch.Size((3, 2)))
- self.assertRaises(ValueError, gumbel.log_prob, self.scalar_sample)
self.assertEqual(gumbel.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(gumbel.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
@@ -1977,7 +2184,7 @@ class TestDistributionShapes(TestCase):
self.assertEqual(normal.sample((3, 2)).size(), torch.Size((3, 2, 2)))
self.assertEqual(normal.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, normal.log_prob, self.tensor_sample_2)
- self.assertEqual(normal.log_prob(Variable(torch.ones(2, 1))).size(), torch.Size((2, 2)))
+ self.assertEqual(normal.log_prob(variable(torch.ones(2, 1))).size(), torch.Size((2, 2)))
def test_uniform_shape_scalar_params(self):
uniform = Uniform(0, 1)
@@ -1997,7 +2204,7 @@ class TestDistributionShapes(TestCase):
self.assertEqual(uniform.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2)))
self.assertEqual(uniform.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, uniform.log_prob, self.tensor_sample_2)
- self.assertEqual(uniform.log_prob(Variable(torch.ones(2, 1))).size(), torch.Size((2, 2)))
+ self.assertEqual(uniform.log_prob(variable(torch.ones(2, 1))).size(), torch.Size((2, 2)))
def test_exponential_shape_scalar_param(self):
expon = Exponential(1.)
@@ -2017,7 +2224,7 @@ class TestDistributionShapes(TestCase):
self.assertEqual(expon.sample((3, 2)).size(), torch.Size((3, 2, 2)))
self.assertEqual(expon.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, expon.log_prob, self.tensor_sample_2)
- self.assertEqual(expon.log_prob(Variable(torch.ones(2, 2))).size(), torch.Size((2, 2)))
+ self.assertEqual(expon.log_prob(variable(torch.ones(2, 2))).size(), torch.Size((2, 2)))
def test_laplace_shape_scalar_params(self):
laplace = Laplace(0, 1)
@@ -2037,7 +2244,7 @@ class TestDistributionShapes(TestCase):
self.assertEqual(laplace.sample((3, 2)).size(), torch.Size((3, 2, 2)))
self.assertEqual(laplace.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, laplace.log_prob, self.tensor_sample_2)
- self.assertEqual(laplace.log_prob(Variable(torch.ones(2, 1))).size(), torch.Size((2, 2)))
+ self.assertEqual(laplace.log_prob(variable(torch.ones(2, 1))).size(), torch.Size((2, 2)))
class TestKL(TestCase):
@@ -2312,8 +2519,10 @@ class TestConstraints(TestCase):
constraint = dist.params[name]
except KeyError:
continue # ignore optional parameters
+
if is_dependent(constraint):
continue
+
message = '{} example {}/{} parameter {} = {}'.format(
Dist.__name__, i + 1, len(params), name, value)
self.assertTrue(constraint.check(value).all(), msg=message)
@@ -2417,35 +2626,35 @@ class TestNumericalStability(TestCase):
def test_categorical_log_prob(self):
for tensor_type in ([torch.FloatTensor, torch.DoubleTensor]):
- p = Variable(tensor_type([0, 1]), requires_grad=True)
+ p = variable(tensor_type([0, 1]), requires_grad=True)
categorical = OneHotCategorical(p)
- log_pdf = categorical.log_prob(Variable(tensor_type([0, 1])))
+ log_pdf = categorical.log_prob(variable(tensor_type([0, 1])))
self.assertEqual(log_pdf.item(), 0)
def test_categorical_log_prob_with_logits(self):
for tensor_type in ([torch.FloatTensor, torch.DoubleTensor]):
- p = Variable(tensor_type([-float('inf'), 0]), requires_grad=True)
+ p = variable(tensor_type([-float('inf'), 0]), requires_grad=True)
categorical = OneHotCategorical(logits=p)
- log_pdf_prob_1 = categorical.log_prob(Variable(tensor_type([0, 1])))
+ log_pdf_prob_1 = categorical.log_prob(variable(tensor_type([0, 1])))
self.assertEqual(log_pdf_prob_1.item(), 0)
- log_pdf_prob_0 = categorical.log_prob(Variable(tensor_type([1, 0])))
+ log_pdf_prob_0 = categorical.log_prob(variable(tensor_type([1, 0])))
self.assertEqual(log_pdf_prob_0.item(), -float('inf'), allow_inf=True)
def test_multinomial_log_prob(self):
for tensor_type in [torch.FloatTensor, torch.DoubleTensor]:
- p = Variable(tensor_type([0, 1]), requires_grad=True)
- s = Variable(tensor_type([0, 10]))
+ p = variable(tensor_type([0, 1]), requires_grad=True)
+ s = variable(tensor_type([0, 10]))
multinomial = Multinomial(10, p)
log_pdf = multinomial.log_prob(s)
self.assertEqual(log_pdf.item(), 0)
def test_multinomial_log_prob_with_logits(self):
for tensor_type in [torch.FloatTensor, torch.DoubleTensor]:
- p = Variable(tensor_type([-float('inf'), 0]), requires_grad=True)
+ p = variable(tensor_type([-float('inf'), 0]), requires_grad=True)
multinomial = Multinomial(10, logits=p)
- log_pdf_prob_1 = multinomial.log_prob(Variable(tensor_type([0, 10])))
+ log_pdf_prob_1 = multinomial.log_prob(variable(tensor_type([0, 10])))
self.assertEqual(log_pdf_prob_1.item(), 0)
- log_pdf_prob_0 = multinomial.log_prob(Variable(tensor_type([10, 0])))
+ log_pdf_prob_0 = multinomial.log_prob(variable(tensor_type([10, 0])))
self.assertEqual(log_pdf_prob_0.item(), -float('inf'), allow_inf=True)
@@ -2462,7 +2671,7 @@ class TestLazyLogitsInitialization(TestCase):
param['logits'] = probs_to_logits(probs)
dist = Dist(**param)
shape = (1,) if not dist.event_shape else dist.event_shape
- dist.log_prob(Variable(torch.ones(shape)))
+ dist.log_prob(variable(torch.ones(shape)))
message = 'Failed for {} example 0/{}'.format(Dist.__name__, len(params))
self.assertFalse('probs' in vars(dist), msg=message)
try:
@@ -2493,9 +2702,9 @@ class TestLazyLogitsInitialization(TestCase):
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
class TestAgainstScipy(TestCase):
def setUp(self):
- positive_var = Variable(torch.Tensor(20,).normal_()).exp()
- positive_var2 = Variable(torch.Tensor(20,).normal_()).exp()
- random_var = Variable(torch.Tensor(20,).normal_())
+ positive_var = variable(torch.Tensor(20,).normal_()).exp()
+ positive_var2 = variable(torch.Tensor(20,).normal_()).exp()
+ random_var = variable(torch.Tensor(20,).normal_())
random_tensor = torch.Tensor(20,).normal_()
simplex_tensor = softmax(random_tensor)
self.distribution_pairs = [
@@ -2606,7 +2815,7 @@ class TestAgainstScipy(TestCase):
def test_icdf(self):
for pytorch_dist, scipy_dist in self.distribution_pairs:
- samples = Variable(torch.rand((5,) + pytorch_dist.batch_shape))
+ samples = variable(torch.rand((5,) + pytorch_dist.batch_shape))
try:
icdf = pytorch_dist.icdf(samples)
except NotImplementedError:
@@ -2623,23 +2832,23 @@ class TestTransforms(TestCase):
AbsTransform(cache_size=cache_size),
ExpTransform(cache_size=cache_size),
SigmoidTransform(cache_size=cache_size),
- AffineTransform(Variable(torch.Tensor(5).normal_()),
- Variable(torch.Tensor(5).normal_()),
+ AffineTransform(variable(torch.Tensor(5).normal_()),
+ variable(torch.Tensor(5).normal_()),
cache_size=cache_size),
- AffineTransform(Variable(torch.Tensor(4, 5).normal_()),
- Variable(torch.Tensor(4, 5).normal_()),
+ AffineTransform(variable(torch.Tensor(4, 5).normal_()),
+ variable(torch.Tensor(4, 5).normal_()),
cache_size=cache_size),
BoltzmannTransform(cache_size=cache_size),
StickBreakingTransform(cache_size=cache_size),
LowerCholeskyTransform(cache_size=cache_size),
ComposeTransform([
- AffineTransform(Variable(torch.Tensor(4, 5).normal_()),
- Variable(torch.Tensor(4, 5).normal_()),
+ AffineTransform(variable(torch.Tensor(4, 5).normal_()),
+ variable(torch.Tensor(4, 5).normal_()),
cache_size=cache_size),
]),
ComposeTransform([
- AffineTransform(Variable(torch.Tensor(4, 5).normal_()),
- Variable(torch.Tensor(4, 5).normal_()),
+ AffineTransform(variable(torch.Tensor(4, 5).normal_()),
+ variable(torch.Tensor(4, 5).normal_()),
cache_size=cache_size),
ExpTransform(cache_size=cache_size),
]),
@@ -2690,7 +2899,7 @@ class TestTransforms(TestCase):
def test_forward_inverse_cache(self):
for transform in self.transforms:
- x = Variable(self._generate_data(transform), requires_grad=True)
+ x = variable(self._generate_data(transform), requires_grad=True)
try:
y = transform(x)
except NotImplementedError:
@@ -2717,7 +2926,7 @@ class TestTransforms(TestCase):
def test_forward_inverse_no_cache(self):
for transform in self.transforms:
- x = Variable(self._generate_data(transform), requires_grad=True)
+ x = variable(self._generate_data(transform), requires_grad=True)
try:
y = transform(x)
x2 = transform.inv(y.clone()) # bypass cache
@@ -2744,7 +2953,7 @@ class TestTransforms(TestCase):
def test_univariate_forward_jacobian(self):
for transform in self.transforms:
- x = Variable(self._generate_data(transform), requires_grad=True)
+ x = variable(self._generate_data(transform), requires_grad=True)
try:
y = transform(x)
actual = transform.log_abs_det_jacobian(x, y)
@@ -2759,7 +2968,7 @@ class TestTransforms(TestCase):
def test_univariate_inverse_jacobian(self):
for transform in self.transforms:
- y = Variable(self._generate_data(transform.inv), requires_grad=True)
+ y = variable(self._generate_data(transform.inv), requires_grad=True)
try:
x = transform.inv(y)
actual = transform.log_abs_det_jacobian(x, y)
@@ -2788,8 +2997,8 @@ class TestTransforms(TestCase):
transform0 = ExpTransform()
transform1 = BoltzmannTransform()
transform2 = LowerCholeskyTransform()
- base_dist0 = Normal(Variable(torch.zeros(4, 4)), Variable(torch.ones(4, 4)))
- base_dist1 = Dirichlet(Variable(torch.ones(4, 4)))
+ base_dist0 = Normal(variable(torch.zeros(4, 4)), variable(torch.ones(4, 4)))
+ base_dist1 = Dirichlet(variable(torch.ones(4, 4)))
examples = [
((4, 4), (), base_dist0),
((4,), (4,), base_dist1),
@@ -2842,7 +3051,7 @@ class TestConstraintRegistry(TestCase):
except NotImplementedError:
continue
self.assertTrue(t.bijective, "biject_to({}) is not bijective".format(constraint))
- x = Variable(torch.Tensor(5, 5)).normal_()
+ x = variable(torch.Tensor(5, 5)).normal_()
y = t(x)
self.assertTrue(constraint.check(y).all(), '\n'.join([
"Failed to biject_to({})".format(constraint),
@@ -2855,7 +3064,7 @@ class TestConstraintRegistry(TestCase):
def test_transform_to(self):
for constraint in self.constraints:
t = transform_to(constraint)
- x = Variable(torch.Tensor(5, 5)).normal_()
+ x = variable(torch.Tensor(5, 5)).normal_()
y = t(x)
self.assertTrue(constraint.check(y).all(), "Failed to transform_to({})".format(constraint))
x2 = t.inv(y)
@@ -2863,5 +3072,33 @@ class TestConstraintRegistry(TestCase):
self.assertEqual(y, y2, message="Error in transform_to({}) pseudoinverse".format(constraint))
+class TestValidation(TestCase):
+ def setUp(self):
+ super(TestCase, self).setUp()
+ Distribution.set_default_validate_args(True)
+
+ def test_valid(self):
+ for Dist, params in EXAMPLES:
+ if constraints.is_dependent(Dist.params): # skipping transformed dist
+ continue
+ for i, param in enumerate(params):
+ Dist(validate_args=True, **param)
+
+ def test_invalid(self):
+ for Dist, params in BAD_EXAMPLES:
+ if constraints.is_dependent(Dist.params): # skipping transformed dist
+ continue
+ for i, param in enumerate(params):
+ try:
+ with self.assertRaises(ValueError):
+ Dist(validate_args=True, **param)
+ except AssertionError:
+ fail_string = 'ValueError not raised for {} example {}/{}'
+ raise AssertionError(fail_string.format(Dist.__name__, i + 1, len(params)))
+
+ def tearDown(self):
+ super(TestCase, self).tearDown()
+ Distribution.set_default_validate_args(False)
+
if __name__ == '__main__':
run_tests()