diff options
author | Junjie Bai <jbai@fb.com> | 2019-03-28 10:18:46 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-03-28 10:26:34 -0700 |
commit | 77280b11e35c747c9ffc01c97509761aefcae37b (patch) | |
tree | c3befc6daa9a789c4f9b996fdb2781e8b4be7e64 | |
parent | eee760dbd341127cd9e33f04813532fbf9e63316 (diff) | |
download | pytorch-77280b11e35c747c9ffc01c97509761aefcae37b.tar.gz pytorch-77280b11e35c747c9ffc01c97509761aefcae37b.tar.bz2 pytorch-77280b11e35c747c9ffc01c97509761aefcae37b.zip |
Revert D14635130: Improved onnx export for 3 onnx ops.
Differential Revision:
D14635130
Original commit changeset: d54a2b6e2950
fbshipit-source-id: f624e2befdde245cb88435a95508b2a8e6b12e61
-rw-r--r-- | caffe2/onnx/backend.cc | 18 | ||||
-rw-r--r-- | caffe2/onnx/backend.h | 2 | ||||
-rw-r--r-- | caffe2/python/onnx/tests/onnx_backend_test.py | 1 | ||||
-rw-r--r-- | torch/onnx/symbolic.py | 10 |
4 files changed, 3 insertions, 28 deletions
diff --git a/caffe2/onnx/backend.cc b/caffe2/onnx/backend.cc index 3564ebeaf6..e7c512a982 100644 --- a/caffe2/onnx/backend.cc +++ b/caffe2/onnx/backend.cc @@ -362,8 +362,7 @@ Caffe2Backend::get_special_operators() const { {"Dropout", &Caffe2Backend::CreateDropout}, {"LRN", &Caffe2Backend::CreateLRN}, {"DynamicSlice", &Caffe2Backend::CreateDynamicSlice}, - {"RandomNormal", &Caffe2Backend::CreateRandomNormal}, - {"Where", &Caffe2Backend::CreateWhereOp}}; + {"RandomNormal", &Caffe2Backend::CreateRandomNormal}}; return kSpecialOperators; } @@ -581,21 +580,6 @@ Caffe2Ops Caffe2Backend::CreateRandomNormal( return CommonOnnxNodeToCaffe2Ops(onnx_node, ctx); } -Caffe2Ops Caffe2Backend::CreateWhereOp( - OnnxNode* onnx_node, - const ConversionContext& ctx) { - // The native Caffe2 op doesn't support broadcasting, so we defer the handling - // of this op to the ATen library that does. - onnx::NodeProto converted; - converted.CopyFrom(onnx_node->node); - converted.set_op_type("ATen"); - onnx::AttributeProto* attr = converted.add_attribute(); - attr->set_name("operator"); - attr->set_s("where"); - OnnxNode new_node(converted); - return CommonOnnxNodeToCaffe2Ops(&new_node, ctx); -} - Caffe2Ops Caffe2Backend::CreateReciprocal( OnnxNode* onnx_node, const ConversionContext& ctx) { diff --git a/caffe2/onnx/backend.h b/caffe2/onnx/backend.h index 8ee33ef2ca..d61af29f13 100644 --- a/caffe2/onnx/backend.h +++ b/caffe2/onnx/backend.h @@ -236,8 +236,6 @@ class CAFFE2_API Caffe2Backend { OnnxNode* onnx_node, const ConversionContext& ctx); - Caffe2Ops CreateWhereOp(OnnxNode* onnx_node, const ConversionContext& ctx); - Caffe2Ops CreateBatchNormalization( OnnxNode* onnx_node, const ConversionContext& ctx); diff --git a/caffe2/python/onnx/tests/onnx_backend_test.py b/caffe2/python/onnx/tests/onnx_backend_test.py index f353e22d9d..75d4b5a9b8 100644 --- a/caffe2/python/onnx/tests/onnx_backend_test.py +++ b/caffe2/python/onnx/tests/onnx_backend_test.py @@ -52,6 +52,7 @@ backend_test.exclude(r'(test_hardsigmoid' # Does not support Hardsigmoid. '|test_isnan.*' # Needs implementation '|test_scatter.*' # Should be similar to ScatterAssign '|test_constantofshape_int.*' # Needs implementation + '|test_where.*' # Needs implementation '|test_shrink.*' # Needs implementation '|test_strnorm.*' # Needs implementation '|test_nonzero.*' # Needs implementation diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index 9a1911f451..fbb8d9765a 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -548,14 +548,6 @@ def relu(g, input): return g.op("Relu", input) -def ceil(g, input): - return g.op("Ceil", input) - - -def floor(g, input): - return g.op("Floor", input) - - @parse_args('v', 't', 't') def threshold(g, self, threshold, value): # See Note [Export inplace] @@ -930,7 +922,7 @@ def le(g, input, other): def where(g, condition, self, other): - return g.op("Where", condition, self, other) + return g.op("ATen", condition, self, other, operator_s="where") @parse_args('v', 'i', 'i') |