diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/common_nn.py | 2 | ||||
-rw-r--r-- | test/test_autograd.py | 8 |
2 files changed, 7 insertions, 3 deletions
diff --git a/test/common_nn.py b/test/common_nn.py index cae61dcb9d..e6b6469800 100644 --- a/test/common_nn.py +++ b/test/common_nn.py @@ -324,6 +324,7 @@ criterion_tests = [ module_name='NLLLoss2d', input_size=(2, 3, 5, 5), target_fn=lambda: torch.rand(2, 5, 5).mul(3).floor().long(), + check_no_size_average=True, ), dict( module_name='NLLLoss2d', @@ -356,6 +357,7 @@ criterion_tests = [ module_name='MultiLabelMarginLoss', input_size=(5, 10), target_fn=lambda: torch.rand(5, 10).mul(10).floor().long(), + check_no_size_average=True, check_gradgrad=False, ), dict( diff --git a/test/test_autograd.py b/test/test_autograd.py index ea69d45f2a..1c338b0112 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -1785,11 +1785,11 @@ method_tests = [ ('fmod', (S, S, S), (Variable(torch.rand(S, S, S) + 1.5, requires_grad=False),), 'tensor'), ('fmod', (S,), (Variable(torch.rand(S, S, S) + 1.5, requires_grad=False),), 'tensor_broadcast_lhs'), ('fmod', (S, S, S), (Variable(torch.rand(S) + 1.5, requires_grad=False),), 'tensor_broadcast_rhs'), - ('fmod', (S, 1, S), (Variable(torch.rand(S, S) + 1.5, requires_grad=False),), 'tensor_broacast_all'), + ('fmod', (S, 1, S), (Variable(torch.rand(S, S) + 1.5, requires_grad=False),), 'tensor_broadcast_all'), ('remainder', (S, S, S), (1.5,)), ('remainder', (S, S, S), (Variable(torch.rand(S, S, S) + 1.5, requires_grad=False),), 'tensor'), ('remainder', (S,), (Variable(torch.rand(S, S, S) + 1.5, requires_grad=False),), 'tensor_broadcast_lhs'), - ('remainder', (S, 1, S), (Variable(torch.rand(S, S) + 1.5, requires_grad=False),), 'tensor_broacast_all'), + ('remainder', (S, 1, S), (Variable(torch.rand(S, S) + 1.5, requires_grad=False),), 'tensor_broadcast_all'), ('lerp', (S, S, S), ((S, S, S), 0.4)), ('lerp', (S, S, S), ((S,), 0.4), 'broadcast_rhs'), ('lerp', (S,), ((S, S, S), 0.4), 'broadcast_lhs'), @@ -2212,7 +2212,9 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks, for test in method_tests: name, self_size, args = test[:3] - basic_test_name = 'test_' + name + ('_' + test[3] if len(test) >= 4 else '') + basic_test_name = 'test_' + name + if len(test) >= 4 and test[3] != '': + basic_test_name += '_' + test[3] dim_args_idx = test[4] if len(test) == 5 else [] |