diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/expect/TestScript.test_onnx_export_script_non_alpha_add_sub.expect | 22 | ||||
-rw-r--r-- | test/onnx/expect/TestOperators.test_rsub.expect | 4 | ||||
-rw-r--r-- | test/test_jit.py | 16 |
3 files changed, 40 insertions, 2 deletions
diff --git a/test/expect/TestScript.test_onnx_export_script_non_alpha_add_sub.expect b/test/expect/TestScript.test_onnx_export_script_non_alpha_add_sub.expect new file mode 100644 index 0000000000..1c2b3c655d --- /dev/null +++ b/test/expect/TestScript.test_onnx_export_script_non_alpha_add_sub.expect @@ -0,0 +1,22 @@ +ModelProto { + producer_name: "pytorch" + domain: "" + doc_string: "" + graph: + GraphProto { + name: "torch-jit-export" + inputs: [{name: "x", type:Tensor dims: 3 4}] + outputs: [{name: "7", type:Tensor dims: 1}] + initializers: [] + nodes: [ + Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}, + Node {type: "Shape", inputs: [x], outputs: [2], attributes: []}, + Node {type: "Gather", inputs: [2,1], outputs: [3], attributes: [{ name: 'axis', type: int, value: 0}]}, + Node {type: "Constant", inputs: [], outputs: [4], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}, + Node {type: "Add", inputs: [3,4], outputs: [5], attributes: []}, + Node {type: "Constant", inputs: [], outputs: [6], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}, + Node {type: "Sub", inputs: [5,6], outputs: [7], attributes: []} + ] + } + opset_import: [OperatorSetIdProto { domain: }], +} diff --git a/test/onnx/expect/TestOperators.test_rsub.expect b/test/onnx/expect/TestOperators.test_rsub.expect index 49fa976a5e..7f2e5284e0 100644 --- a/test/onnx/expect/TestOperators.test_rsub.expect +++ b/test/onnx/expect/TestOperators.test_rsub.expect @@ -8,8 +8,8 @@ graph { attribute { name: "value" t { - data_type: DOUBLE - raw_data: "\000\000\000\000\000\000\360?" + data_type: INT64 + raw_data: "\001\000\000\000\000\000\000\000" } type: TENSOR } diff --git a/test/test_jit.py b/test/test_jit.py index d91b5b4653..b77846908f 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -5331,6 +5331,22 @@ a") mte, (torch.zeros(1, 2, 3),), None, verbose=False, example_outputs=outputs, export_raw_ir=True)) + def test_onnx_export_script_non_alpha_add_sub(self): + class ModuleToExport(torch.jit.ScriptModule): + def __init__(self): + super(ModuleToExport, self).__init__() + + @torch.jit.script_method + def forward(self, x): + bs = x.size(0) + 1 + return bs - 1 + + mte = ModuleToExport() + outputs = torch.LongTensor([mte(torch.rand(3, 4))]) + self.assertExpected(torch.onnx.export_to_pretty_string( + mte, (torch.rand(3, 4),), None, verbose=False, + example_outputs=outputs)) + def test_onnx_export_script_module_if(self): class ModuleToExport(torch.jit.ScriptModule): def __init__(self): |