diff options
Diffstat (limited to 'test/test_jit.py')
-rw-r--r-- | test/test_jit.py | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/test/test_jit.py b/test/test_jit.py index d91b5b4653..b77846908f 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -5331,6 +5331,22 @@ a") mte, (torch.zeros(1, 2, 3),), None, verbose=False, example_outputs=outputs, export_raw_ir=True)) + def test_onnx_export_script_non_alpha_add_sub(self): + class ModuleToExport(torch.jit.ScriptModule): + def __init__(self): + super(ModuleToExport, self).__init__() + + @torch.jit.script_method + def forward(self, x): + bs = x.size(0) + 1 + return bs - 1 + + mte = ModuleToExport() + outputs = torch.LongTensor([mte(torch.rand(3, 4))]) + self.assertExpected(torch.onnx.export_to_pretty_string( + mte, (torch.rand(3, 4),), None, verbose=False, + example_outputs=outputs)) + def test_onnx_export_script_module_if(self): class ModuleToExport(torch.jit.ScriptModule): def __init__(self): |