summaryrefslogtreecommitdiff
path: root/caffe2/opt
diff options
context:
space:
mode:
authorKimish Patel <kimishpatel@fb.com>2019-02-11 14:32:30 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-11 14:55:31 -0800
commit4292d13240e23a4a343b4ccb153214ab11c8d255 (patch)
treed7a32939aec67de42ae4d2699fe0ce38f17e182f /caffe2/opt
parent917eac91f4d8040eb51480fdfe1f75e99c33ac1a (diff)
downloadpytorch-4292d13240e23a4a343b4ccb153214ab11c8d255.tar.gz
pytorch-4292d13240e23a4a343b4ccb153214ab11c8d255.tar.bz2
pytorch-4292d13240e23a4a343b4ccb153214ab11c8d255.zip
Keep weights name unchanged during SsaRewrite (#16932)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16932 During onnxifi transformation net ssa is rewritten. At the last step the weight names are changed back to what they were before. The diff keeps the weight names unchanged thru the process. Reviewed By: yinghai Differential Revision: D13972597 fbshipit-source-id: 7c29857f788a674edf625c073b345f2b44267b33
Diffstat (limited to 'caffe2/opt')
-rw-r--r--caffe2/opt/onnxifi_transformer.cc36
-rw-r--r--caffe2/opt/onnxifi_transformer.h2
2 files changed, 20 insertions, 18 deletions
diff --git a/caffe2/opt/onnxifi_transformer.cc b/caffe2/opt/onnxifi_transformer.cc
index 607d6f07a1..99f9b39924 100644
--- a/caffe2/opt/onnxifi_transformer.cc
+++ b/caffe2/opt/onnxifi_transformer.cc
@@ -715,11 +715,23 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaOnnx(
CaffeMap<std::string, TensorShape> OnnxifiTransformer::SsaRewriteAndMapNames(
Workspace* ws,
NetDef* pred_net,
+ const std::unordered_set<std::string>& weights,
const std::unordered_map<std::string, TensorShape>& input_shape_hints) {
- input_mapping_ = onnx::SsaRewrite(nullptr, pred_net);
+ // Make sure weights do not contain output of any op.
+ for (const auto& op : pred_net->op()) {
+ for (const auto& output : op.output()) {
+ CAFFE_ENFORCE_EQ(weights.count(output), 0);
+ }
+ }
+ input_mapping_ = onnx::SsaRewrite(nullptr, pred_net, weights);
// Annote the ops with net position
AnnotateOpIndex(pred_net);
std::vector<std::string> external_inputs;
+ // Need to add mapping for weights. This will be used to create new workspace
+ // with mapped weights.
+ for (const auto& w : weights) {
+ input_mapping_.emplace(w, w);
+ }
for (const auto kv : input_mapping_) {
reverse_input_mapping_.emplace(kv.second, kv.first);
if (!ws->HasBlob(kv.second)) {
@@ -966,6 +978,7 @@ void OnnxifiTransformer::Transform(
Workspace* ws,
NetDef* pred_net,
const std::vector<std::string>& external_inputs,
+ const std::vector<std::string>& weight_names,
const std::unordered_map<std::string, TensorShape>& input_shape_hints,
const std::unordered_set<int>& blacklisted_ops) {
CAFFE_ENFORCE(ws);
@@ -975,9 +988,12 @@ void OnnxifiTransformer::Transform(
model_id_ = GetModelId(*pred_net);
onnxifi_op_id_ = 0;
+ std::unordered_set<std::string> weights(
+ weight_names.begin(), weight_names.end());
+
// SSA Rewrite the net
auto shape_hints_ordered =
- SsaRewriteAndMapNames(ws, pred_net, input_shape_hints);
+ SsaRewriteAndMapNames(ws, pred_net, weights, input_shape_hints);
// Populate shape info
Workspace mapped_ws(ws, input_mapping_);
@@ -988,22 +1004,6 @@ void OnnxifiTransformer::Transform(
opts_.infer_shapes,
opts_.bound_shape_spec);
- // Figure out what are the weights
- std::unordered_set<std::string> weights;
- std::unordered_set<std::string> input_set;
- for (const auto& i : external_inputs) {
- const auto it = reverse_input_mapping_.find(i);
- if (it != reverse_input_mapping_.end()) {
- input_set.emplace(it->second);
- }
- }
- const std::vector<string>& ws_blobs = mapped_ws.Blobs();
- for (const auto& s : ws_blobs) {
- if (!input_set.count(s)) {
- weights.emplace(s);
- }
- }
-
// Transform the net
NetDef net_opt = opts_.use_onnx
? TransformViaOnnx(ws, pred_net, weights, blacklisted_ops, &shape_hints)
diff --git a/caffe2/opt/onnxifi_transformer.h b/caffe2/opt/onnxifi_transformer.h
index a7ba90a0f9..a1a8cd8eea 100644
--- a/caffe2/opt/onnxifi_transformer.h
+++ b/caffe2/opt/onnxifi_transformer.h
@@ -42,6 +42,7 @@ class CAFFE2_API OnnxifiTransformer final {
Workspace* ws,
NetDef* pred_net,
const std::vector<std::string>& external_inputs,
+ const std::vector<std::string>& weight_names,
const std::unordered_map<std::string, TensorShape>& shape_hints,
const std::unordered_set<int>& blacklisted_ops);
@@ -85,6 +86,7 @@ class CAFFE2_API OnnxifiTransformer final {
CaffeMap<std::string, TensorShape> SsaRewriteAndMapNames(
Workspace* ws,
NetDef* pred_net,
+ const std::unordered_set<std::string>& weights,
const std::unordered_map<std::string, TensorShape>& input_shape_hints);
// Transform by passing C2 proto to backend