summaryrefslogtreecommitdiff
path: root/caffe2/onnx
diff options
context:
space:
mode:
authorYinghai Lu <yinghai@fb.com>2018-03-26 19:23:45 -0700
committerGitHub <noreply@github.com>2018-03-26 19:23:45 -0700
commitb6e80a1ec407e866ba154c00b942338d1ef5228a (patch)
treec57e281dc0cafe115b1162858d88ee372ab46a29 /caffe2/onnx
parentb6b2edb96fcf350d4ec31e6cc90b14de94f90790 (diff)
downloadpytorch-b6e80a1ec407e866ba154c00b942338d1ef5228a.tar.gz
pytorch-b6e80a1ec407e866ba154c00b942338d1ef5228a.tar.bz2
pytorch-b6e80a1ec407e866ba154c00b942338d1ef5228a.zip
Caffe2-onnx exporter (#2248)
* caffe2-onnx frontend * Remove Python part of the conversion code * nit * convert more ops * Address commmetns
Diffstat (limited to 'caffe2/onnx')
-rw-r--r--caffe2/onnx/backend.cc5
-rw-r--r--caffe2/onnx/helper.cc22
-rw-r--r--caffe2/onnx/helper.h76
-rw-r--r--caffe2/onnx/onnx_exporter.cc592
-rw-r--r--caffe2/onnx/onnx_exporter.h102
5 files changed, 784 insertions, 13 deletions
diff --git a/caffe2/onnx/backend.cc b/caffe2/onnx/backend.cc
index 852d3df2be..e776922196 100644
--- a/caffe2/onnx/backend.cc
+++ b/caffe2/onnx/backend.cc
@@ -19,6 +19,7 @@
#include "caffe2/onnx/backend.h"
#include "caffe2/onnx/device.h"
#include "caffe2/onnx/helper.h"
+#include "caffe2/utils/map_utils.h"
#if !CAFFE2_MOBILE
#include "onnx/checker.h"
@@ -866,7 +867,7 @@ Caffe2Ops Caffe2Backend::CommonOnnxNodeToCaffe2Ops(
c2_op->set_name(node.name());
const auto onnx_op_type = node.op_type();
- auto broken_version = LookUpWithDefault(
+ auto broken_version = caffe2::get_default(
get_broken_operators(), onnx_op_type, std::numeric_limits<int>::max());
if (broken_version <= opset_version) {
CAFFE_THROW(
@@ -878,7 +879,7 @@ Caffe2Ops Caffe2Backend::CommonOnnxNodeToCaffe2Ops(
broken_version);
}
c2_op->set_type(
- LookUpWithDefault(get_renamed_operators(), onnx_op_type, onnx_op_type));
+ caffe2::get_default(get_renamed_operators(), onnx_op_type, onnx_op_type));
if (!IsOperator(c2_op->type())) {
CAFFE_THROW(
"Don't know how to translate op ", onnx_op_type);
diff --git a/caffe2/onnx/helper.cc b/caffe2/onnx/helper.cc
index f37f6e5dfc..7b5cd5893d 100644
--- a/caffe2/onnx/helper.cc
+++ b/caffe2/onnx/helper.cc
@@ -44,4 +44,26 @@ void DummyName::Reset(const std::unordered_set<std::string> &used_names) {
counter_ = 0;
}
+NodeProto MakeNode(
+ const std::string& type,
+ const std::vector<std::string>& inputs,
+ const std::vector<std::string>& outputs,
+ const std::vector<AttributeProto>& attributes,
+ const std::string& name) {
+ NodeProto node;
+ if (!name.empty()) {
+ node.set_name(name);
+ }
+ node.set_op_type(type);
+ for (const auto& input: inputs) {
+ node.add_input(input);
+ }
+ for (const auto& output: outputs) {
+ node.add_output(output);
+ }
+ for (const auto& attr: attributes) {
+ node.add_attribute()->CopyFrom(attr);
+ }
+ return node;
+}
}}
diff --git a/caffe2/onnx/helper.h b/caffe2/onnx/helper.h
index e210cdee89..065502ec3a 100644
--- a/caffe2/onnx/helper.h
+++ b/caffe2/onnx/helper.h
@@ -16,24 +16,78 @@
#pragma once
+#include "onnx/onnx_pb.h"
+
#include <set>
#include <string>
#include <unordered_set>
-namespace caffe2 { namespace onnx {
+namespace caffe2 {
+namespace onnx {
+
+using ::ONNX_NAMESPACE::AttributeProto;
+using ::ONNX_NAMESPACE::NodeProto;
+
+// \brief This class generates unique dummy names
class DummyName {
- public:
- static std::string NewDummyName();
+ public:
+ static std::string NewDummyName();
- static void Reset(const std::unordered_set<std::string>& used_names);
+ static void Reset(const std::unordered_set<std::string>& used_names);
- static void AddName(const std::string& new_used) {
- get_used_names().insert(new_used);
- }
+ static void AddName(const std::string& new_used) {
+ get_used_names().insert(new_used);
+ }
- private:
- static std::unordered_set<std::string>& get_used_names();
- static size_t counter_;
+ private:
+ static std::unordered_set<std::string>& get_used_names();
+ static size_t counter_;
};
-}}
+inline AttributeProto MakeAttribute(
+ const std::string& name,
+ const std::vector<int64_t>& vals) {
+ AttributeProto attr;
+ attr.set_name(name);
+ for (const auto v : vals) {
+ attr.add_ints(v);
+ }
+ attr.set_type(AttributeProto::INTS);
+ return attr;
+}
+
+inline AttributeProto MakeAttribute(const std::string& name, int64_t val) {
+ AttributeProto attr;
+ attr.set_name(name);
+ attr.set_i(val);
+ attr.set_type(AttributeProto::INT);
+ return attr;
+}
+
+inline AttributeProto MakeAttribute(
+ const std::string& name,
+ const std::string& val) {
+ AttributeProto attr;
+ attr.set_name(name);
+ attr.set_s(val);
+ attr.set_type(AttributeProto::STRING);
+ return attr;
+}
+
+NodeProto MakeNode(
+ const std::string& type,
+ const std::vector<std::string>& inputs,
+ const std::vector<std::string>& outputs,
+ const std::vector<AttributeProto>& attributes,
+ const std::string& name = "");
+
+inline NodeProto MakeNode(
+ const std::string& type,
+ const std::vector<std::string>& inputs,
+ const std::vector<std::string>& outputs,
+ const std::string& name = "") {
+ return MakeNode(type, inputs, outputs, {}, name);
+}
+
+} // namespace onnx
+} // namespace caffe2
diff --git a/caffe2/onnx/onnx_exporter.cc b/caffe2/onnx/onnx_exporter.cc
new file mode 100644
index 0000000000..3b0393c1b9
--- /dev/null
+++ b/caffe2/onnx/onnx_exporter.cc
@@ -0,0 +1,592 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "caffe2/core/logging.h"
+#include "caffe2/onnx/onnx_exporter.h"
+#include "caffe2/onnx/helper.h"
+#include "caffe2/proto/caffe2_legacy.pb.h"
+#include "caffe2/utils/map_utils.h"
+
+#include <unordered_set>
+
+namespace caffe2 {
+namespace onnx {
+
+namespace {
+// rewrite padding attributes
+void ApplyTrans(
+ std::unordered_map<std::string, AttributeProto>* attrs,
+ bool global,
+ const std::string& k,
+ int dim = 2,
+ const std::string& ks = "") {
+ std::string ks2 = ks.empty() ? (k + "s") : ks;
+ std::string k_h, k_w, k_t, k_l, k_b, k_r;
+ if (dim == 2) {
+ k_h = k + "_h";
+ k_w = k + "_w";
+ } else {
+ k_t = k + "_t";
+ k_l = k + "_l";
+ k_b = k + "_b";
+ k_r = k + "_r";
+ }
+
+ std::vector<int64_t> vals;
+ if (dim == 2 && attrs->count(k_h) && attrs->count(k_w)) {
+ auto it = attrs->find(k_h);
+ vals.push_back(it->second.i());
+ attrs->erase(it);
+ it = attrs->find(k_w);
+ vals.push_back(it->second.i());
+ attrs->erase(it);
+ } else if (
+ dim == 4 && attrs->count(k_t) && attrs->count(k_b) && attrs->count(k_l) &&
+ attrs->count(k_r)) {
+ auto it = attrs->find(k_t);
+ vals.push_back(it->second.i());
+ attrs->erase(it);
+ it = attrs->find(k_l);
+ vals.push_back(it->second.i());
+ attrs->erase(it);
+ it = attrs->find(k_b);
+ vals.push_back(it->second.i());
+ attrs->erase(it);
+ it = attrs->find(k_r);
+ vals.push_back(it->second.i());
+ attrs->erase(it);
+ } else if (attrs->count(k)) {
+ auto it = attrs->find(k);
+ auto tmp = it->second.i();
+ for (int i = 0; i < dim; ++i) {
+ vals.push_back(tmp);
+ }
+ attrs->erase(it);
+ }
+
+ if (!vals.empty() && !global) {
+ attrs->emplace(ks2, MakeAttribute(ks2, vals));
+ }
+}
+
+int64_t DimProd(const caffe2::TensorShape& shape, int start, int end) {
+ int64_t acc = 1;
+ for (int i = start; i < end; ++i) {
+ acc *= shape.dims(i);
+ }
+ return acc;
+}
+
+TensorProto CreateOnnxShapeTensor(const std::vector<int64_t>& shape) {
+ TensorProto tensor;
+ tensor.set_name(DummyName::NewDummyName());
+ tensor.set_data_type(TensorProto::INT64);
+ tensor.add_dims(shape.size());
+ tensor.mutable_raw_data()->assign(
+ reinterpret_cast<const char*>(shape.data()), sizeof(int64_t) * shape.size());
+ return tensor;
+}
+} // namespace
+
+const std::unordered_map<std::string, std::string>&
+OnnxExporter::get_renamed_operators() const {
+ const static std::unordered_map<std::string, std::string> kRenamedOperators{
+ {"SpatialBN", "BatchNormalization"},
+ {"Conv1D", "Conv"},
+ {"Conv2D", "Conv"},
+ {"Conv3D", "Conv"},
+ {"ConvTranspose1D", "ConvTranspose"},
+ {"ConvTranspose2D", "ConvTranspose"},
+ {"ConvTranspose3D", "ConvTranspose"},
+ {"MaxPool1D", "MaxPool"},
+ {"MaxPool2D", "MaxPool"},
+ {"MaxPool3D", "MaxPool"},
+ {"AveragePool1D", "AveragePool"},
+ {"AveragePool2D", "AveragePool"},
+ {"AveragePool3D", "AveragePool"}};
+ return kRenamedOperators;
+}
+
+const std::unordered_map<std::string, std::string>&
+OnnxExporter::get_renamed_attrs() const {
+ const static std::unordered_map<std::string, std::string> kRenamedAttrs{
+ {"kernels", "kernel_shape"}};
+ return kRenamedAttrs;
+}
+
+const std::
+ unordered_map<std::string, std::unordered_map<std::string, std::string>>&
+ OnnxExporter::get_per_op_renamed_attrs() const {
+ const static std::
+ unordered_map<std::string, std::unordered_map<std::string, std::string>>
+ kPerOpRenamedAttrs = {{"Squeeze", {{"dims", "axes"}}},
+ {"Unsqueeze", {{"dims", "axes"}}},
+ {"Transpose", {{"axes", "perm"}}},
+ {"ConvTranspose", {{"adjs", "output_padding"}}},
+ {"Selu", {{"scale", "gamma"}}}};
+
+ return kPerOpRenamedAttrs;
+}
+
+const std::unordered_map<std::string, OnnxExporter::SpecialOpConverter>&
+OnnxExporter::get_special_operators() const {
+ const static std::unordered_map<std::string, OnnxExporter::SpecialOpConverter>
+ kSpecialOperators = {
+ {"Conv", &OnnxExporter::CreateConvPoolNodes},
+ {"ConvTranspose", &OnnxExporter::CreateConvPoolNodes},
+ {"MaxPool", &OnnxExporter::CreateConvPoolNodes},
+ {"AveragePool", &OnnxExporter::CreateConvPoolNodes},
+ {"FC", &OnnxExporter::CreateGemmNodes},
+ {"Concat", &OnnxExporter::CreateConcatNodes},
+ {"LRN", &OnnxExporter::CreateLrnNodes},
+ {"Reshape", &OnnxExporter::CreateReshapeNodes},
+ {"Slice", &OnnxExporter::CreateSliceNodes},
+ {"ChannelShuffle", &OnnxExporter::CreateChannelShuffleNodes}
+ };
+ return kSpecialOperators;
+}
+
+void OnnxExporter::CopyCaffe2ArgToOnnxAttr(
+ AttributeProto* attr,
+ const std::string& op_type,
+ const caffe2::Argument& arg) {
+ std::string name;
+ const auto& per_op_renamed_attr_lut = get_per_op_renamed_attrs();
+ const auto it = per_op_renamed_attr_lut.find(op_type);
+ if (it != per_op_renamed_attr_lut.end()) {
+ name = caffe2::get_default(it->second, arg.name(), arg.name());
+ } else {
+ name = caffe2::get_default(get_renamed_attrs(), arg.name(), arg.name());
+ }
+ attr->set_name(name);
+
+ if (arg.has_f()) {
+ attr->set_f(arg.f());
+ attr->set_type(AttributeProto::FLOAT);
+ } else if (arg.has_i()) {
+ attr->set_i(arg.i());
+ attr->set_type(AttributeProto::INT);
+ } else if (arg.has_s()) {
+ attr->set_s(arg.s());
+ attr->set_type(AttributeProto::STRING);
+ } else if (arg.floats_size()) {
+ attr->mutable_floats()->CopyFrom(arg.floats());
+ attr->set_type(AttributeProto::STRINGS);
+ } else if (arg.ints_size()) {
+ attr->mutable_ints()->CopyFrom(arg.ints());
+ attr->set_type(AttributeProto::INTS);
+ } else if (arg.strings_size()) {
+ attr->mutable_strings()->CopyFrom(arg.strings());
+ attr->set_type(AttributeProto::STRINGS);
+ } else {
+ CAFFE_THROW(
+ caffe2::MakeString("Unsupported Caffe2 argument: ", arg.name()));
+ }
+}
+
+bool OnnxExporter::IsBlackListed(const caffe2::Argument& arg) {
+ const static std::unordered_map<std::string, std::unordered_set<std::string>>
+ kBlackListString = {{"order", {"NCHW"}}};
+ const static std::unordered_map<std::string, std::unordered_set<int64_t>>
+ kBlackListInt = {{"cudnn_exhaustive_search", {0, 1}},
+ {"use_cudnn", {0, 1}}};
+
+ if (arg.has_i()) {
+ const auto it = kBlackListInt.find(arg.name());
+ if (it != kBlackListInt.end()) {
+ return it->second.count(arg.i());
+ }
+ } else if (arg.has_s()) {
+ const auto it = kBlackListString.find(arg.name());
+ if (it != kBlackListString.end()) {
+ return it->second.count(arg.s());
+ }
+ }
+
+ return false;
+}
+
+ConvertedResult OnnxExporter::Caffe2OpToOnnxNodes(
+ const caffe2::OperatorDef& def,
+ const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
+ std::string type = def.type();
+ const auto& renamed_op_lut = get_renamed_operators();
+ const auto it = renamed_op_lut.find(type);
+ if (it != renamed_op_lut.end()) {
+ type = it->second;
+ }
+ const auto& special_op_lut = get_special_operators();
+ const auto it_op = get_special_operators().find(type);
+ if (it_op != special_op_lut.end()) {
+ return (this->*(it_op->second))(def, shapes);
+ } else {
+ return CommonCaffe2OpToOnnxNodes(def);
+ }
+}
+
+ConvertedResult OnnxExporter::CommonCaffe2OpToOnnxNodes(
+ const caffe2::OperatorDef& def) {
+ ConvertedResult result;
+ auto& nodes = result.first;
+ nodes.emplace_back();
+ NodeProto& node = nodes.back();
+ node.set_name(def.name());
+ node.set_op_type(
+ caffe2::get_default(get_renamed_operators(), def.type(), def.type()));
+ for (const auto& i : def.input()) {
+ node.add_input(i);
+ }
+ for (const auto& o : def.output()) {
+ node.add_output(o);
+ }
+ for (const auto& a : def.arg()) {
+ if (!IsBlackListed(a)) {
+ auto* attr = node.add_attribute();
+ CopyCaffe2ArgToOnnxAttr(attr, def.type(), a);
+ }
+ }
+ return result;
+}
+
+ConvertedResult OnnxExporter::CreateConvPoolNodes(
+ const caffe2::OperatorDef& def,
+ const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
+ auto result = CommonCaffe2OpToOnnxNodes(def);
+ auto& nodes = result.first;
+ auto& node = nodes.back();
+
+ std::unordered_map<std::string, AttributeProto> attrs;
+ for (const auto& attr : node.attribute()) {
+ attrs.emplace(attr.name(), attr);
+ }
+
+ // Handle global pooling
+ bool global = false;
+ if (node.op_type() == "MaxPool" || node.op_type() == "AveragePool") {
+ auto it = attrs.find("global_pooling");
+ if (it != attrs.end() && it->second.has_i() && it->second.i()) {
+ node.set_op_type("Global" + node.op_type());
+ global = true;
+ attrs.erase(it);
+ }
+ }
+
+ ApplyTrans(&attrs, global, "kernel", 2, "kernel_shape");
+ ApplyTrans(&attrs, global, "stride");
+ ApplyTrans(&attrs, global, "dilation");
+ ApplyTrans(&attrs, global, "adj");
+ ApplyTrans(&attrs, global, "pad", 4);
+
+ // Fix legacy pad attr
+ auto it = attrs.find("legacy_pad");
+ if (it != attrs.end()) {
+ auto legacy_pad_attr = it->second;
+ attrs.erase(it);
+ CAFFE_ENFORCE(
+ node.op_type().size() >= 4 &&
+ (node.op_type().rfind("Pool") == node.op_type().size() - 4));
+ CAFFE_ENFORCE(!global);
+ const auto& input_size = shapes.at(node.input(0));
+ const auto& output_size = shapes.at(node.output(0));
+ CAFFE_ENFORCE(output_size.dims().size() == 4);
+ if (legacy_pad_attr.i() ==
+ static_cast<int64_t>(caffe2::LegacyPadding::VALID)) {
+ CAFFE_ENFORCE(!attrs.count("pads"));
+ attrs.emplace("auto_pad", MakeAttribute("auto_pad", "VALID"));
+ } else if (
+ legacy_pad_attr.i() ==
+ static_cast<int64_t>(caffe2::LegacyPadding::SAME)) {
+ CAFFE_ENFORCE(!attrs.count("pads"));
+ // default behavior in Caffe2 is SAME_UPPER
+ // https://github.com/caffe2/caffe2/blob/master/caffe2/operators/conv_pool_op_base.h#L39
+ attrs.emplace("auto_pad", MakeAttribute("auto_pad", "SAME_UPPER"));
+ } else if (
+ legacy_pad_attr.i() ==
+ static_cast<int64_t>(caffe2::LegacyPadding::CAFFE_LEGACY_POOLING)) {
+ // The problem here is that, Pool op in Caffe may add an additional pixel,
+ // if the last part is smaller than stride. So we use the explicit padding
+ // to replace legacy_pad. pad[end] = output_size[start + 2] *
+ // stride[start] - pad[start] - 1 + kernel[start] - input[start + 2] end =
+ // start + len(pad) / 2
+ LOG(WARNING) << "Converting legacy padding to explicit padding.";
+ auto* pads_attr = attrs.at("pads").mutable_ints();
+ auto& strides_attr = attrs.at("strides").ints();
+ auto& kernel_shape_attr = attrs.at("kernel_shape").ints();
+ for (int i = 0; i < 2; ++i) {
+ int64_t tmp_pad = output_size.dims(i + 2) * strides_attr.Get(i) -
+ pads_attr->Get(i) - 1 + kernel_shape_attr.Get(i) -
+ input_size.dims(i + 2);
+ pads_attr->Set(i + 2, tmp_pad);
+ }
+ } else if (
+ legacy_pad_attr.i() !=
+ static_cast<int64_t>(caffe2::LegacyPadding::NOTSET)) {
+ CAFFE_THROW(caffe2::MakeString(
+ "Don't know how to handle the legacy_pad, while processing operator: ",
+ def.type()));
+ }
+ }
+
+ node.clear_attribute();
+ for (const auto& kv : attrs) {
+ auto* attr = node.add_attribute();
+ attr->CopyFrom(kv.second);
+ }
+
+ return result;
+}
+
+ConvertedResult OnnxExporter::CreateLrnNodes(
+ const caffe2::OperatorDef& def,
+ const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
+ auto result = CommonCaffe2OpToOnnxNodes(def);
+ auto& nodes = result.first;
+
+ CAFFE_ENFORCE_EQ(nodes.size(), 1);
+ auto& node = nodes.back();
+ if (node.output_size() == 2) {
+ node.mutable_output()->RemoveLast();
+ }
+
+ return result;
+}
+
+ConvertedResult OnnxExporter::CreateConcatNodes(
+ const caffe2::OperatorDef& def,
+ const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
+ auto result = CommonCaffe2OpToOnnxNodes(def);
+ auto& nodes = result.first;
+
+ CAFFE_ENFORCE_EQ(nodes.size(), 1);
+ auto& node = nodes.back();
+ if (node.output_size() == 2) {
+ node.mutable_output()->RemoveLast();
+ }
+
+ bool explicit_axis = false;
+ for (const auto& a: def.arg()) {
+ if (a.name() == "axis") {
+ explicit_axis = true;
+ break;
+ }
+ }
+ if (!explicit_axis) {
+ node.add_attribute()->CopyFrom(MakeAttribute("axis", 1L));
+ }
+
+ return result;
+}
+
+ConvertedResult OnnxExporter::CreateChannelShuffleNodes(
+ const caffe2::OperatorDef& def,
+ const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
+ const auto& x = def.input(0);
+ const auto& y = def.output(0);
+ const auto& x_shape = shapes.at(x);
+ CAFFE_ENFORCE_EQ(
+ x_shape.dims().size(),
+ 4,
+ "Input shape of ChannelShuffle needs to be in NCHW format");
+ auto n = x_shape.dims(0);
+ auto c = x_shape.dims(1);
+ auto h = x_shape.dims(2);
+ auto w = x_shape.dims(3);
+ int64_t g = 0;
+ for (const auto& arg: def.arg()) {
+ if (arg.name() == "group") {
+ g = arg.i();
+ break;
+ }
+ }
+ CAFFE_ENFORCE(g && c % g == 0);
+ ConvertedResult result;
+ auto& nodes = result.first;
+ auto& const_tensors = result.second;
+
+ const auto reshape_output = DummyName::NewDummyName();
+ std::vector<int64_t> dims = {n, g, c / g, h, w};
+ const_tensors.emplace_back(CreateOnnxShapeTensor(dims));
+ nodes.emplace_back(
+ MakeNode("Reshape", {x, const_tensors.back().name()}, {reshape_output}));
+
+ const auto transpose_output = DummyName::NewDummyName();
+ dims = {0, 2, 1, 3, 4};
+ nodes.emplace_back(MakeNode(
+ "Transpose",
+ {reshape_output},
+ {transpose_output},
+ {MakeAttribute("perm", dims)}));
+
+ dims = {n, c, h, w};
+ const_tensors.emplace_back(CreateOnnxShapeTensor(dims));
+ nodes.emplace_back(MakeNode(
+ "Reshape", {transpose_output, const_tensors.back().name()}, {y}));
+
+ return result;
+}
+
+ConvertedResult OnnxExporter::CreateSliceNodes(
+ const caffe2::OperatorDef& def,
+ const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
+ CAFFE_ENFORCE_EQ(
+ def.input_size(),
+ 1,
+ "ONNX Slice operator does not support dynamic slice.");
+ auto result = CommonCaffe2OpToOnnxNodes(def);
+ auto& nodes = result.first;
+ CAFFE_ENFORCE_EQ(nodes.size(), 1);
+ auto& node = nodes.back();
+ const auto& shape = shapes.at(node.input(0));
+
+ std::vector<int64_t> dims;
+ for (auto& attr: *node.mutable_attribute()) {
+ if (attr.name() == "starts") {
+ auto len = attr.ints_size();
+ if (len) {
+ dims.resize(len);
+ std::iota(dims.begin(), dims.end(), 0);
+ }
+ } else if (attr.name() == "ends") {
+ for (int i = 0; i < attr.ints_size(); ++i) {
+ auto end = attr.ints(i);
+ if (end >=0) {
+ continue;
+ }
+ if (end == -1) {
+ end = shape.dims(i);
+ } else {
+ ++end;
+ }
+ attr.set_ints(i, end);
+ }
+ }
+ }
+ if (!dims.empty()) {
+ node.add_attribute()->CopyFrom(MakeAttribute("axes", dims));
+ }
+
+ return result;
+}
+
+ConvertedResult OnnxExporter::CreateReshapeNodes(
+ const caffe2::OperatorDef& def,
+ const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
+ auto result = CommonCaffe2OpToOnnxNodes(def);
+ auto& nodes = result.first;
+ auto& const_tensors = result.second;
+ CAFFE_ENFORCE_EQ(nodes.size(), 1);
+ auto& node = nodes.back();
+
+ int i = 0;
+ int attr_size = node.attribute_size();
+ for (; i < attr_size; ++i) {
+ const auto& attr = node.attribute(i);
+ if (attr.name() == "shape") {
+ std::vector<int64_t> shape;
+ for (const auto k: attr.ints()) {
+ shape.push_back(k);
+ }
+ const_tensors.emplace_back(CreateOnnxShapeTensor(shape));
+ node.add_input(const_tensors.back().name());
+ break;
+ }
+ }
+ if (i != attr_size) {
+ if (i != attr_size - 1) {
+ node.mutable_attribute()->SwapElements(i, attr_size - 1);
+ }
+ node.mutable_attribute()->RemoveLast();
+ }
+
+ if (node.output_size() == 2) {
+ node.mutable_output()->RemoveLast();
+ }
+
+ return result;
+}
+
+ConvertedResult OnnxExporter::CreateGemmNodes(
+ const caffe2::OperatorDef& def,
+ const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
+ CAFFE_ENFORCE_EQ(def.input_size(), 3);
+ CAFFE_ENFORCE_GE(def.output_size(), 1);
+ auto x = def.input(0);
+ auto w = def.input(1);
+ const auto& b = def.input(2);
+ const auto& y = def.output(0);
+ const auto& x_shape = shapes.at(x);
+
+ ConvertedResult result;
+ auto& nodes = result.first;
+ auto& const_tensors = result.second;
+ std::unordered_map<std::string, const caffe2::Argument*> args;
+ for (const auto& a : def.arg()) {
+ args.emplace(a.name(), &a);
+ }
+
+ auto it = args.find("axis");
+ bool has_axis = (it != args.end());
+ int64_t axis = 0;
+ if (has_axis) {
+ axis = it->second->i();
+ auto outer = DimProd(x_shape, 0, axis);
+ auto inner = DimProd(x_shape, axis, x_shape.dims().size());
+ std::vector<int64_t> dims = {outer, inner};
+ auto reshaped_x = DummyName::NewDummyName();
+ const_tensors.emplace_back(CreateOnnxShapeTensor(dims));
+ nodes.emplace_back(
+ MakeNode("Reshape", {x, const_tensors.back().name()}, {reshaped_x}));
+ x = reshaped_x;
+ }
+
+ it = args.find("axis_w");
+ if (it != args.end()) {
+ auto axis_w = it->second->i();
+ const auto& w_shape = shapes.at(w);
+ auto outer = DimProd(w_shape, 0, axis_w);
+ auto inner = DimProd(w_shape, axis_w, w_shape.dims().size());
+ std::vector<int64_t> dims = {outer, inner};
+ auto reshaped_w = DummyName::NewDummyName();
+ const_tensors.emplace_back(CreateOnnxShapeTensor(dims));
+ nodes.emplace_back(
+ MakeNode("Reshape", {w, const_tensors.back().name()}, {reshaped_w}));
+ w = reshaped_w;
+ }
+
+ auto gemm_y_output = (has_axis) ? DummyName::NewDummyName() : y;
+ nodes.emplace_back(MakeNode(
+ "Gemm",
+ {x, w, b},
+ {gemm_y_output},
+ {MakeAttribute("transB", 1L), MakeAttribute("broadcast", 1)},
+ def.name()));
+
+ if (has_axis) {
+ std::vector<int64_t> dims;
+ for (int i = 0; i < axis; ++i) {
+ dims.push_back(x_shape.dims(i));
+ }
+ dims.push_back(-1);
+ const_tensors.emplace_back(CreateOnnxShapeTensor(dims));
+ nodes.emplace_back(
+ MakeNode("Reshape", {gemm_y_output, const_tensors.back().name()}, {y}));
+ }
+
+ return result;
+}
+} // namespace onnx
+} // namespace caffe2
+
diff --git a/caffe2/onnx/onnx_exporter.h b/caffe2/onnx/onnx_exporter.h
new file mode 100644
index 0000000000..4eb3f3ff3c
--- /dev/null
+++ b/caffe2/onnx/onnx_exporter.h
@@ -0,0 +1,102 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "caffe2/core/common.h"
+#include "caffe2/proto/caffe2.pb.h"
+#include "onnx/onnx_pb.h"
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+namespace caffe2 {
+namespace onnx {
+
+namespace {
+using ::ONNX_NAMESPACE::AttributeProto;
+using ::ONNX_NAMESPACE::GraphProto;
+using ::ONNX_NAMESPACE::ModelProto;
+using ::ONNX_NAMESPACE::NodeProto;
+using ::ONNX_NAMESPACE::TensorProto;
+using ConvertedResult =
+ std::pair<std::vector<NodeProto>, std::vector<TensorProto>>;
+} // namespace
+
+class OnnxExporter {
+ using SpecialOpConverter = ConvertedResult (OnnxExporter::*)(
+ const caffe2::OperatorDef&,
+ const std::unordered_map<std::string, caffe2::TensorShape>&);
+
+ public:
+ ConvertedResult Caffe2OpToOnnxNodes(
+ const caffe2::OperatorDef& def,
+ const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
+
+ private:
+ ConvertedResult CommonCaffe2OpToOnnxNodes(const caffe2::OperatorDef& def);
+
+ ConvertedResult CreateConvPoolNodes(
+ const caffe2::OperatorDef& def,
+ const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
+
+ ConvertedResult CreateGemmNodes(
+ const caffe2::OperatorDef& def,
+ const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
+
+ ConvertedResult CreateReshapeNodes(
+ const caffe2::OperatorDef& def,
+ const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
+
+ ConvertedResult CreateSliceNodes(
+ const caffe2::OperatorDef& def,
+ const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
+
+ ConvertedResult CreateChannelShuffleNodes(
+ const caffe2::OperatorDef& def,
+ const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
+
+ ConvertedResult CreateConcatNodes(
+ const caffe2::OperatorDef& def,
+ const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
+
+ ConvertedResult CreateLrnNodes(
+ const caffe2::OperatorDef& def,
+ const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
+
+ // \brief Check black listed arguemnts where we won't pass down when
+ // converting to ONNX node
+ bool IsBlackListed(const caffe2::Argument& arg);
+
+ // \brief Convert Caffe2 argument to Onnx attribute
+ void CopyCaffe2ArgToOnnxAttr(
+ AttributeProto* attr,
+ const std::string& op_type,
+ const caffe2::Argument& arg);
+
+ // LUT getters
+ const std::unordered_map<std::string, std::string>& get_renamed_operators()
+ const;
+ const std::unordered_map<std::string, std::string>& get_renamed_attrs() const;
+ const std::
+ unordered_map<std::string, std::unordered_map<std::string, std::string>>&
+ get_per_op_renamed_attrs() const;
+ const std::unordered_map<std::string, OnnxExporter::SpecialOpConverter>&
+ get_special_operators() const;
+};
+} // namespace onnx
+} // namespace caffe2