diff options
author | Ying Zhang <yingz@fb.com> | 2019-02-07 00:33:29 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-07 00:40:11 -0800 |
commit | 511f6fc2d5f3661e66ecc209574d568de7fa1cf3 (patch) | |
tree | ee179f1b5be5998f53bf4f1a9f14e57cb377b9e8 | |
parent | aa88c2c0b62518764c8f6d7cc91daf25fb876dea (diff) | |
download | pytorch-511f6fc2d5f3661e66ecc209574d568de7fa1cf3.tar.gz pytorch-511f6fc2d5f3661e66ecc209574d568de7fa1cf3.tar.bz2 pytorch-511f6fc2d5f3661e66ecc209574d568de7fa1cf3.zip |
Insert AdjustBatchSizeOp into the predict_net. (#16811)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16811
As the title. The AdjustBatch ops will be inserted before and after the Onnxifi op to:
1) adjust batch/seq sizes to the ideal batch/seq size before these tensors are processed by the Onnxifi op;
2) adjust batch size to the original batch size for batches generated by the Onnxifi op.
Reviewed By: yinghai
Differential Revision: D13967711
fbshipit-source-id: 471b25ae6a60bf5b7ebee1de6449e0389b6cafff
-rw-r--r-- | caffe2/opt/onnxifi_transformer.cc | 184 |
1 files changed, 161 insertions, 23 deletions
diff --git a/caffe2/opt/onnxifi_transformer.cc b/caffe2/opt/onnxifi_transformer.cc index 0af59e9384..0eaea78578 100644 --- a/caffe2/opt/onnxifi_transformer.cc +++ b/caffe2/opt/onnxifi_transformer.cc @@ -232,6 +232,147 @@ void FillModelInfo(::ONNX_NAMESPACE::ModelProto* model) { opset_id->set_domain(""); opset_id->set_version(7); } + +string MkBatchSizeBlob() { + return "real_batch_size"; +} + +string MkSeqSizeBlob(const string& blob_name) { + return blob_name + "_real_seq_size"; +} + +string MkOutputForAdjustBatchOp(const string& input) { + return input + "_post_adjust_batch"; +} + +string MkInputForAdjustBatchOp(const string& output) { + return output + "_pre_adjust_batch"; +} + +OperatorDef MkAdjustBatchOp( + const string& input_blob, + const string& output_blob, + int max_batch_size, + const string& real_batch_size_blob, + bool adjust_to_max_batch_size) { + OperatorDef adjust_batch_op; + adjust_batch_op.set_type("AdjustBatch"); + auto* arg = adjust_batch_op.add_arg(); + arg->set_name("max_batch_size"); + arg->set_i(max_batch_size); + adjust_batch_op.add_input(input_blob); + adjust_batch_op.add_output(output_blob); + if (adjust_to_max_batch_size) { + adjust_batch_op.add_output(real_batch_size_blob); + } else { + adjust_batch_op.add_input(real_batch_size_blob); + } + return adjust_batch_op; +} + +std::unordered_set<string> ToHashSet( + const ::google::protobuf::RepeatedPtrField<string>& strs) { + return std::unordered_set<string>(strs.begin(), strs.end()); +} + +int64_t GetBlob1stDimSize( + const ShapeInfo& shape_info, + const string& blob_name) { + CAFFE_ENFORCE( + shape_info.shape.dims_size() > 0 && shape_info.shape.dims(0) > 0, + "Tensor " + blob_name + + " is type BATCH / SEQ, however the batch_size is unknown. " + + "Dims size: " + to_string(shape_info.shape.dims_size()) + + ", dim[0] = " + to_string(shape_info.shape.dims(0))); + return shape_info.shape.dims(0); +} + +// Generates AdjustBatchOps for external inputs / outputs with type BATCH or +// SEQ and adds them to input_ops and output_ops. +// Meanwhile, modifies inputs / outputs of corresponding operators in the +// wrapper_net to use the new inputs / outputs of AdjustBatchOps. +void AddAdjustBatchOps( + const ShapeInfoMap& shape_hints, + NetDef* wrapper_net, + vector<OperatorDef>* input_ops, + vector<OperatorDef>* output_ops) { + const auto external_inputs = ToHashSet(wrapper_net->external_input()); + const auto external_outputs = ToHashSet(wrapper_net->external_output()); + + for (auto& op : *(wrapper_net->mutable_op())) { + // Add AdjustBatchOp for all external inputs with type BATCH or SEQ. + // This will adjust the batch/seq size to the batch/seq size inferred by + // bound_shape_inference. + for (auto& input_blob : *(op.mutable_input())) { + if (external_inputs.count(input_blob)) { + auto shape_info_it = shape_hints.find(input_blob); + if (shape_info_it == shape_hints.end()) { + LOG(WARNING) << "Cannot find shape_info for external input blob: " + << input_blob; + continue; + } + string real_batch_size_blob = ""; + if (shape_info_it->second.dim_type == ShapeInfo::DimType::BATCH) { + real_batch_size_blob = MkBatchSizeBlob(); + } else if (shape_info_it->second.dim_type == ShapeInfo::DimType::SEQ) { + real_batch_size_blob = MkSeqSizeBlob(input_blob); + } else { + continue; + } + auto output_blob = MkOutputForAdjustBatchOp(input_blob); + input_ops->push_back(MkAdjustBatchOp( + input_blob, + output_blob, + GetBlob1stDimSize(shape_info_it->second, input_blob), + real_batch_size_blob, + true /* adjust_to_max_batch_size */)); + input_blob = output_blob; + } + } + // Add AdjustBatchOp for all external outputs with type BATCH. + // This will adjust the batch size to the original batch size. + for (auto& output_blob : *(op.mutable_output())) { + if (external_outputs.count(output_blob)) { + auto shape_info_it = shape_hints.find(output_blob); + if (shape_info_it == shape_hints.end()) { + continue; + } + if (shape_info_it->second.dim_type == ShapeInfo::DimType::BATCH) { + auto input_blob = MkInputForAdjustBatchOp(output_blob); + output_ops->push_back(MkAdjustBatchOp( + input_blob, + output_blob, + GetBlob1stDimSize(shape_info_it->second, output_blob), + MkBatchSizeBlob(), + false /* adjust_to_max_batch_size */)); + output_blob = input_blob; + } else { + CAFFE_ENFORCE( + shape_info_it->second.dim_type != ShapeInfo::DimType::SEQ, + "Output tensor " + output_blob + + " should never have dim_type SEQ."); + } + } + } + } +} + +NetDef ComposeResultNet( + const vector<OperatorDef>& input_ops, + const vector<OperatorDef>& output_ops, + const OperatorDef& onnxifi_op) { + NetDef net_opt; + for (const auto& op : input_ops) { + *(net_opt.add_op()) = op; + } + *(net_opt.add_op()) = onnxifi_op; + // Add AdjustBatch ops for output blobs to the net. + for (const auto& op : output_ops) { + *(net_opt.add_op()) = op; + } + return net_opt; +} + } // namespace OnnxifiTransformer::OnnxifiTransformer(const OnnxifiTransformerOptions& opts) @@ -362,6 +503,10 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaC2( } } + vector<OperatorDef> input_ops; + vector<OperatorDef> output_ops; + AddAdjustBatchOps(shape_hints, &wrapper_net, &input_ops, &output_ops); + // Figure out weights and add it to external_inputs too std::vector<std::string> extra_weights; std::unordered_set<std::string> initialization_list; @@ -381,23 +526,17 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaC2( WrapShapeInfoIntoTensorProto(i, shape_hints.at(i))); } - // Debugging stuff - if (opts_.debug) { - WriteProtoToTextFile(wrapper_net, "debug.pb_txt"); - } - - // C2 model is ready. Build ONNXIFI Op + // Build ONNXIFI Op std::string model_str; wrapper_net.SerializeToString(&model_str); - NetDef net_opt; - auto* op = net_opt.add_op(); - *op = BuildOnnxifiOp( + auto onnxifi_op = BuildOnnxifiOp( model_str, output_shape_hints, initialization_list, net_copy); - for (const auto& i : op->input()) { - net_opt.add_external_input(i); - } - for (const auto& o : op->output()) { - net_opt.add_external_output(o); + NetDef net_opt = ComposeResultNet(input_ops, output_ops, onnxifi_op); + + // Debugging stuff + if (opts_.debug) { + WriteProtoToTextFile(wrapper_net, "debug_wrapper_net.pb_txt"); + WriteProtoToTextFile(net_opt, "debug_optimized_net.pb_txt"); } return net_opt; } @@ -412,6 +551,11 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaOnnx( ::ONNX_NAMESPACE::ModelProto onnx_model; FillModelInfo(&onnx_model); + caffe2::NetDef wrapper_net(net); + vector<OperatorDef> input_ops; + vector<OperatorDef> output_ops; + AddAdjustBatchOps(*shape_hints, &wrapper_net, &input_ops, &output_ops); + // Convert c2 ops to onnx ops, add const weights if there are any DeviceOption option; CPUContext context(option); @@ -511,15 +655,9 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaOnnx( // Onnx model is ready. Build ONNXIFI Op std::string model_str; onnx_model.SerializeToString(&model_str); - NetDef net_opt; - auto* op = net_opt.add_op(); - *op = BuildOnnxifiOp(model_str, output_shape_hints, initialization_list, net); - for (const auto& i : op->input()) { - net_opt.add_external_input(i); - } - for (const auto& i : op->output()) { - net_opt.add_external_output(i); - } + auto onnxifi_op = + BuildOnnxifiOp(model_str, output_shape_hints, initialization_list, net); + NetDef net_opt = ComposeResultNet(input_ops, output_ops, onnxifi_op); return net_opt; } |