diff options
author | Yinghai Lu <yinghai@fb.com> | 2019-02-01 18:45:44 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-01 18:51:51 -0800 |
commit | 98b333d810a5e4863640033e4a37ae9aa9c22418 (patch) | |
tree | 8cb5cd5d0e5cbad8dc96e771975b2e17eb0cb4cd /caffe2/opt | |
parent | a4ac3cbb2f62aae6cfb2ca11fa67a5f2ea52bdf8 (diff) | |
download | pytorch-98b333d810a5e4863640033e4a37ae9aa9c22418.tar.gz pytorch-98b333d810a5e4863640033e4a37ae9aa9c22418.tar.bz2 pytorch-98b333d810a5e4863640033e4a37ae9aa9c22418.zip |
Tag model_id and onnxifi index in OnnxifiOp (#16648)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16648
We added onnxGraph sharing keyed on model id and net seq number but we forgot to supply these info to the Onnxifi. Therefore, we will only create ONE onnxGraph whatsoever... This diff adds necessary info to the OnnxifiOp to prevent this from happening.
Reviewed By: bertmaher, rdzhabarov
Differential Revision: D13912356
fbshipit-source-id: fe8982327287a35f32fe3b125d94b617d18c0ab5
Diffstat (limited to 'caffe2/opt')
-rw-r--r-- | caffe2/opt/onnxifi_transformer.cc | 28 | ||||
-rw-r--r-- | caffe2/opt/onnxifi_transformer.h | 6 |
2 files changed, 27 insertions, 7 deletions
diff --git a/caffe2/opt/onnxifi_transformer.cc b/caffe2/opt/onnxifi_transformer.cc index 3aeafcd19b..f1820c64fc 100644 --- a/caffe2/opt/onnxifi_transformer.cc +++ b/caffe2/opt/onnxifi_transformer.cc @@ -18,20 +18,26 @@ namespace caffe2 { namespace { const std::string kNetPos("net_pos"); -const std::string kNetId("model_id"); +const std::string kModelId("model_id"); constexpr size_t kBufferSize = 64; -// TODO: We probably don't want use protobuf as annotation in the future. -void AnnotateOpIndex(NetDef* net) { +void AnnotateOpIndex(NetDef* net, const std::string& net_id) { int i = 0; - auto net_id = - ArgumentHelper(*net).GetSingleArgument<std::string>("model_id", ""); for (auto& op : *(net->mutable_op())) { AddArgument(kNetPos, i++, &op); - AddArgument(kNetId, net_id, &op); } } +std::string GetModelId(const NetDef& net) { + static std::atomic<size_t> seq_id{0}; + auto model_id = + ArgumentHelper(net).GetSingleArgument<std::string>("model_id", ""); + if (model_id.empty()) { + model_id = "unnamed_" + c10::to_string(seq_id++); + } + return model_id; +} + // Wrap TensorShape into TensorProto TensorProto WrapShapeIntoTensorProto( const std::string& name, @@ -278,6 +284,10 @@ OperatorDef OnnxifiTransformer::BuildOnnxifiOp( // Tell Onnxifi op which backend id to use AddArgument("backend_id", idx_, &op); + // Add model_id and net_pos to the onnxifi model + AddArgument(kModelId, model_id_, &op); + AddArgument(kNetPos, c10::to_string(onnxifi_op_id_++), &op); + return op; } @@ -482,7 +492,7 @@ CaffeMap<std::string, TensorShape> OnnxifiTransformer::SsaRewriteAndMapNames( const std::unordered_map<std::string, TensorShape>& input_shape_hints) { input_mapping_ = onnx::SsaRewrite(nullptr, pred_net); // Annote the ops with net position - AnnotateOpIndex(pred_net); + AnnotateOpIndex(pred_net, model_id_); std::vector<std::string> external_inputs; for (const auto kv : input_mapping_) { reverse_input_mapping_.emplace(kv.second, kv.first); @@ -723,6 +733,10 @@ void OnnxifiTransformer::Transform( CAFFE_ENFORCE(ws); CAFFE_ENFORCE(pred_net, "Predict net cannot be nullptr"); + // Get model id and reset Onnxifi op id to 0 + model_id_ = GetModelId(*pred_net); + onnxifi_op_id_ = 0; + // SSA Rewrite the net auto shape_hints_ordered = SsaRewriteAndMapNames(ws, pred_net, input_shape_hints); diff --git a/caffe2/opt/onnxifi_transformer.h b/caffe2/opt/onnxifi_transformer.h index dc2b9e9e81..4f08328710 100644 --- a/caffe2/opt/onnxifi_transformer.h +++ b/caffe2/opt/onnxifi_transformer.h @@ -103,6 +103,12 @@ class CAFFE2_API OnnxifiTransformer final { // backend idx int idx_{0}; + // Number of Onnxifi Ops we build so far + int onnxifi_op_id_{0}; + + // Model id + std::string model_id_; + // Backned IDs std::vector<onnxBackendID> backend_ids_; |