diff options
author | Yinghai Lu <yinghai@fb.com> | 2018-04-17 21:23:27 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-04-17 21:23:27 -0700 |
commit | 6252706feb23e448bff944c3505092b18718d3ab (patch) | |
tree | 199e070bbbfc0a84287d3f4eb990dace343a8c48 /caffe2/onnx | |
parent | dc94182db0f320f09894a99eb169fe83631e8440 (diff) | |
download | pytorch-6252706feb23e448bff944c3505092b18718d3ab.tar.gz pytorch-6252706feb23e448bff944c3505092b18718d3ab.tar.bz2 pytorch-6252706feb23e448bff944c3505092b18718d3ab.zip |
[Caffe2] Workspace centric API for TensorRT transformation (#6678)
* Workspace centric API for trt transformation
* Merge SSA rewrite code
Diffstat (limited to 'caffe2/onnx')
-rw-r--r-- | caffe2/onnx/backend.cc | 2 | ||||
-rw-r--r-- | caffe2/onnx/onnx_exporter.cc | 33 | ||||
-rw-r--r-- | caffe2/onnx/onnx_exporter.h | 9 | ||||
-rw-r--r-- | caffe2/onnx/ssa_test.cc | 6 |
4 files changed, 33 insertions, 17 deletions
diff --git a/caffe2/onnx/backend.cc b/caffe2/onnx/backend.cc index 561c13d081..f835bece40 100644 --- a/caffe2/onnx/backend.cc +++ b/caffe2/onnx/backend.cc @@ -405,7 +405,7 @@ Caffe2Ops Caffe2Backend::CreateCast(OnnxNode* onnx_node, int opset_version) { "' dtype is not supported"); CAFFE_ENFORCE_EQ( - c2_op.ops[0].arg().size(), + c2_op.ops.Get(0).arg().size(), 1, "Unexpected number of attributes in 'Cast'"); c2_op.ops.Mutable(0)->mutable_arg(0)->set_i(c2_dtype); diff --git a/caffe2/onnx/onnx_exporter.cc b/caffe2/onnx/onnx_exporter.cc index 58f4dafe69..e1c1cc98c7 100644 --- a/caffe2/onnx/onnx_exporter.cc +++ b/caffe2/onnx/onnx_exporter.cc @@ -90,12 +90,10 @@ std::string SsaName(const std::string& n, int version) { } } // namespace -std::pair< - std::unordered_map<std::string, std::string>, - std::unordered_map<std::string, std::string>> -SsaRewrite(caffe2::NetDef* init_net, caffe2::NetDef* pred_net) { +std::unordered_map<std::string, std::string> SsaRewrite( + caffe2::NetDef* init_net, + caffe2::NetDef* pred_net) { std::unordered_map<std::string, std::string> input_mapping; - std::unordered_map<std::string, std::string> output_mapping; std::unordered_map<std::string, int> blob_versions; #define REWRITE_EXTERNAL_IO(net, name) \ @@ -121,7 +119,6 @@ SsaRewrite(caffe2::NetDef* init_net, caffe2::NetDef* pred_net) { blob_versions.emplace(output, 0); } REWRITE_EXTERNAL_IO(init_net, input); - REWRITE_EXTERNAL_IO(init_net, output); blob_versions.clear(); } @@ -151,11 +148,31 @@ SsaRewrite(caffe2::NetDef* init_net, caffe2::NetDef* pred_net) { } } } - REWRITE_EXTERNAL_IO(pred_net, output); + + // Fix the external output name back to original + std::unordered_set<std::string> external_outputs; + for (const auto& output : pred_net->external_output()) { + external_outputs.emplace(output); + } + for (auto& op : *pred_net->mutable_op()) { + for (auto& output : *op.mutable_output()) { + auto pos = output.find_last_of('_'); + CAFFE_ENFORCE_NE(pos, 0); + auto basename = output.substr(0, pos); + if (!external_outputs.count(basename)) { + continue; + } + auto it = blob_versions.find(basename); + if (it != blob_versions.end() && + SsaName(basename, it->second) == output) { + output = basename; + } + } + } } #undef REWRITE_EXTERNAL_IO - return std::make_pair(std::move(input_mapping), std::move(output_mapping)); + return input_mapping; } const std::unordered_map<std::string, std::string>& diff --git a/caffe2/onnx/onnx_exporter.h b/caffe2/onnx/onnx_exporter.h index 8c997c065a..dd11006915 100644 --- a/caffe2/onnx/onnx_exporter.h +++ b/caffe2/onnx/onnx_exporter.h @@ -23,10 +23,11 @@ using ::ONNX_NAMESPACE::TensorProto; using ConvertedResult = std::pair<std::vector<NodeProto>, std::vector<TensorProto>>; -std::pair< - std::unordered_map<std::string, std::string>, - std::unordered_map<std::string, std::string>> -SsaRewrite(caffe2::NetDef* init_net, caffe2::NetDef* pred_net); +// Rewrite Caffe2 nets into SSA forms. Notice that we will preserve the external +// output names for predict net. +std::unordered_map<std::string, std::string> SsaRewrite( + caffe2::NetDef* init_net, + caffe2::NetDef* pred_net); class OnnxExporter { using SpecialOpConverter = ConvertedResult (OnnxExporter::*)( diff --git a/caffe2/onnx/ssa_test.cc b/caffe2/onnx/ssa_test.cc index db3ba451b8..9f83b8f486 100644 --- a/caffe2/onnx/ssa_test.cc +++ b/caffe2/onnx/ssa_test.cc @@ -22,9 +22,7 @@ TEST(SsaTest, ConvReluInplace) { net.add_external_input("X"); net.add_external_output("Y"); - std::unordered_map<std::string, std::string> input_mapping; - std::unordered_map<std::string, std::string> output_mapping; - std::tie(input_mapping, output_mapping) = + std::unordered_map<std::string, std::string> input_mapping = caffe2::onnx::SsaRewrite(nullptr, &net); for (const auto& op : net.op()) { std::unordered_set<std::string> inputs; @@ -37,5 +35,5 @@ TEST(SsaTest, ConvReluInplace) { } EXPECT_EQ(net.op(0).output(0), net.op(1).input(0)); EXPECT_EQ("X", input_mapping.at(net.external_input(0))); - EXPECT_EQ("Y", output_mapping.at(net.external_output(0))); + EXPECT_EQ("Y", net.external_output(0)); } |