diff options
author | Kimish Patel <kimishpatel@fb.com> | 2019-02-11 14:32:30 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-11 14:55:31 -0800 |
commit | 4292d13240e23a4a343b4ccb153214ab11c8d255 (patch) | |
tree | d7a32939aec67de42ae4d2699fe0ce38f17e182f /caffe2/opt | |
parent | 917eac91f4d8040eb51480fdfe1f75e99c33ac1a (diff) | |
download | pytorch-4292d13240e23a4a343b4ccb153214ab11c8d255.tar.gz pytorch-4292d13240e23a4a343b4ccb153214ab11c8d255.tar.bz2 pytorch-4292d13240e23a4a343b4ccb153214ab11c8d255.zip |
Keep weights name unchanged during SsaRewrite (#16932)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16932
During onnxifi transformation net ssa is rewritten. At the last step the weight
names are changed back to what they were before. The diff keeps the weight
names unchanged thru the process.
Reviewed By: yinghai
Differential Revision: D13972597
fbshipit-source-id: 7c29857f788a674edf625c073b345f2b44267b33
Diffstat (limited to 'caffe2/opt')
-rw-r--r-- | caffe2/opt/onnxifi_transformer.cc | 36 | ||||
-rw-r--r-- | caffe2/opt/onnxifi_transformer.h | 2 |
2 files changed, 20 insertions, 18 deletions
diff --git a/caffe2/opt/onnxifi_transformer.cc b/caffe2/opt/onnxifi_transformer.cc index 607d6f07a1..99f9b39924 100644 --- a/caffe2/opt/onnxifi_transformer.cc +++ b/caffe2/opt/onnxifi_transformer.cc @@ -715,11 +715,23 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaOnnx( CaffeMap<std::string, TensorShape> OnnxifiTransformer::SsaRewriteAndMapNames( Workspace* ws, NetDef* pred_net, + const std::unordered_set<std::string>& weights, const std::unordered_map<std::string, TensorShape>& input_shape_hints) { - input_mapping_ = onnx::SsaRewrite(nullptr, pred_net); + // Make sure weights do not contain output of any op. + for (const auto& op : pred_net->op()) { + for (const auto& output : op.output()) { + CAFFE_ENFORCE_EQ(weights.count(output), 0); + } + } + input_mapping_ = onnx::SsaRewrite(nullptr, pred_net, weights); // Annote the ops with net position AnnotateOpIndex(pred_net); std::vector<std::string> external_inputs; + // Need to add mapping for weights. This will be used to create new workspace + // with mapped weights. + for (const auto& w : weights) { + input_mapping_.emplace(w, w); + } for (const auto kv : input_mapping_) { reverse_input_mapping_.emplace(kv.second, kv.first); if (!ws->HasBlob(kv.second)) { @@ -966,6 +978,7 @@ void OnnxifiTransformer::Transform( Workspace* ws, NetDef* pred_net, const std::vector<std::string>& external_inputs, + const std::vector<std::string>& weight_names, const std::unordered_map<std::string, TensorShape>& input_shape_hints, const std::unordered_set<int>& blacklisted_ops) { CAFFE_ENFORCE(ws); @@ -975,9 +988,12 @@ void OnnxifiTransformer::Transform( model_id_ = GetModelId(*pred_net); onnxifi_op_id_ = 0; + std::unordered_set<std::string> weights( + weight_names.begin(), weight_names.end()); + // SSA Rewrite the net auto shape_hints_ordered = - SsaRewriteAndMapNames(ws, pred_net, input_shape_hints); + SsaRewriteAndMapNames(ws, pred_net, weights, input_shape_hints); // Populate shape info Workspace mapped_ws(ws, input_mapping_); @@ -988,22 +1004,6 @@ void OnnxifiTransformer::Transform( opts_.infer_shapes, opts_.bound_shape_spec); - // Figure out what are the weights - std::unordered_set<std::string> weights; - std::unordered_set<std::string> input_set; - for (const auto& i : external_inputs) { - const auto it = reverse_input_mapping_.find(i); - if (it != reverse_input_mapping_.end()) { - input_set.emplace(it->second); - } - } - const std::vector<string>& ws_blobs = mapped_ws.Blobs(); - for (const auto& s : ws_blobs) { - if (!input_set.count(s)) { - weights.emplace(s); - } - } - // Transform the net NetDef net_opt = opts_.use_onnx ? TransformViaOnnx(ws, pred_net, weights, blacklisted_ops, &shape_hints) diff --git a/caffe2/opt/onnxifi_transformer.h b/caffe2/opt/onnxifi_transformer.h index a7ba90a0f9..a1a8cd8eea 100644 --- a/caffe2/opt/onnxifi_transformer.h +++ b/caffe2/opt/onnxifi_transformer.h @@ -42,6 +42,7 @@ class CAFFE2_API OnnxifiTransformer final { Workspace* ws, NetDef* pred_net, const std::vector<std::string>& external_inputs, + const std::vector<std::string>& weight_names, const std::unordered_map<std::string, TensorShape>& shape_hints, const std::unordered_set<int>& blacklisted_ops); @@ -85,6 +86,7 @@ class CAFFE2_API OnnxifiTransformer final { CaffeMap<std::string, TensorShape> SsaRewriteAndMapNames( Workspace* ws, NetDef* pred_net, + const std::unordered_set<std::string>& weights, const std::unordered_map<std::string, TensorShape>& input_shape_hints); // Transform by passing C2 proto to backend |