summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorEdward Yang <ezyang@fb.com>2019-02-04 19:23:57 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-04 19:30:04 -0800
commit6c04224cd84385910021580176413a98839cec72 (patch)
tree60739d235527165884dad0417f24efc16f253d14 /test
parent1409a2afc8d81d911dbcfc2d0b210cf62658237b (diff)
downloadpytorch-6c04224cd84385910021580176413a98839cec72.tar.gz
pytorch-6c04224cd84385910021580176413a98839cec72.tar.bz2
pytorch-6c04224cd84385910021580176413a98839cec72.zip
Revert "Move outplace ops to ATen (#12413)" (#16731)
Summary: This reverts commit f660d3ae19decc64390e894fbaf8de80d87585e0. cc zasdfgbnm Reasoning at https://github.com/pytorch/pytorch/pull/12413#issuecomment-460424129 Pull Request resolved: https://github.com/pytorch/pytorch/pull/16731 Differential Revision: D13948022 Pulled By: ezyang fbshipit-source-id: b10669cf03679e306850314b7b5b08bed0839e19
Diffstat (limited to 'test')
-rw-r--r--test/common_methods_invocations.py4
-rw-r--r--test/test_jit.py2
-rw-r--r--test/test_torch.py12
3 files changed, 6 insertions, 12 deletions
diff --git a/test/common_methods_invocations.py b/test/common_methods_invocations.py
index b998c6fef6..41c8763e16 100644
--- a/test/common_methods_invocations.py
+++ b/test/common_methods_invocations.py
@@ -744,7 +744,7 @@ def method_tests():
('masked_select', (), (mask_not_all_zeros((M, M)),), 'scalar_broadcast_lhs'),
('masked_fill', (M, M), (torch.ByteTensor(M, M).bernoulli_(), 10)),
('masked_fill', (M, M), (torch.ByteTensor(M, M).bernoulli_(), torch.tensor(10)), 'tensor'),
- ('masked_fill', (M,), (torch.ByteTensor(M, M).bernoulli_(), 10), 'broadcast_lhs'),
+ # no lhs or all broadcast on masked_fill or masked_scatter because it's always inplace
('masked_fill', (M, M), (torch.ByteTensor(M,).bernoulli_(), 10), 'broadcast_rhs'),
('masked_fill', (), (torch.tensor(0, dtype=torch.uint8, requires_grad=False).bernoulli_(), 10), 'scalar'),
('masked_fill', (), (torch.tensor(0, dtype=torch.uint8, requires_grad=False).bernoulli_(), torch.tensor(10)),
@@ -752,8 +752,6 @@ def method_tests():
('masked_fill', (M, M), (torch.tensor(0, dtype=torch.uint8, requires_grad=False).bernoulli_(), 10),
'scalar_broadcast_rhs'),
('masked_scatter', (M, M), (torch.ByteTensor(M, M).bernoulli_(), (M, M))),
- ('masked_scatter', (M,), (torch.ByteTensor(M, M).bernoulli_(), (M, M)),
- 'broadcast_lhs'),
('masked_scatter', (M, M), (torch.ByteTensor(M,).bernoulli_(), (M, M)),
'broadcast_rhs'),
('masked_scatter', (M, M), (bernoulli_scalar(), (M, M)), 'scalar'),
diff --git a/test/test_jit.py b/test/test_jit.py
index 42ec3d34c3..b0ef8326da 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -8834,7 +8834,7 @@ a")
def test_builtin_error_messsage(self):
from torch.nn.modules.utils import _single, _pair, _triple, _quadruple
- with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
+ with self.assertRaisesRegex(RuntimeError, "aten::masked_fill_"):
@torch.jit.script
def close_match(x):
return x.masked_fill(True)
diff --git a/test/test_torch.py b/test/test_torch.py
index 1c538d5a2d..f3e8d8234c 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -3450,7 +3450,6 @@ class _TestTorchMixin(object):
for fn in fns:
(dims_small, dims_large, dims_full) = self._select_broadcastable_dims()
- full1d = cast(torch.randn(*dims_full).flatten().float())
small = cast(torch.randn(*dims_small).float())
large = cast(torch.randn(*dims_large).float())
small_expanded = small.expand(*dims_full)
@@ -3467,7 +3466,8 @@ class _TestTorchMixin(object):
# map and map2 are not implementd on CUDA tensors
continue
- if hasattr(large_expanded, fn):
+ # TODO: fix masked_scatter and masked_fill broadcasting
+ if hasattr(large_expanded, fn) and fn not in ['masked_scatter', 'masked_fill']:
# run through tensor versions of functions
# and verify fully expanded inputs give same results
expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded}
@@ -3477,10 +3477,6 @@ class _TestTorchMixin(object):
return myfn(t1, 0.5)
elif fn == "masked_select":
return myfn(t1 < 0)
- elif fn == "masked_scatter":
- return myfn(t1 < 0.5, full1d)
- elif fn == "masked_fill":
- return myfn(t1 < 0.5, 1.0)
elif fn in fns_3_args:
return myfn(1, t1, t2)
else:
@@ -3508,7 +3504,7 @@ class _TestTorchMixin(object):
elif fn == "masked_select":
return fntorch(t1, t2 < 0)
elif fn == "masked_scatter":
- return fntorch(t1, t2 < 0.5, full1d)
+ return fntorch(t1, t2 < 0.5, cast(torch.arange(1, t1.nelement() + 1).float()))
elif fn == "masked_fill":
return fntorch(t1, t2 < 0.5, 1.0)
elif fn in fns_3_args:
@@ -3539,7 +3535,7 @@ class _TestTorchMixin(object):
if fn == "lerp":
return t0_fn(t1, 0.5)
elif fn == "masked_scatter":
- return t0_fn(t1 < 0.5, full1d)
+ return t0_fn(t1 < 0.5, cast(torch.arange(1, t0.nelement() + 1).float()))
elif fn == "masked_fill":
return t0_fn(t1 < 0.5, 1.0)
elif fn == "map":