summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorWanchao Liang <wanchaol@users.noreply.github.com>2019-02-07 10:32:02 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-07 11:02:14 -0800
commitac00e85e36c236f141d0621bfbbbbe8c9ffeefd1 (patch)
tree93d7b45194e53e0e9c19cf6eb1280118ff7459e6 /test
parent0d366e1bde42265d764a6a9aad27a38753156fb0 (diff)
downloadpytorch-ac00e85e36c236f141d0621bfbbbbe8c9ffeefd1.tar.gz
pytorch-ac00e85e36c236f141d0621bfbbbbe8c9ffeefd1.tar.bz2
pytorch-ac00e85e36c236f141d0621bfbbbbe8c9ffeefd1.zip
Remove undefined tensor in jit script (#16379)
Summary: This PR is a follow up of #15460, it did the following things: * remove the undefined tensor semantic in jit script/tracing mode * change ATen/JIT schema for at::index and other index related ops with `Tensor?[]` to align with what at::index is really doing and to adopt `optional[tensor]` in JIT * change python_print to correctly print the exported script * register both TensorList and ListOfOptionalTensor in JIT ATen ops to support both * Backward compatibility for `torch.jit.annotate(Tensor, None)` List of follow ups: * remove the undefined tensor semantic in jit autograd, autodiff and grad_of * remove prim::Undefined fully For easy reviews, please turn on `hide white space changes` in diff settings. Pull Request resolved: https://github.com/pytorch/pytorch/pull/16379 Differential Revision: D13855677 Pulled By: wanchaol fbshipit-source-id: 0e21c14d7de250c62731227c81bfbfb7b7da20ab
Diffstat (limited to 'test')
-rw-r--r--test/expect/TestJit.test_conv.expect2
-rw-r--r--test/expect/TestScript.test_index_put_trace_with_view.expect2
-rw-r--r--test/expect/TestScript.test_index_put_trace_without_view.expect2
-rw-r--r--test/test_jit.py12
4 files changed, 14 insertions, 4 deletions
diff --git a/test/expect/TestJit.test_conv.expect b/test/expect/TestJit.test_conv.expect
index 3e8108b880..12942487d5 100644
--- a/test/expect/TestJit.test_conv.expect
+++ b/test/expect/TestJit.test_conv.expect
@@ -1,6 +1,6 @@
graph(%0 : Double(20, 16, 50, 40)
%1 : Double(13, 16, 3, 3)) {
- %2 : Tensor = prim::Undefined(), scope: Conv2d
+ %2 : Tensor? = prim::None(), scope: Conv2d
%3 : int = prim::Constant[value=1](), scope: Conv2d
%4 : int = prim::Constant[value=1](), scope: Conv2d
%5 : int[] = prim::ListConstruct(%3, %4), scope: Conv2d
diff --git a/test/expect/TestScript.test_index_put_trace_with_view.expect b/test/expect/TestScript.test_index_put_trace_with_view.expect
index 39c92a4593..936010e6d4 100644
--- a/test/expect/TestScript.test_index_put_trace_with_view.expect
+++ b/test/expect/TestScript.test_index_put_trace_with_view.expect
@@ -10,7 +10,7 @@ graph(%target : Double(100)
%9 : bool = prim::Constant[value=0]()
%10 : bool = prim::Constant[value=0]()
%indices : Long(4) = aten::to(%indices.1, %6, %7, %8, %9, %10)
- %12 : Tensor[] = prim::ListConstruct(%indices)
+ %12 : Tensor?[] = prim::ListConstruct(%indices)
%13 : bool = prim::Constant[value=0]()
%14 : Double(100) = aten::index_put_(%target, %12, %5, %13)
return (%14);
diff --git a/test/expect/TestScript.test_index_put_trace_without_view.expect b/test/expect/TestScript.test_index_put_trace_without_view.expect
index dfe403fa89..1da672897d 100644
--- a/test/expect/TestScript.test_index_put_trace_without_view.expect
+++ b/test/expect/TestScript.test_index_put_trace_without_view.expect
@@ -7,7 +7,7 @@ graph(%target : Double(100)
%6 : bool = prim::Constant[value=0]()
%7 : bool = prim::Constant[value=0]()
%indices : Long(4) = aten::to(%indices.1, %3, %4, %5, %6, %7)
- %9 : Tensor[] = prim::ListConstruct(%indices)
+ %9 : Tensor?[] = prim::ListConstruct(%indices)
%10 : bool = prim::Constant[value=0]()
%11 : Double(100) = aten::index_put_(%target, %9, %rhs, %10)
return (%11);
diff --git a/test/test_jit.py b/test/test_jit.py
index d5be13f6c9..6beb003314 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -2975,6 +2975,16 @@ class TestScript(JitTestCase):
return torch.jit.annotate(float, a)
self.checkScript(baz, (torch.rand(()),))
+ # test annotate none types
+ def annotate_none():
+ return torch.jit.annotate(Optional[torch.Tensor], None)
+
+ def annotate_none_no_optional():
+ return torch.jit.annotate(torch.Tensor, None)
+
+ self.checkScript(annotate_none, ())
+ self.checkScript(annotate_none_no_optional, ())
+
def test_robust_op_resolution(self):
neg = torch.add # misleading name to make sure we resolve by function
@@ -3453,7 +3463,6 @@ a")
formals = ''.join(map(', {}'.format, formals))
inputs = [tensor] + values
-
self._check_code(template.format(formals=formals, expr=indexing),
"func", inputs)
@@ -10470,6 +10479,7 @@ EXCLUDE_TRACED = {
'test___getitem___adv_index_sub_2',
'test___getitem___adv_index_sub_3',
'test___getitem___adv_index_var',
+
}
EXCLUDE_TYPE_CHECK = {