summaryrefslogtreecommitdiff
path: root/test/test_jit.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_jit.py')
-rw-r--r--test/test_jit.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/test/test_jit.py b/test/test_jit.py
index d91b5b4653..b77846908f 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -5331,6 +5331,22 @@ a")
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
example_outputs=outputs, export_raw_ir=True))
+ def test_onnx_export_script_non_alpha_add_sub(self):
+ class ModuleToExport(torch.jit.ScriptModule):
+ def __init__(self):
+ super(ModuleToExport, self).__init__()
+
+ @torch.jit.script_method
+ def forward(self, x):
+ bs = x.size(0) + 1
+ return bs - 1
+
+ mte = ModuleToExport()
+ outputs = torch.LongTensor([mte(torch.rand(3, 4))])
+ self.assertExpected(torch.onnx.export_to_pretty_string(
+ mte, (torch.rand(3, 4),), None, verbose=False,
+ example_outputs=outputs))
+
def test_onnx_export_script_module_if(self):
class ModuleToExport(torch.jit.ScriptModule):
def __init__(self):