summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJunjie Bai <jbai@fb.com>2019-03-28 10:18:46 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-03-28 10:26:34 -0700
commit77280b11e35c747c9ffc01c97509761aefcae37b (patch)
treec3befc6daa9a789c4f9b996fdb2781e8b4be7e64
parenteee760dbd341127cd9e33f04813532fbf9e63316 (diff)
downloadpytorch-77280b11e35c747c9ffc01c97509761aefcae37b.tar.gz
pytorch-77280b11e35c747c9ffc01c97509761aefcae37b.tar.bz2
pytorch-77280b11e35c747c9ffc01c97509761aefcae37b.zip
Revert D14635130: Improved onnx export for 3 onnx ops.
Differential Revision: D14635130 Original commit changeset: d54a2b6e2950 fbshipit-source-id: f624e2befdde245cb88435a95508b2a8e6b12e61
-rw-r--r--caffe2/onnx/backend.cc18
-rw-r--r--caffe2/onnx/backend.h2
-rw-r--r--caffe2/python/onnx/tests/onnx_backend_test.py1
-rw-r--r--torch/onnx/symbolic.py10
4 files changed, 3 insertions, 28 deletions
diff --git a/caffe2/onnx/backend.cc b/caffe2/onnx/backend.cc
index 3564ebeaf6..e7c512a982 100644
--- a/caffe2/onnx/backend.cc
+++ b/caffe2/onnx/backend.cc
@@ -362,8 +362,7 @@ Caffe2Backend::get_special_operators() const {
{"Dropout", &Caffe2Backend::CreateDropout},
{"LRN", &Caffe2Backend::CreateLRN},
{"DynamicSlice", &Caffe2Backend::CreateDynamicSlice},
- {"RandomNormal", &Caffe2Backend::CreateRandomNormal},
- {"Where", &Caffe2Backend::CreateWhereOp}};
+ {"RandomNormal", &Caffe2Backend::CreateRandomNormal}};
return kSpecialOperators;
}
@@ -581,21 +580,6 @@ Caffe2Ops Caffe2Backend::CreateRandomNormal(
return CommonOnnxNodeToCaffe2Ops(onnx_node, ctx);
}
-Caffe2Ops Caffe2Backend::CreateWhereOp(
- OnnxNode* onnx_node,
- const ConversionContext& ctx) {
- // The native Caffe2 op doesn't support broadcasting, so we defer the handling
- // of this op to the ATen library that does.
- onnx::NodeProto converted;
- converted.CopyFrom(onnx_node->node);
- converted.set_op_type("ATen");
- onnx::AttributeProto* attr = converted.add_attribute();
- attr->set_name("operator");
- attr->set_s("where");
- OnnxNode new_node(converted);
- return CommonOnnxNodeToCaffe2Ops(&new_node, ctx);
-}
-
Caffe2Ops Caffe2Backend::CreateReciprocal(
OnnxNode* onnx_node,
const ConversionContext& ctx) {
diff --git a/caffe2/onnx/backend.h b/caffe2/onnx/backend.h
index 8ee33ef2ca..d61af29f13 100644
--- a/caffe2/onnx/backend.h
+++ b/caffe2/onnx/backend.h
@@ -236,8 +236,6 @@ class CAFFE2_API Caffe2Backend {
OnnxNode* onnx_node,
const ConversionContext& ctx);
- Caffe2Ops CreateWhereOp(OnnxNode* onnx_node, const ConversionContext& ctx);
-
Caffe2Ops CreateBatchNormalization(
OnnxNode* onnx_node,
const ConversionContext& ctx);
diff --git a/caffe2/python/onnx/tests/onnx_backend_test.py b/caffe2/python/onnx/tests/onnx_backend_test.py
index f353e22d9d..75d4b5a9b8 100644
--- a/caffe2/python/onnx/tests/onnx_backend_test.py
+++ b/caffe2/python/onnx/tests/onnx_backend_test.py
@@ -52,6 +52,7 @@ backend_test.exclude(r'(test_hardsigmoid' # Does not support Hardsigmoid.
'|test_isnan.*' # Needs implementation
'|test_scatter.*' # Should be similar to ScatterAssign
'|test_constantofshape_int.*' # Needs implementation
+ '|test_where.*' # Needs implementation
'|test_shrink.*' # Needs implementation
'|test_strnorm.*' # Needs implementation
'|test_nonzero.*' # Needs implementation
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py
index 9a1911f451..fbb8d9765a 100644
--- a/torch/onnx/symbolic.py
+++ b/torch/onnx/symbolic.py
@@ -548,14 +548,6 @@ def relu(g, input):
return g.op("Relu", input)
-def ceil(g, input):
- return g.op("Ceil", input)
-
-
-def floor(g, input):
- return g.op("Floor", input)
-
-
@parse_args('v', 't', 't')
def threshold(g, self, threshold, value):
# See Note [Export inplace]
@@ -930,7 +922,7 @@ def le(g, input, other):
def where(g, condition, self, other):
- return g.op("Where", condition, self, other)
+ return g.op("ATen", condition, self, other, operator_s="where")
@parse_args('v', 'i', 'i')