diff options
author | Wanchao Liang <wanchaol@users.noreply.github.com> | 2018-09-18 13:41:11 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-09-18 13:55:39 -0700 |
commit | d4e1fa45d055a1b00e8b7cfefa5c9f5db9ae6160 (patch) | |
tree | 48132b63f2c966cc17cc5bece0ddf60699b2aee2 | |
parent | 7d25fa3c721f6b515af5bbfc704e96a978bef3c9 (diff) | |
download | pytorch-d4e1fa45d055a1b00e8b7cfefa5c9f5db9ae6160.tar.gz pytorch-d4e1fa45d055a1b00e8b7cfefa5c9f5db9ae6160.tar.bz2 pytorch-d4e1fa45d055a1b00e8b7cfefa5c9f5db9ae6160.zip |
allow no-alpha add/sub in onnx symbolic (#10972)
Summary:
The PR fixes #10873
The context is aten::add and aten::sub ST overloads don't have alpha, so onnx symbolic does not match.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10972
Reviewed By: jamesr66a
Differential Revision: D9724224
Pulled By: wanchaol
fbshipit-source-id: eb5d1b09fa8f1604b288f4a62b8d1f0bc66611af
-rw-r--r-- | test/expect/TestScript.test_onnx_export_script_non_alpha_add_sub.expect | 22 | ||||
-rw-r--r-- | test/onnx/expect/TestOperators.test_rsub.expect | 4 | ||||
-rw-r--r-- | test/test_jit.py | 16 | ||||
-rw-r--r-- | torch/onnx/symbolic.py | 17 |
4 files changed, 47 insertions, 12 deletions
diff --git a/test/expect/TestScript.test_onnx_export_script_non_alpha_add_sub.expect b/test/expect/TestScript.test_onnx_export_script_non_alpha_add_sub.expect new file mode 100644 index 0000000000..1c2b3c655d --- /dev/null +++ b/test/expect/TestScript.test_onnx_export_script_non_alpha_add_sub.expect @@ -0,0 +1,22 @@ +ModelProto { + producer_name: "pytorch" + domain: "" + doc_string: "" + graph: + GraphProto { + name: "torch-jit-export" + inputs: [{name: "x", type:Tensor dims: 3 4}] + outputs: [{name: "7", type:Tensor dims: 1}] + initializers: [] + nodes: [ + Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}, + Node {type: "Shape", inputs: [x], outputs: [2], attributes: []}, + Node {type: "Gather", inputs: [2,1], outputs: [3], attributes: [{ name: 'axis', type: int, value: 0}]}, + Node {type: "Constant", inputs: [], outputs: [4], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}, + Node {type: "Add", inputs: [3,4], outputs: [5], attributes: []}, + Node {type: "Constant", inputs: [], outputs: [6], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}, + Node {type: "Sub", inputs: [5,6], outputs: [7], attributes: []} + ] + } + opset_import: [OperatorSetIdProto { domain: }], +} diff --git a/test/onnx/expect/TestOperators.test_rsub.expect b/test/onnx/expect/TestOperators.test_rsub.expect index 49fa976a5e..7f2e5284e0 100644 --- a/test/onnx/expect/TestOperators.test_rsub.expect +++ b/test/onnx/expect/TestOperators.test_rsub.expect @@ -8,8 +8,8 @@ graph { attribute { name: "value" t { - data_type: DOUBLE - raw_data: "\000\000\000\000\000\000\360?" + data_type: INT64 + raw_data: "\001\000\000\000\000\000\000\000" } type: TENSOR } diff --git a/test/test_jit.py b/test/test_jit.py index d91b5b4653..b77846908f 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -5331,6 +5331,22 @@ a") mte, (torch.zeros(1, 2, 3),), None, verbose=False, example_outputs=outputs, export_raw_ir=True)) + def test_onnx_export_script_non_alpha_add_sub(self): + class ModuleToExport(torch.jit.ScriptModule): + def __init__(self): + super(ModuleToExport, self).__init__() + + @torch.jit.script_method + def forward(self, x): + bs = x.size(0) + 1 + return bs - 1 + + mte = ModuleToExport() + outputs = torch.LongTensor([mte(torch.rand(3, 4))]) + self.assertExpected(torch.onnx.export_to_pretty_string( + mte, (torch.rand(3, 4),), None, verbose=False, + example_outputs=outputs)) + def test_onnx_export_script_module_if(self): class ModuleToExport(torch.jit.ScriptModule): def __init__(self): diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index d5b586c384..3f33430470 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -192,25 +192,22 @@ def unused(g): return g.op("prim::Undefined") -@parse_args('v', 'v', 't') -def add(g, self, other, alpha): - if _scalar(alpha) != 1: +def add(g, self, other, alpha=None): + # default alpha arg is to allow no-alpha add (aten add st overload no alpha) + if alpha and _scalar(_maybe_get_scalar(alpha)) != 1: return _unimplemented("add", "alpha != 1") # See Note [Pointwise by scalar] other = _maybe_get_scalar(other) return g.op("Add", self, _if_scalar_type_as(g, other, self)) -@parse_args('v', 'v', 't') -def sub(g, self, other, alpha): - if _scalar(alpha) != 1: +def sub(g, self, other, alpha=None): + # default alpha arg is to allow no-alpha sub (aten sub st overload no alpha) + if alpha and _scalar(_maybe_get_scalar(alpha)) != 1: return _unimplemented("sub", "alpha != 1") # See Note [Pointwise by scalar]. Note that self or other may be scalars. other = _maybe_get_scalar(other) - self = _maybe_get_scalar(self) - self = _if_scalar_type_as(g, self, other) - other = _if_scalar_type_as(g, other, self) - return g.op("Sub", self, other) + return g.op("Sub", self, _if_scalar_type_as(g, other, self)) def mul(g, self, other): |