summaryrefslogtreecommitdiff
path: root/caffe2
diff options
context:
space:
mode:
authorBenoit Steiner <benoitsteiner@fb.com>2019-03-28 08:52:01 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-03-28 08:55:21 -0700
commiteee760dbd341127cd9e33f04813532fbf9e63316 (patch)
tree40bda869fdf314101804fa3c12ca4e9c7937c6f8 /caffe2
parentffc7158bf2f97916305217e4203ef846c00161ce (diff)
downloadpytorch-eee760dbd341127cd9e33f04813532fbf9e63316.tar.gz
pytorch-eee760dbd341127cd9e33f04813532fbf9e63316.tar.bz2
pytorch-eee760dbd341127cd9e33f04813532fbf9e63316.zip
Improved onnx export for 3 onnx ops. (#18512)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18512 Ceil and Floor have been supported since version 6 of ONNX: export them using the native onnx ops instead of an Aten op. Similarly, support for the Where op has been added in version 9, so we don't need to wrap these op in an Aten op. Reviewed By: houseroad Differential Revision: D14635130 fbshipit-source-id: d54a2b6e295074a6214b5939b21051a6735c9958
Diffstat (limited to 'caffe2')
-rw-r--r--caffe2/onnx/backend.cc18
-rw-r--r--caffe2/onnx/backend.h2
-rw-r--r--caffe2/python/onnx/tests/onnx_backend_test.py1
3 files changed, 19 insertions, 2 deletions
diff --git a/caffe2/onnx/backend.cc b/caffe2/onnx/backend.cc
index e7c512a982..3564ebeaf6 100644
--- a/caffe2/onnx/backend.cc
+++ b/caffe2/onnx/backend.cc
@@ -362,7 +362,8 @@ Caffe2Backend::get_special_operators() const {
{"Dropout", &Caffe2Backend::CreateDropout},
{"LRN", &Caffe2Backend::CreateLRN},
{"DynamicSlice", &Caffe2Backend::CreateDynamicSlice},
- {"RandomNormal", &Caffe2Backend::CreateRandomNormal}};
+ {"RandomNormal", &Caffe2Backend::CreateRandomNormal},
+ {"Where", &Caffe2Backend::CreateWhereOp}};
return kSpecialOperators;
}
@@ -580,6 +581,21 @@ Caffe2Ops Caffe2Backend::CreateRandomNormal(
return CommonOnnxNodeToCaffe2Ops(onnx_node, ctx);
}
+Caffe2Ops Caffe2Backend::CreateWhereOp(
+ OnnxNode* onnx_node,
+ const ConversionContext& ctx) {
+ // The native Caffe2 op doesn't support broadcasting, so we defer the handling
+ // of this op to the ATen library that does.
+ onnx::NodeProto converted;
+ converted.CopyFrom(onnx_node->node);
+ converted.set_op_type("ATen");
+ onnx::AttributeProto* attr = converted.add_attribute();
+ attr->set_name("operator");
+ attr->set_s("where");
+ OnnxNode new_node(converted);
+ return CommonOnnxNodeToCaffe2Ops(&new_node, ctx);
+}
+
Caffe2Ops Caffe2Backend::CreateReciprocal(
OnnxNode* onnx_node,
const ConversionContext& ctx) {
diff --git a/caffe2/onnx/backend.h b/caffe2/onnx/backend.h
index d61af29f13..8ee33ef2ca 100644
--- a/caffe2/onnx/backend.h
+++ b/caffe2/onnx/backend.h
@@ -236,6 +236,8 @@ class CAFFE2_API Caffe2Backend {
OnnxNode* onnx_node,
const ConversionContext& ctx);
+ Caffe2Ops CreateWhereOp(OnnxNode* onnx_node, const ConversionContext& ctx);
+
Caffe2Ops CreateBatchNormalization(
OnnxNode* onnx_node,
const ConversionContext& ctx);
diff --git a/caffe2/python/onnx/tests/onnx_backend_test.py b/caffe2/python/onnx/tests/onnx_backend_test.py
index 75d4b5a9b8..f353e22d9d 100644
--- a/caffe2/python/onnx/tests/onnx_backend_test.py
+++ b/caffe2/python/onnx/tests/onnx_backend_test.py
@@ -52,7 +52,6 @@ backend_test.exclude(r'(test_hardsigmoid' # Does not support Hardsigmoid.
'|test_isnan.*' # Needs implementation
'|test_scatter.*' # Should be similar to ScatterAssign
'|test_constantofshape_int.*' # Needs implementation
- '|test_where.*' # Needs implementation
'|test_shrink.*' # Needs implementation
'|test_strnorm.*' # Needs implementation
'|test_nonzero.*' # Needs implementation