summaryrefslogtreecommitdiff
path: root/caffe2/opt
diff options
context:
space:
mode:
authorYinghai Lu <yinghai@fb.com>2019-02-01 18:45:44 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-01 18:51:51 -0800
commit98b333d810a5e4863640033e4a37ae9aa9c22418 (patch)
tree8cb5cd5d0e5cbad8dc96e771975b2e17eb0cb4cd /caffe2/opt
parenta4ac3cbb2f62aae6cfb2ca11fa67a5f2ea52bdf8 (diff)
downloadpytorch-98b333d810a5e4863640033e4a37ae9aa9c22418.tar.gz
pytorch-98b333d810a5e4863640033e4a37ae9aa9c22418.tar.bz2
pytorch-98b333d810a5e4863640033e4a37ae9aa9c22418.zip
Tag model_id and onnxifi index in OnnxifiOp (#16648)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16648 We added onnxGraph sharing keyed on model id and net seq number but we forgot to supply these info to the Onnxifi. Therefore, we will only create ONE onnxGraph whatsoever... This diff adds necessary info to the OnnxifiOp to prevent this from happening. Reviewed By: bertmaher, rdzhabarov Differential Revision: D13912356 fbshipit-source-id: fe8982327287a35f32fe3b125d94b617d18c0ab5
Diffstat (limited to 'caffe2/opt')
-rw-r--r--caffe2/opt/onnxifi_transformer.cc28
-rw-r--r--caffe2/opt/onnxifi_transformer.h6
2 files changed, 27 insertions, 7 deletions
diff --git a/caffe2/opt/onnxifi_transformer.cc b/caffe2/opt/onnxifi_transformer.cc
index 3aeafcd19b..f1820c64fc 100644
--- a/caffe2/opt/onnxifi_transformer.cc
+++ b/caffe2/opt/onnxifi_transformer.cc
@@ -18,20 +18,26 @@ namespace caffe2 {
namespace {
const std::string kNetPos("net_pos");
-const std::string kNetId("model_id");
+const std::string kModelId("model_id");
constexpr size_t kBufferSize = 64;
-// TODO: We probably don't want use protobuf as annotation in the future.
-void AnnotateOpIndex(NetDef* net) {
+void AnnotateOpIndex(NetDef* net, const std::string& net_id) {
int i = 0;
- auto net_id =
- ArgumentHelper(*net).GetSingleArgument<std::string>("model_id", "");
for (auto& op : *(net->mutable_op())) {
AddArgument(kNetPos, i++, &op);
- AddArgument(kNetId, net_id, &op);
}
}
+std::string GetModelId(const NetDef& net) {
+ static std::atomic<size_t> seq_id{0};
+ auto model_id =
+ ArgumentHelper(net).GetSingleArgument<std::string>("model_id", "");
+ if (model_id.empty()) {
+ model_id = "unnamed_" + c10::to_string(seq_id++);
+ }
+ return model_id;
+}
+
// Wrap TensorShape into TensorProto
TensorProto WrapShapeIntoTensorProto(
const std::string& name,
@@ -278,6 +284,10 @@ OperatorDef OnnxifiTransformer::BuildOnnxifiOp(
// Tell Onnxifi op which backend id to use
AddArgument("backend_id", idx_, &op);
+ // Add model_id and net_pos to the onnxifi model
+ AddArgument(kModelId, model_id_, &op);
+ AddArgument(kNetPos, c10::to_string(onnxifi_op_id_++), &op);
+
return op;
}
@@ -482,7 +492,7 @@ CaffeMap<std::string, TensorShape> OnnxifiTransformer::SsaRewriteAndMapNames(
const std::unordered_map<std::string, TensorShape>& input_shape_hints) {
input_mapping_ = onnx::SsaRewrite(nullptr, pred_net);
// Annote the ops with net position
- AnnotateOpIndex(pred_net);
+ AnnotateOpIndex(pred_net, model_id_);
std::vector<std::string> external_inputs;
for (const auto kv : input_mapping_) {
reverse_input_mapping_.emplace(kv.second, kv.first);
@@ -723,6 +733,10 @@ void OnnxifiTransformer::Transform(
CAFFE_ENFORCE(ws);
CAFFE_ENFORCE(pred_net, "Predict net cannot be nullptr");
+ // Get model id and reset Onnxifi op id to 0
+ model_id_ = GetModelId(*pred_net);
+ onnxifi_op_id_ = 0;
+
// SSA Rewrite the net
auto shape_hints_ordered =
SsaRewriteAndMapNames(ws, pred_net, input_shape_hints);
diff --git a/caffe2/opt/onnxifi_transformer.h b/caffe2/opt/onnxifi_transformer.h
index dc2b9e9e81..4f08328710 100644
--- a/caffe2/opt/onnxifi_transformer.h
+++ b/caffe2/opt/onnxifi_transformer.h
@@ -103,6 +103,12 @@ class CAFFE2_API OnnxifiTransformer final {
// backend idx
int idx_{0};
+ // Number of Onnxifi Ops we build so far
+ int onnxifi_op_id_{0};
+
+ // Model id
+ std::string model_id_;
+
// Backned IDs
std::vector<onnxBackendID> backend_ids_;