summaryrefslogtreecommitdiff
path: root/caffe2/onnx
diff options
context:
space:
mode:
authorYinghai Lu <yinghai@fb.com>2018-04-17 21:23:27 -0700
committerGitHub <noreply@github.com>2018-04-17 21:23:27 -0700
commit6252706feb23e448bff944c3505092b18718d3ab (patch)
tree199e070bbbfc0a84287d3f4eb990dace343a8c48 /caffe2/onnx
parentdc94182db0f320f09894a99eb169fe83631e8440 (diff)
downloadpytorch-6252706feb23e448bff944c3505092b18718d3ab.tar.gz
pytorch-6252706feb23e448bff944c3505092b18718d3ab.tar.bz2
pytorch-6252706feb23e448bff944c3505092b18718d3ab.zip
[Caffe2] Workspace centric API for TensorRT transformation (#6678)
* Workspace centric API for trt transformation * Merge SSA rewrite code
Diffstat (limited to 'caffe2/onnx')
-rw-r--r--caffe2/onnx/backend.cc2
-rw-r--r--caffe2/onnx/onnx_exporter.cc33
-rw-r--r--caffe2/onnx/onnx_exporter.h9
-rw-r--r--caffe2/onnx/ssa_test.cc6
4 files changed, 33 insertions, 17 deletions
diff --git a/caffe2/onnx/backend.cc b/caffe2/onnx/backend.cc
index 561c13d081..f835bece40 100644
--- a/caffe2/onnx/backend.cc
+++ b/caffe2/onnx/backend.cc
@@ -405,7 +405,7 @@ Caffe2Ops Caffe2Backend::CreateCast(OnnxNode* onnx_node, int opset_version) {
"' dtype is not supported");
CAFFE_ENFORCE_EQ(
- c2_op.ops[0].arg().size(),
+ c2_op.ops.Get(0).arg().size(),
1,
"Unexpected number of attributes in 'Cast'");
c2_op.ops.Mutable(0)->mutable_arg(0)->set_i(c2_dtype);
diff --git a/caffe2/onnx/onnx_exporter.cc b/caffe2/onnx/onnx_exporter.cc
index 58f4dafe69..e1c1cc98c7 100644
--- a/caffe2/onnx/onnx_exporter.cc
+++ b/caffe2/onnx/onnx_exporter.cc
@@ -90,12 +90,10 @@ std::string SsaName(const std::string& n, int version) {
}
} // namespace
-std::pair<
- std::unordered_map<std::string, std::string>,
- std::unordered_map<std::string, std::string>>
-SsaRewrite(caffe2::NetDef* init_net, caffe2::NetDef* pred_net) {
+std::unordered_map<std::string, std::string> SsaRewrite(
+ caffe2::NetDef* init_net,
+ caffe2::NetDef* pred_net) {
std::unordered_map<std::string, std::string> input_mapping;
- std::unordered_map<std::string, std::string> output_mapping;
std::unordered_map<std::string, int> blob_versions;
#define REWRITE_EXTERNAL_IO(net, name) \
@@ -121,7 +119,6 @@ SsaRewrite(caffe2::NetDef* init_net, caffe2::NetDef* pred_net) {
blob_versions.emplace(output, 0);
}
REWRITE_EXTERNAL_IO(init_net, input);
- REWRITE_EXTERNAL_IO(init_net, output);
blob_versions.clear();
}
@@ -151,11 +148,31 @@ SsaRewrite(caffe2::NetDef* init_net, caffe2::NetDef* pred_net) {
}
}
}
- REWRITE_EXTERNAL_IO(pred_net, output);
+
+ // Fix the external output name back to original
+ std::unordered_set<std::string> external_outputs;
+ for (const auto& output : pred_net->external_output()) {
+ external_outputs.emplace(output);
+ }
+ for (auto& op : *pred_net->mutable_op()) {
+ for (auto& output : *op.mutable_output()) {
+ auto pos = output.find_last_of('_');
+ CAFFE_ENFORCE_NE(pos, 0);
+ auto basename = output.substr(0, pos);
+ if (!external_outputs.count(basename)) {
+ continue;
+ }
+ auto it = blob_versions.find(basename);
+ if (it != blob_versions.end() &&
+ SsaName(basename, it->second) == output) {
+ output = basename;
+ }
+ }
+ }
}
#undef REWRITE_EXTERNAL_IO
- return std::make_pair(std::move(input_mapping), std::move(output_mapping));
+ return input_mapping;
}
const std::unordered_map<std::string, std::string>&
diff --git a/caffe2/onnx/onnx_exporter.h b/caffe2/onnx/onnx_exporter.h
index 8c997c065a..dd11006915 100644
--- a/caffe2/onnx/onnx_exporter.h
+++ b/caffe2/onnx/onnx_exporter.h
@@ -23,10 +23,11 @@ using ::ONNX_NAMESPACE::TensorProto;
using ConvertedResult =
std::pair<std::vector<NodeProto>, std::vector<TensorProto>>;
-std::pair<
- std::unordered_map<std::string, std::string>,
- std::unordered_map<std::string, std::string>>
-SsaRewrite(caffe2::NetDef* init_net, caffe2::NetDef* pred_net);
+// Rewrite Caffe2 nets into SSA forms. Notice that we will preserve the external
+// output names for predict net.
+std::unordered_map<std::string, std::string> SsaRewrite(
+ caffe2::NetDef* init_net,
+ caffe2::NetDef* pred_net);
class OnnxExporter {
using SpecialOpConverter = ConvertedResult (OnnxExporter::*)(
diff --git a/caffe2/onnx/ssa_test.cc b/caffe2/onnx/ssa_test.cc
index db3ba451b8..9f83b8f486 100644
--- a/caffe2/onnx/ssa_test.cc
+++ b/caffe2/onnx/ssa_test.cc
@@ -22,9 +22,7 @@ TEST(SsaTest, ConvReluInplace) {
net.add_external_input("X");
net.add_external_output("Y");
- std::unordered_map<std::string, std::string> input_mapping;
- std::unordered_map<std::string, std::string> output_mapping;
- std::tie(input_mapping, output_mapping) =
+ std::unordered_map<std::string, std::string> input_mapping =
caffe2::onnx::SsaRewrite(nullptr, &net);
for (const auto& op : net.op()) {
std::unordered_set<std::string> inputs;
@@ -37,5 +35,5 @@ TEST(SsaTest, ConvReluInplace) {
}
EXPECT_EQ(net.op(0).output(0), net.op(1).input(0));
EXPECT_EQ("X", input_mapping.at(net.external_input(0)));
- EXPECT_EQ("Y", output_mapping.at(net.external_output(0)));
+ EXPECT_EQ("Y", net.external_output(0));
}