summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/common_nn.py2
-rw-r--r--test/test_autograd.py8
2 files changed, 7 insertions, 3 deletions
diff --git a/test/common_nn.py b/test/common_nn.py
index cae61dcb9d..e6b6469800 100644
--- a/test/common_nn.py
+++ b/test/common_nn.py
@@ -324,6 +324,7 @@ criterion_tests = [
module_name='NLLLoss2d',
input_size=(2, 3, 5, 5),
target_fn=lambda: torch.rand(2, 5, 5).mul(3).floor().long(),
+ check_no_size_average=True,
),
dict(
module_name='NLLLoss2d',
@@ -356,6 +357,7 @@ criterion_tests = [
module_name='MultiLabelMarginLoss',
input_size=(5, 10),
target_fn=lambda: torch.rand(5, 10).mul(10).floor().long(),
+ check_no_size_average=True,
check_gradgrad=False,
),
dict(
diff --git a/test/test_autograd.py b/test/test_autograd.py
index ea69d45f2a..1c338b0112 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -1785,11 +1785,11 @@ method_tests = [
('fmod', (S, S, S), (Variable(torch.rand(S, S, S) + 1.5, requires_grad=False),), 'tensor'),
('fmod', (S,), (Variable(torch.rand(S, S, S) + 1.5, requires_grad=False),), 'tensor_broadcast_lhs'),
('fmod', (S, S, S), (Variable(torch.rand(S) + 1.5, requires_grad=False),), 'tensor_broadcast_rhs'),
- ('fmod', (S, 1, S), (Variable(torch.rand(S, S) + 1.5, requires_grad=False),), 'tensor_broacast_all'),
+ ('fmod', (S, 1, S), (Variable(torch.rand(S, S) + 1.5, requires_grad=False),), 'tensor_broadcast_all'),
('remainder', (S, S, S), (1.5,)),
('remainder', (S, S, S), (Variable(torch.rand(S, S, S) + 1.5, requires_grad=False),), 'tensor'),
('remainder', (S,), (Variable(torch.rand(S, S, S) + 1.5, requires_grad=False),), 'tensor_broadcast_lhs'),
- ('remainder', (S, 1, S), (Variable(torch.rand(S, S) + 1.5, requires_grad=False),), 'tensor_broacast_all'),
+ ('remainder', (S, 1, S), (Variable(torch.rand(S, S) + 1.5, requires_grad=False),), 'tensor_broadcast_all'),
('lerp', (S, S, S), ((S, S, S), 0.4)),
('lerp', (S, S, S), ((S,), 0.4), 'broadcast_rhs'),
('lerp', (S,), ((S, S, S), 0.4), 'broadcast_lhs'),
@@ -2212,7 +2212,9 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
for test in method_tests:
name, self_size, args = test[:3]
- basic_test_name = 'test_' + name + ('_' + test[3] if len(test) >= 4 else '')
+ basic_test_name = 'test_' + name
+ if len(test) >= 4 and test[3] != '':
+ basic_test_name += '_' + test[3]
dim_args_idx = test[4] if len(test) == 5 else []