summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWanchao Liang <wanchaol@users.noreply.github.com>2018-09-18 13:41:11 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-09-18 13:55:39 -0700
commitd4e1fa45d055a1b00e8b7cfefa5c9f5db9ae6160 (patch)
tree48132b63f2c966cc17cc5bece0ddf60699b2aee2
parent7d25fa3c721f6b515af5bbfc704e96a978bef3c9 (diff)
downloadpytorch-d4e1fa45d055a1b00e8b7cfefa5c9f5db9ae6160.tar.gz
pytorch-d4e1fa45d055a1b00e8b7cfefa5c9f5db9ae6160.tar.bz2
pytorch-d4e1fa45d055a1b00e8b7cfefa5c9f5db9ae6160.zip
allow no-alpha add/sub in onnx symbolic (#10972)
Summary: The PR fixes #10873 The context is aten::add and aten::sub ST overloads don't have alpha, so onnx symbolic does not match. Pull Request resolved: https://github.com/pytorch/pytorch/pull/10972 Reviewed By: jamesr66a Differential Revision: D9724224 Pulled By: wanchaol fbshipit-source-id: eb5d1b09fa8f1604b288f4a62b8d1f0bc66611af
-rw-r--r--test/expect/TestScript.test_onnx_export_script_non_alpha_add_sub.expect22
-rw-r--r--test/onnx/expect/TestOperators.test_rsub.expect4
-rw-r--r--test/test_jit.py16
-rw-r--r--torch/onnx/symbolic.py17
4 files changed, 47 insertions, 12 deletions
diff --git a/test/expect/TestScript.test_onnx_export_script_non_alpha_add_sub.expect b/test/expect/TestScript.test_onnx_export_script_non_alpha_add_sub.expect
new file mode 100644
index 0000000000..1c2b3c655d
--- /dev/null
+++ b/test/expect/TestScript.test_onnx_export_script_non_alpha_add_sub.expect
@@ -0,0 +1,22 @@
+ModelProto {
+ producer_name: "pytorch"
+ domain: ""
+ doc_string: ""
+ graph:
+ GraphProto {
+ name: "torch-jit-export"
+ inputs: [{name: "x", type:Tensor dims: 3 4}]
+ outputs: [{name: "7", type:Tensor dims: 1}]
+ initializers: []
+ nodes: [
+ Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
+ Node {type: "Shape", inputs: [x], outputs: [2], attributes: []},
+ Node {type: "Gather", inputs: [2,1], outputs: [3], attributes: [{ name: 'axis', type: int, value: 0}]},
+ Node {type: "Constant", inputs: [], outputs: [4], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
+ Node {type: "Add", inputs: [3,4], outputs: [5], attributes: []},
+ Node {type: "Constant", inputs: [], outputs: [6], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
+ Node {type: "Sub", inputs: [5,6], outputs: [7], attributes: []}
+ ]
+ }
+ opset_import: [OperatorSetIdProto { domain: }],
+}
diff --git a/test/onnx/expect/TestOperators.test_rsub.expect b/test/onnx/expect/TestOperators.test_rsub.expect
index 49fa976a5e..7f2e5284e0 100644
--- a/test/onnx/expect/TestOperators.test_rsub.expect
+++ b/test/onnx/expect/TestOperators.test_rsub.expect
@@ -8,8 +8,8 @@ graph {
attribute {
name: "value"
t {
- data_type: DOUBLE
- raw_data: "\000\000\000\000\000\000\360?"
+ data_type: INT64
+ raw_data: "\001\000\000\000\000\000\000\000"
}
type: TENSOR
}
diff --git a/test/test_jit.py b/test/test_jit.py
index d91b5b4653..b77846908f 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -5331,6 +5331,22 @@ a")
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
example_outputs=outputs, export_raw_ir=True))
+ def test_onnx_export_script_non_alpha_add_sub(self):
+ class ModuleToExport(torch.jit.ScriptModule):
+ def __init__(self):
+ super(ModuleToExport, self).__init__()
+
+ @torch.jit.script_method
+ def forward(self, x):
+ bs = x.size(0) + 1
+ return bs - 1
+
+ mte = ModuleToExport()
+ outputs = torch.LongTensor([mte(torch.rand(3, 4))])
+ self.assertExpected(torch.onnx.export_to_pretty_string(
+ mte, (torch.rand(3, 4),), None, verbose=False,
+ example_outputs=outputs))
+
def test_onnx_export_script_module_if(self):
class ModuleToExport(torch.jit.ScriptModule):
def __init__(self):
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py
index d5b586c384..3f33430470 100644
--- a/torch/onnx/symbolic.py
+++ b/torch/onnx/symbolic.py
@@ -192,25 +192,22 @@ def unused(g):
return g.op("prim::Undefined")
-@parse_args('v', 'v', 't')
-def add(g, self, other, alpha):
- if _scalar(alpha) != 1:
+def add(g, self, other, alpha=None):
+ # default alpha arg is to allow no-alpha add (aten add st overload no alpha)
+ if alpha and _scalar(_maybe_get_scalar(alpha)) != 1:
return _unimplemented("add", "alpha != 1")
# See Note [Pointwise by scalar]
other = _maybe_get_scalar(other)
return g.op("Add", self, _if_scalar_type_as(g, other, self))
-@parse_args('v', 'v', 't')
-def sub(g, self, other, alpha):
- if _scalar(alpha) != 1:
+def sub(g, self, other, alpha=None):
+ # default alpha arg is to allow no-alpha sub (aten sub st overload no alpha)
+ if alpha and _scalar(_maybe_get_scalar(alpha)) != 1:
return _unimplemented("sub", "alpha != 1")
# See Note [Pointwise by scalar]. Note that self or other may be scalars.
other = _maybe_get_scalar(other)
- self = _maybe_get_scalar(self)
- self = _if_scalar_type_as(g, self, other)
- other = _if_scalar_type_as(g, other, self)
- return g.op("Sub", self, other)
+ return g.op("Sub", self, _if_scalar_type_as(g, other, self))
def mul(g, self, other):