summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/expect/TestScript.test_onnx_export_script_non_alpha_add_sub.expect22
-rw-r--r--test/onnx/expect/TestOperators.test_rsub.expect4
-rw-r--r--test/test_jit.py16
3 files changed, 40 insertions, 2 deletions
diff --git a/test/expect/TestScript.test_onnx_export_script_non_alpha_add_sub.expect b/test/expect/TestScript.test_onnx_export_script_non_alpha_add_sub.expect
new file mode 100644
index 0000000000..1c2b3c655d
--- /dev/null
+++ b/test/expect/TestScript.test_onnx_export_script_non_alpha_add_sub.expect
@@ -0,0 +1,22 @@
+ModelProto {
+ producer_name: "pytorch"
+ domain: ""
+ doc_string: ""
+ graph:
+ GraphProto {
+ name: "torch-jit-export"
+ inputs: [{name: "x", type:Tensor dims: 3 4}]
+ outputs: [{name: "7", type:Tensor dims: 1}]
+ initializers: []
+ nodes: [
+ Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
+ Node {type: "Shape", inputs: [x], outputs: [2], attributes: []},
+ Node {type: "Gather", inputs: [2,1], outputs: [3], attributes: [{ name: 'axis', type: int, value: 0}]},
+ Node {type: "Constant", inputs: [], outputs: [4], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
+ Node {type: "Add", inputs: [3,4], outputs: [5], attributes: []},
+ Node {type: "Constant", inputs: [], outputs: [6], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
+ Node {type: "Sub", inputs: [5,6], outputs: [7], attributes: []}
+ ]
+ }
+ opset_import: [OperatorSetIdProto { domain: }],
+}
diff --git a/test/onnx/expect/TestOperators.test_rsub.expect b/test/onnx/expect/TestOperators.test_rsub.expect
index 49fa976a5e..7f2e5284e0 100644
--- a/test/onnx/expect/TestOperators.test_rsub.expect
+++ b/test/onnx/expect/TestOperators.test_rsub.expect
@@ -8,8 +8,8 @@ graph {
attribute {
name: "value"
t {
- data_type: DOUBLE
- raw_data: "\000\000\000\000\000\000\360?"
+ data_type: INT64
+ raw_data: "\001\000\000\000\000\000\000\000"
}
type: TENSOR
}
diff --git a/test/test_jit.py b/test/test_jit.py
index d91b5b4653..b77846908f 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -5331,6 +5331,22 @@ a")
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
example_outputs=outputs, export_raw_ir=True))
+ def test_onnx_export_script_non_alpha_add_sub(self):
+ class ModuleToExport(torch.jit.ScriptModule):
+ def __init__(self):
+ super(ModuleToExport, self).__init__()
+
+ @torch.jit.script_method
+ def forward(self, x):
+ bs = x.size(0) + 1
+ return bs - 1
+
+ mte = ModuleToExport()
+ outputs = torch.LongTensor([mte(torch.rand(3, 4))])
+ self.assertExpected(torch.onnx.export_to_pretty_string(
+ mte, (torch.rand(3, 4),), None, verbose=False,
+ example_outputs=outputs))
+
def test_onnx_export_script_module_if(self):
class ModuleToExport(torch.jit.ScriptModule):
def __init__(self):