diff options
author | Wanchao Liang <wanchaol@users.noreply.github.com> | 2019-02-07 10:32:02 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-07 11:02:14 -0800 |
commit | ac00e85e36c236f141d0621bfbbbbe8c9ffeefd1 (patch) | |
tree | 93d7b45194e53e0e9c19cf6eb1280118ff7459e6 /test | |
parent | 0d366e1bde42265d764a6a9aad27a38753156fb0 (diff) | |
download | pytorch-ac00e85e36c236f141d0621bfbbbbe8c9ffeefd1.tar.gz pytorch-ac00e85e36c236f141d0621bfbbbbe8c9ffeefd1.tar.bz2 pytorch-ac00e85e36c236f141d0621bfbbbbe8c9ffeefd1.zip |
Remove undefined tensor in jit script (#16379)
Summary:
This PR is a follow up of #15460, it did the following things:
* remove the undefined tensor semantic in jit script/tracing mode
* change ATen/JIT schema for at::index and other index related ops with `Tensor?[]` to align with what at::index is really doing and to adopt `optional[tensor]` in JIT
* change python_print to correctly print the exported script
* register both TensorList and ListOfOptionalTensor in JIT ATen ops to support both
* Backward compatibility for `torch.jit.annotate(Tensor, None)`
List of follow ups:
* remove the undefined tensor semantic in jit autograd, autodiff and grad_of
* remove prim::Undefined fully
For easy reviews, please turn on `hide white space changes` in diff settings.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16379
Differential Revision: D13855677
Pulled By: wanchaol
fbshipit-source-id: 0e21c14d7de250c62731227c81bfbfb7b7da20ab
Diffstat (limited to 'test')
-rw-r--r-- | test/expect/TestJit.test_conv.expect | 2 | ||||
-rw-r--r-- | test/expect/TestScript.test_index_put_trace_with_view.expect | 2 | ||||
-rw-r--r-- | test/expect/TestScript.test_index_put_trace_without_view.expect | 2 | ||||
-rw-r--r-- | test/test_jit.py | 12 |
4 files changed, 14 insertions, 4 deletions
diff --git a/test/expect/TestJit.test_conv.expect b/test/expect/TestJit.test_conv.expect index 3e8108b880..12942487d5 100644 --- a/test/expect/TestJit.test_conv.expect +++ b/test/expect/TestJit.test_conv.expect @@ -1,6 +1,6 @@ graph(%0 : Double(20, 16, 50, 40) %1 : Double(13, 16, 3, 3)) { - %2 : Tensor = prim::Undefined(), scope: Conv2d + %2 : Tensor? = prim::None(), scope: Conv2d %3 : int = prim::Constant[value=1](), scope: Conv2d %4 : int = prim::Constant[value=1](), scope: Conv2d %5 : int[] = prim::ListConstruct(%3, %4), scope: Conv2d diff --git a/test/expect/TestScript.test_index_put_trace_with_view.expect b/test/expect/TestScript.test_index_put_trace_with_view.expect index 39c92a4593..936010e6d4 100644 --- a/test/expect/TestScript.test_index_put_trace_with_view.expect +++ b/test/expect/TestScript.test_index_put_trace_with_view.expect @@ -10,7 +10,7 @@ graph(%target : Double(100) %9 : bool = prim::Constant[value=0]() %10 : bool = prim::Constant[value=0]() %indices : Long(4) = aten::to(%indices.1, %6, %7, %8, %9, %10) - %12 : Tensor[] = prim::ListConstruct(%indices) + %12 : Tensor?[] = prim::ListConstruct(%indices) %13 : bool = prim::Constant[value=0]() %14 : Double(100) = aten::index_put_(%target, %12, %5, %13) return (%14); diff --git a/test/expect/TestScript.test_index_put_trace_without_view.expect b/test/expect/TestScript.test_index_put_trace_without_view.expect index dfe403fa89..1da672897d 100644 --- a/test/expect/TestScript.test_index_put_trace_without_view.expect +++ b/test/expect/TestScript.test_index_put_trace_without_view.expect @@ -7,7 +7,7 @@ graph(%target : Double(100) %6 : bool = prim::Constant[value=0]() %7 : bool = prim::Constant[value=0]() %indices : Long(4) = aten::to(%indices.1, %3, %4, %5, %6, %7) - %9 : Tensor[] = prim::ListConstruct(%indices) + %9 : Tensor?[] = prim::ListConstruct(%indices) %10 : bool = prim::Constant[value=0]() %11 : Double(100) = aten::index_put_(%target, %9, %rhs, %10) return (%11); diff --git a/test/test_jit.py b/test/test_jit.py index d5be13f6c9..6beb003314 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -2975,6 +2975,16 @@ class TestScript(JitTestCase): return torch.jit.annotate(float, a) self.checkScript(baz, (torch.rand(()),)) + # test annotate none types + def annotate_none(): + return torch.jit.annotate(Optional[torch.Tensor], None) + + def annotate_none_no_optional(): + return torch.jit.annotate(torch.Tensor, None) + + self.checkScript(annotate_none, ()) + self.checkScript(annotate_none_no_optional, ()) + def test_robust_op_resolution(self): neg = torch.add # misleading name to make sure we resolve by function @@ -3453,7 +3463,6 @@ a") formals = ''.join(map(', {}'.format, formals)) inputs = [tensor] + values - self._check_code(template.format(formals=formals, expr=indexing), "func", inputs) @@ -10470,6 +10479,7 @@ EXCLUDE_TRACED = { 'test___getitem___adv_index_sub_2', 'test___getitem___adv_index_sub_3', 'test___getitem___adv_index_var', + } EXCLUDE_TYPE_CHECK = { |