summaryrefslogtreecommitdiff
path: root/caffe2
diff options
context:
space:
mode:
authorYinghai Lu <yinghai@fb.com>2019-02-14 14:22:51 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-14 15:12:20 -0800
commitb515ebc6f11400d904339275aa5df12f1f8121b1 (patch)
tree9696946487cdac7729b5359de2d0013173f6892f /caffe2
parent0a5de6e9720eb9600c6424fd2d7a5df0b36d9703 (diff)
downloadpytorch-b515ebc6f11400d904339275aa5df12f1f8121b1.tar.gz
pytorch-b515ebc6f11400d904339275aa5df12f1f8121b1.tar.bz2
pytorch-b515ebc6f11400d904339275aa5df12f1f8121b1.zip
Remove fake inference for shape info in ONNXIFI transform (#17046)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17046 As we are moving to use bound shape inference, we can remove the awkward fake inference run path and make the code cleaner. Reviewed By: ipiszy Differential Revision: D14061501 fbshipit-source-id: b3ace98b3dabef3c3359086a0bb1410518cefa26
Diffstat (limited to 'caffe2')
-rw-r--r--caffe2/opt/onnxifi_transformer.cc96
-rw-r--r--caffe2/opt/onnxifi_transformer.h5
-rw-r--r--caffe2/python/onnx/onnxifi.py17
-rw-r--r--caffe2/python/pybind_state.cc4
4 files changed, 39 insertions, 83 deletions
diff --git a/caffe2/opt/onnxifi_transformer.cc b/caffe2/opt/onnxifi_transformer.cc
index 0fdbeb5b86..94e85a0c76 100644
--- a/caffe2/opt/onnxifi_transformer.cc
+++ b/caffe2/opt/onnxifi_transformer.cc
@@ -85,57 +85,36 @@ uint64_t OnnxifiDataType(caffe2::TensorProto::DataType t) {
#undef CAFFE2_TO_ONNXIFI_TYPE
}
-// TODO: Use ShapeInfo instead of shape
ShapeInfoMap InferShapes(
Workspace* ws,
NetDef* pred_net,
- CaffeMap<std::string, TensorShape>* shape_hints_ordered,
- bool infer_shapes,
+ std::unordered_map<std::string, TensorShape>* shape_hints_mapped,
const BoundShapeSpec& spec) {
ShapeInfoMap shape_map;
- if (infer_shapes) {
- // Populate shapes from workplace
- const std::vector<std::string> ws_blobs = ws->Blobs();
- for (const auto& s : ws_blobs) {
- auto shape_info = getShapeInfoFromBlob(ws->GetBlob(s));
- if (shape_info.dim_type != ShapeInfo::DimType::UNKNOWN) {
- shape_map[s] = shape_info;
- }
- }
- for (const auto& kv : *shape_hints_ordered) {
- shape_map.emplace(
- std::piecewise_construct,
- std::forward_as_tuple(kv.first),
- std::forward_as_tuple(ShapeInfo::DimType::CONSTANT, kv.second));
- }
- BoundShapeInferencer eng(spec);
- eng.InferBoundShapeAndType(*pred_net, shape_map);
- const auto& out_map = eng.shape_info();
-
- for (const auto& kv : out_map) {
- shape_map.emplace(
- std::piecewise_construct,
- std::forward_as_tuple(kv.first),
- std::forward_as_tuple(kv.second.dim_type, kv.second.shape));
- }
- } else {
- // TODO: deprecate this path
- Workspace ws_local(ws);
- ws_local.RunNetOnce(*pred_net);
- const std::vector<std::string> ws_blobs = ws_local.Blobs();
- for (const auto& s : ws_blobs) {
- const Blob* b = ws_local.GetBlob(s);
- auto shape = GetTensorShapeOfBlob(b);
- if (!shape.unknown_shape()) {
- shape_map.emplace(
- std::piecewise_construct,
- std::forward_as_tuple(s),
- std::forward_as_tuple(
- ShapeInfo::DimType::CONSTANT, std::move(shape)));
- }
+ // Populate shapes from workplace
+ const std::vector<std::string> ws_blobs = ws->Blobs();
+ for (const auto& s : ws_blobs) {
+ auto shape_info = getShapeInfoFromBlob(ws->GetBlob(s));
+ if (shape_info.dim_type != ShapeInfo::DimType::UNKNOWN) {
+ shape_map[s] = shape_info;
}
}
+ for (const auto& kv : *shape_hints_mapped) {
+ shape_map.emplace(
+ std::piecewise_construct,
+ std::forward_as_tuple(kv.first),
+ std::forward_as_tuple(ShapeInfo::DimType::CONSTANT, kv.second));
+ }
+ BoundShapeInferencer eng(spec);
+ eng.InferBoundShapeAndType(*pred_net, shape_map);
+ const auto& out_map = eng.shape_info();
+ for (const auto& kv : out_map) {
+ shape_map.emplace(
+ std::piecewise_construct,
+ std::forward_as_tuple(kv.first),
+ std::forward_as_tuple(kv.second.dim_type, kv.second.shape));
+ }
return shape_map;
}
@@ -724,7 +703,8 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaOnnx(
return net_opt;
}
-CaffeMap<std::string, TensorShape> OnnxifiTransformer::SsaRewriteAndMapNames(
+std::unordered_map<std::string, TensorShape>
+OnnxifiTransformer::SsaRewriteAndMapNames(
Workspace* ws,
NetDef* pred_net,
const std::unordered_set<std::string>& weights,
@@ -743,31 +723,36 @@ CaffeMap<std::string, TensorShape> OnnxifiTransformer::SsaRewriteAndMapNames(
input_mapping_ = onnx::SsaRewrite(nullptr, pred_net, weights);
// Annote the ops with net position
AnnotateOpIndex(pred_net);
- std::vector<std::string> external_inputs;
+
// Need to add mapping for weights. This will be used to create new workspace
// with mapped weights.
for (const auto& w : weights) {
input_mapping_.emplace(w, w);
}
+
+ // Since we are going to create a mapped workspace, we need to make sure that
+ // the parent workspace has the mapped blob names. If the blobs don't exist
+ // (usually such blobs are input tensor names), we exclude them from mapping.
+ std::vector<std::string> exclude_mapping;
for (const auto kv : input_mapping_) {
reverse_input_mapping_.emplace(kv.second, kv.first);
if (!ws->HasBlob(kv.second)) {
- external_inputs.emplace_back(kv.first);
+ exclude_mapping.emplace_back(kv.first);
}
}
- for (const auto& i : external_inputs) {
+ for (const auto& i : exclude_mapping) {
input_mapping_.erase(i);
}
- CaffeMap<std::string, TensorShape> shape_hints_ordered;
+ std::unordered_map<std::string, TensorShape> shape_hints_mapped;
for (const auto& kv : input_shape_hints) {
const auto it = reverse_input_mapping_.find(kv.first);
if (it != reverse_input_mapping_.end()) {
- shape_hints_ordered.emplace(it->second, kv.second);
+ shape_hints_mapped.emplace(it->second, kv.second);
} else {
- shape_hints_ordered.emplace(kv.first, kv.second);
+ shape_hints_mapped.emplace(kv.first, kv.second);
}
}
- return shape_hints_ordered;
+ return shape_hints_mapped;
}
NetDef OnnxifiTransformer::TransformViaC2(
@@ -996,7 +981,6 @@ NetDef OnnxifiTransformer::TransformViaOnnx(
void OnnxifiTransformer::Transform(
Workspace* ws,
NetDef* pred_net,
- const std::vector<std::string>& external_inputs,
const std::vector<std::string>& weight_names,
const std::unordered_map<std::string, TensorShape>& input_shape_hints,
const std::unordered_set<int>& blacklisted_ops) {
@@ -1011,17 +995,13 @@ void OnnxifiTransformer::Transform(
weight_names.begin(), weight_names.end());
// SSA Rewrite the net
- auto shape_hints_ordered =
+ auto shape_hints_mapped =
SsaRewriteAndMapNames(ws, pred_net, weights, input_shape_hints);
// Populate shape info
Workspace mapped_ws(ws, input_mapping_);
ShapeInfoMap shape_hints = InferShapes(
- &mapped_ws,
- pred_net,
- &shape_hints_ordered,
- opts_.infer_shapes,
- opts_.bound_shape_spec);
+ &mapped_ws, pred_net, &shape_hints_mapped, opts_.bound_shape_spec);
// Transform the net
NetDef net_opt = opts_.use_onnx
diff --git a/caffe2/opt/onnxifi_transformer.h b/caffe2/opt/onnxifi_transformer.h
index a1a8cd8eea..e037eefe69 100644
--- a/caffe2/opt/onnxifi_transformer.h
+++ b/caffe2/opt/onnxifi_transformer.h
@@ -22,8 +22,6 @@ class OnnxExporter;
struct OnnxifiTransformerOptions {
explicit OnnxifiTransformerOptions() : bound_shape_spec(0, 0) {}
- // Run bound shape inference
- bool infer_shapes{false};
// Dump onnx model for debugging
bool debug{false};
// Pass serialized onnx model if true, otherwise pass serialized c2 model
@@ -41,7 +39,6 @@ class CAFFE2_API OnnxifiTransformer final {
void Transform(
Workspace* ws,
NetDef* pred_net,
- const std::vector<std::string>& external_inputs,
const std::vector<std::string>& weight_names,
const std::unordered_map<std::string, TensorShape>& shape_hints,
const std::unordered_set<int>& blacklisted_ops);
@@ -83,7 +80,7 @@ class CAFFE2_API OnnxifiTransformer final {
const std::vector<std::string>& external_inputs,
const std::vector<std::string>& external_outputs);
- CaffeMap<std::string, TensorShape> SsaRewriteAndMapNames(
+ std::unordered_map<std::string, TensorShape> SsaRewriteAndMapNames(
Workspace* ws,
NetDef* pred_net,
const std::unordered_set<std::string>& weights,
diff --git a/caffe2/python/onnx/onnxifi.py b/caffe2/python/onnx/onnxifi.py
index e76bf5d84a..9a859cbf60 100644
--- a/caffe2/python/onnx/onnxifi.py
+++ b/caffe2/python/onnx/onnxifi.py
@@ -19,7 +19,6 @@ import numpy as np
def onnxifi_caffe2_net(
pred_net,
input_shapes,
- infer_shapes=False,
max_batch_size=1,
max_seq_size=1,
debug=False,
@@ -27,27 +26,11 @@ def onnxifi_caffe2_net(
"""
Transform the caffe2_net by collapsing ONNXIFI-runnable nodes into Onnxifi c2 ops
"""
- # Inject an fake input tensor to help popluate the shape if we
- # do not do shape inference
shape_hints = {}
- external_inputs = []
- if not infer_shapes:
- for k, v in input_shapes.items():
- need_input_tensor = True
- if workspace.HasBlob(k):
- itensor = workspace.FetchBlob(k)
- if itensor.shape == v:
- need_input_tensor = False
- if need_input_tensor:
- workspace.FeedBlob(k, np.random.randn(*v).astype(np.float32))
- external_inputs.append(k)
-
for k, v in input_shapes.items():
shape_hints[k] = v
pred_net_str = C.onnxifi(pred_net.SerializeToString(),
- external_inputs,
shape_hints,
- infer_shapes,
max_batch_size,
max_seq_size,
debug,
diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc
index d4f54d8a87..609ee01e69 100644
--- a/caffe2/python/pybind_state.cc
+++ b/caffe2/python/pybind_state.cc
@@ -1604,9 +1604,7 @@ void addGlobalMethods(py::module& m) {
m.def(
"onnxifi",
[](const py::bytes& pred_net_str,
- const std::vector<std::string>& external_inputs,
const std::unordered_map<std::string, std::vector<int>>& shapes,
- bool infer_shapes,
int max_batch_size,
int max_seq_size,
bool debug_builder,
@@ -1622,7 +1620,6 @@ void addGlobalMethods(py::module& m) {
it.first, CreateTensorShape(it.second, TensorProto::FLOAT));
}
OnnxifiTransformerOptions opts;
- opts.infer_shapes = infer_shapes;
opts.bound_shape_spec.max_batch_size = max_batch_size;
opts.bound_shape_spec.max_seq_size = max_seq_size;
opts.debug = debug_builder;
@@ -1633,7 +1630,6 @@ void addGlobalMethods(py::module& m) {
ts.Transform(
curr_ws,
&pred_net,
- external_inputs,
weight_names,
tensor_shapes,
std::unordered_set<int>());