summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYing Zhang <yingz@fb.com>2019-02-07 00:33:29 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-07 00:40:11 -0800
commit511f6fc2d5f3661e66ecc209574d568de7fa1cf3 (patch)
treeee179f1b5be5998f53bf4f1a9f14e57cb377b9e8
parentaa88c2c0b62518764c8f6d7cc91daf25fb876dea (diff)
downloadpytorch-511f6fc2d5f3661e66ecc209574d568de7fa1cf3.tar.gz
pytorch-511f6fc2d5f3661e66ecc209574d568de7fa1cf3.tar.bz2
pytorch-511f6fc2d5f3661e66ecc209574d568de7fa1cf3.zip
Insert AdjustBatchSizeOp into the predict_net. (#16811)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16811 As the title. The AdjustBatch ops will be inserted before and after the Onnxifi op to: 1) adjust batch/seq sizes to the ideal batch/seq size before these tensors are processed by the Onnxifi op; 2) adjust batch size to the original batch size for batches generated by the Onnxifi op. Reviewed By: yinghai Differential Revision: D13967711 fbshipit-source-id: 471b25ae6a60bf5b7ebee1de6449e0389b6cafff
-rw-r--r--caffe2/opt/onnxifi_transformer.cc184
1 files changed, 161 insertions, 23 deletions
diff --git a/caffe2/opt/onnxifi_transformer.cc b/caffe2/opt/onnxifi_transformer.cc
index 0af59e9384..0eaea78578 100644
--- a/caffe2/opt/onnxifi_transformer.cc
+++ b/caffe2/opt/onnxifi_transformer.cc
@@ -232,6 +232,147 @@ void FillModelInfo(::ONNX_NAMESPACE::ModelProto* model) {
opset_id->set_domain("");
opset_id->set_version(7);
}
+
+string MkBatchSizeBlob() {
+ return "real_batch_size";
+}
+
+string MkSeqSizeBlob(const string& blob_name) {
+ return blob_name + "_real_seq_size";
+}
+
+string MkOutputForAdjustBatchOp(const string& input) {
+ return input + "_post_adjust_batch";
+}
+
+string MkInputForAdjustBatchOp(const string& output) {
+ return output + "_pre_adjust_batch";
+}
+
+OperatorDef MkAdjustBatchOp(
+ const string& input_blob,
+ const string& output_blob,
+ int max_batch_size,
+ const string& real_batch_size_blob,
+ bool adjust_to_max_batch_size) {
+ OperatorDef adjust_batch_op;
+ adjust_batch_op.set_type("AdjustBatch");
+ auto* arg = adjust_batch_op.add_arg();
+ arg->set_name("max_batch_size");
+ arg->set_i(max_batch_size);
+ adjust_batch_op.add_input(input_blob);
+ adjust_batch_op.add_output(output_blob);
+ if (adjust_to_max_batch_size) {
+ adjust_batch_op.add_output(real_batch_size_blob);
+ } else {
+ adjust_batch_op.add_input(real_batch_size_blob);
+ }
+ return adjust_batch_op;
+}
+
+std::unordered_set<string> ToHashSet(
+ const ::google::protobuf::RepeatedPtrField<string>& strs) {
+ return std::unordered_set<string>(strs.begin(), strs.end());
+}
+
+int64_t GetBlob1stDimSize(
+ const ShapeInfo& shape_info,
+ const string& blob_name) {
+ CAFFE_ENFORCE(
+ shape_info.shape.dims_size() > 0 && shape_info.shape.dims(0) > 0,
+ "Tensor " + blob_name +
+ " is type BATCH / SEQ, however the batch_size is unknown. " +
+ "Dims size: " + to_string(shape_info.shape.dims_size()) +
+ ", dim[0] = " + to_string(shape_info.shape.dims(0)));
+ return shape_info.shape.dims(0);
+}
+
+// Generates AdjustBatchOps for external inputs / outputs with type BATCH or
+// SEQ and adds them to input_ops and output_ops.
+// Meanwhile, modifies inputs / outputs of corresponding operators in the
+// wrapper_net to use the new inputs / outputs of AdjustBatchOps.
+void AddAdjustBatchOps(
+ const ShapeInfoMap& shape_hints,
+ NetDef* wrapper_net,
+ vector<OperatorDef>* input_ops,
+ vector<OperatorDef>* output_ops) {
+ const auto external_inputs = ToHashSet(wrapper_net->external_input());
+ const auto external_outputs = ToHashSet(wrapper_net->external_output());
+
+ for (auto& op : *(wrapper_net->mutable_op())) {
+ // Add AdjustBatchOp for all external inputs with type BATCH or SEQ.
+ // This will adjust the batch/seq size to the batch/seq size inferred by
+ // bound_shape_inference.
+ for (auto& input_blob : *(op.mutable_input())) {
+ if (external_inputs.count(input_blob)) {
+ auto shape_info_it = shape_hints.find(input_blob);
+ if (shape_info_it == shape_hints.end()) {
+ LOG(WARNING) << "Cannot find shape_info for external input blob: "
+ << input_blob;
+ continue;
+ }
+ string real_batch_size_blob = "";
+ if (shape_info_it->second.dim_type == ShapeInfo::DimType::BATCH) {
+ real_batch_size_blob = MkBatchSizeBlob();
+ } else if (shape_info_it->second.dim_type == ShapeInfo::DimType::SEQ) {
+ real_batch_size_blob = MkSeqSizeBlob(input_blob);
+ } else {
+ continue;
+ }
+ auto output_blob = MkOutputForAdjustBatchOp(input_blob);
+ input_ops->push_back(MkAdjustBatchOp(
+ input_blob,
+ output_blob,
+ GetBlob1stDimSize(shape_info_it->second, input_blob),
+ real_batch_size_blob,
+ true /* adjust_to_max_batch_size */));
+ input_blob = output_blob;
+ }
+ }
+ // Add AdjustBatchOp for all external outputs with type BATCH.
+ // This will adjust the batch size to the original batch size.
+ for (auto& output_blob : *(op.mutable_output())) {
+ if (external_outputs.count(output_blob)) {
+ auto shape_info_it = shape_hints.find(output_blob);
+ if (shape_info_it == shape_hints.end()) {
+ continue;
+ }
+ if (shape_info_it->second.dim_type == ShapeInfo::DimType::BATCH) {
+ auto input_blob = MkInputForAdjustBatchOp(output_blob);
+ output_ops->push_back(MkAdjustBatchOp(
+ input_blob,
+ output_blob,
+ GetBlob1stDimSize(shape_info_it->second, output_blob),
+ MkBatchSizeBlob(),
+ false /* adjust_to_max_batch_size */));
+ output_blob = input_blob;
+ } else {
+ CAFFE_ENFORCE(
+ shape_info_it->second.dim_type != ShapeInfo::DimType::SEQ,
+ "Output tensor " + output_blob +
+ " should never have dim_type SEQ.");
+ }
+ }
+ }
+ }
+}
+
+NetDef ComposeResultNet(
+ const vector<OperatorDef>& input_ops,
+ const vector<OperatorDef>& output_ops,
+ const OperatorDef& onnxifi_op) {
+ NetDef net_opt;
+ for (const auto& op : input_ops) {
+ *(net_opt.add_op()) = op;
+ }
+ *(net_opt.add_op()) = onnxifi_op;
+ // Add AdjustBatch ops for output blobs to the net.
+ for (const auto& op : output_ops) {
+ *(net_opt.add_op()) = op;
+ }
+ return net_opt;
+}
+
} // namespace
OnnxifiTransformer::OnnxifiTransformer(const OnnxifiTransformerOptions& opts)
@@ -362,6 +503,10 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaC2(
}
}
+ vector<OperatorDef> input_ops;
+ vector<OperatorDef> output_ops;
+ AddAdjustBatchOps(shape_hints, &wrapper_net, &input_ops, &output_ops);
+
// Figure out weights and add it to external_inputs too
std::vector<std::string> extra_weights;
std::unordered_set<std::string> initialization_list;
@@ -381,23 +526,17 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaC2(
WrapShapeInfoIntoTensorProto(i, shape_hints.at(i)));
}
- // Debugging stuff
- if (opts_.debug) {
- WriteProtoToTextFile(wrapper_net, "debug.pb_txt");
- }
-
- // C2 model is ready. Build ONNXIFI Op
+ // Build ONNXIFI Op
std::string model_str;
wrapper_net.SerializeToString(&model_str);
- NetDef net_opt;
- auto* op = net_opt.add_op();
- *op = BuildOnnxifiOp(
+ auto onnxifi_op = BuildOnnxifiOp(
model_str, output_shape_hints, initialization_list, net_copy);
- for (const auto& i : op->input()) {
- net_opt.add_external_input(i);
- }
- for (const auto& o : op->output()) {
- net_opt.add_external_output(o);
+ NetDef net_opt = ComposeResultNet(input_ops, output_ops, onnxifi_op);
+
+ // Debugging stuff
+ if (opts_.debug) {
+ WriteProtoToTextFile(wrapper_net, "debug_wrapper_net.pb_txt");
+ WriteProtoToTextFile(net_opt, "debug_optimized_net.pb_txt");
}
return net_opt;
}
@@ -412,6 +551,11 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaOnnx(
::ONNX_NAMESPACE::ModelProto onnx_model;
FillModelInfo(&onnx_model);
+ caffe2::NetDef wrapper_net(net);
+ vector<OperatorDef> input_ops;
+ vector<OperatorDef> output_ops;
+ AddAdjustBatchOps(*shape_hints, &wrapper_net, &input_ops, &output_ops);
+
// Convert c2 ops to onnx ops, add const weights if there are any
DeviceOption option;
CPUContext context(option);
@@ -511,15 +655,9 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaOnnx(
// Onnx model is ready. Build ONNXIFI Op
std::string model_str;
onnx_model.SerializeToString(&model_str);
- NetDef net_opt;
- auto* op = net_opt.add_op();
- *op = BuildOnnxifiOp(model_str, output_shape_hints, initialization_list, net);
- for (const auto& i : op->input()) {
- net_opt.add_external_input(i);
- }
- for (const auto& i : op->output()) {
- net_opt.add_external_output(i);
- }
+ auto onnxifi_op =
+ BuildOnnxifiOp(model_str, output_shape_hints, initialization_list, net);
+ NetDef net_opt = ComposeResultNet(input_ops, output_ops, onnxifi_op);
return net_opt;
}