diff options
author | Yinghai Lu <yinghai@fb.com> | 2019-04-05 10:09:14 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-05 10:11:58 -0700 |
commit | 1d263ed92a11941bbb856114a68d2dbfa7a95e3f (patch) | |
tree | df2e07e2828c27f115c2721d907e2eaba6f61338 | |
parent | 0c5d444b2895857e8d10a85de0fef8cfc32b8c42 (diff) | |
download | pytorch-1d263ed92a11941bbb856114a68d2dbfa7a95e3f.tar.gz pytorch-1d263ed92a11941bbb856114a68d2dbfa7a95e3f.tar.bz2 pytorch-1d263ed92a11941bbb856114a68d2dbfa7a95e3f.zip |
Add backward pass to infer single missing input shape for Concat opportunitiscally (#18911)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18911
Att.
Reviewed By: bddppq
Differential Revision: D14791295
fbshipit-source-id: 4b7a775924f0eadb0cb73aa6c434a6a5be8b92be
-rw-r--r-- | caffe2/operators/concat_split_op.h | 16 | ||||
-rw-r--r-- | caffe2/opt/bound_shape_inference_test.cc | 34 | ||||
-rw-r--r-- | caffe2/opt/bound_shape_inferencer.cc | 60 | ||||
-rw-r--r-- | caffe2/opt/bound_shape_inferencer.h | 2 | ||||
-rw-r--r-- | caffe2/opt/onnxifi_transformer.cc | 4 | ||||
-rw-r--r-- | caffe2/utils/string_utils.h | 14 |
6 files changed, 114 insertions, 16 deletions
diff --git a/caffe2/operators/concat_split_op.h b/caffe2/operators/concat_split_op.h index 47ed663f25..74d74e45d2 100644 --- a/caffe2/operators/concat_split_op.h +++ b/caffe2/operators/concat_split_op.h @@ -5,24 +5,10 @@ #include "caffe2/core/operator.h" #include "caffe2/core/types.h" #include "caffe2/utils/math.h" +#include "caffe2/utils/string_utils.h" namespace caffe2 { -namespace { -inline int GetDimFromOrderString(const string& str) { - auto order = StringToStorageOrder(str); - switch (order) { - case StorageOrder::NHWC: - return 3; - case StorageOrder::NCHW: - return 1; - default: - CAFFE_THROW("Unsupported storage order: ", str); - return -1; - } -} -} // namespace - template <class Context> class SplitOp final : public Operator<Context> { public: diff --git a/caffe2/opt/bound_shape_inference_test.cc b/caffe2/opt/bound_shape_inference_test.cc index a148b0e82a..0efa8788cc 100644 --- a/caffe2/opt/bound_shape_inference_test.cc +++ b/caffe2/opt/bound_shape_inference_test.cc @@ -214,6 +214,40 @@ TEST(BoundShapeInference, ConcatMissingInput) { {spec.max_batch_size, 2, 60}); } +TEST(BoundShapeInference, ConcatInferInputBackwards) { + NetDef net; + net.add_op()->CopyFrom(CreateOperatorDef( + "Concat", + "", + {"I0", "I1"}, + {"Cout", "split_info"}, + {MakeArgument<int>("axis", 1)})); + net.add_op()->CopyFrom( + CreateOperatorDef("FCTransposed", "", {"Cout", "W0", "B0"}, {"Y"}, {})); + BoundShapeSpec spec(20, 1000); + ShapeInfoMap shape_map; + shape_map.emplace( + "I0", + makeTensorInfo(ShapeInfo::DimType::BATCH, {spec.max_batch_size, 60})); + shape_map.emplace( + "W0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {101, 16})); + shape_map.emplace("B0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {16})); + BoundShapeInferencer eng(spec); + eng.InferBoundShapeAndType(net, shape_map); + const auto& out_shape = eng.shape_info(); + verifyShapeInfo( + out_shape, "I0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 60}); + verifyShapeInfo( + out_shape, "Cout", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 101}); + verifyShapeInfo( + out_shape, "Y", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 16}); + verifyShapeInfo( + out_shape, + "I1", + ShapeInfo::DimType::BATCH, + {spec.max_batch_size, 101 - 60}); +} + TEST(BoundShapeInference, Split) { NetDef net; net.add_op()->CopyFrom(CreateOperatorDef( diff --git a/caffe2/opt/bound_shape_inferencer.cc b/caffe2/opt/bound_shape_inferencer.cc index b7c20d592c..1d2f940e19 100644 --- a/caffe2/opt/bound_shape_inferencer.cc +++ b/caffe2/opt/bound_shape_inferencer.cc @@ -79,6 +79,14 @@ void BoundShapeInferencer::InferBoundShapeAndType( } } + // Doing a reverse pass to infer the input shapes if applicable + for (int i = net.op_size() - 1; i >= 0; --i) { + const auto& op = net.op(i); + if (op.type() == "Concat") { + InferConcatInputs(op); + } + } + // Make sure shape has name EnsureShapeNames(&shape_info_); } @@ -251,6 +259,55 @@ void BoundShapeInferencer::InferReshape(const OperatorDef& op) { shape_info_[op.output(1)].dim_type = ShapeInfo::DimType::CONSTANT; } } + +void BoundShapeInferencer::InferConcatInputs(const OperatorDef& op) { + ArgumentHelper helper(op); + const auto add_axis = helper.GetSingleArgument<int32_t>("add_axis", 0); + if (add_axis) { + return; + } else if (op.output_size() == 0 || !shape_info_.count(op.output(0))) { + return; + } + + const auto axis = helper.HasArgument("axis") + ? helper.GetSingleArgument<int32_t>("axis", -1) + : GetDimFromOrderString( + helper.GetSingleArgument<string>("order", "NCHW")); + + const auto& shape_info = shape_info_.at(op.output(0)); + int output_channel = shape_info.shape.dims(axis); + int missing_shape_infos = 0; + int channel_acc = 0; + std::string input_to_infer; + for (const auto& i : op.input()) { + const auto it = shape_info_.find(i); + if (it != shape_info_.end()) { + const auto& current_input_shape = it->second; + channel_acc += current_input_shape.shape.dims(axis); + } else if (missing_shape_infos) { + LOG(INFO) << "More than one missing shapes, previous one: " + << input_to_infer; + // We can only infer one missing input shape info + return; + } else { + ++missing_shape_infos; + input_to_infer = i; + } + } + + if (missing_shape_infos && !input_to_infer.empty()) { + auto input_shape_info = shape_info; + input_shape_info.shape.set_dims(axis, output_channel - channel_acc); + shape_info_.emplace(input_to_infer, std::move(input_shape_info)); + + // Infer the shape of the second output of Concat + InferCommonOp(op); + if (op.output_size() > 1 && shape_info_.count(op.output(1))) { + shape_info_[op.output(1)].dim_type = ShapeInfo::DimType::CONSTANT; + } + } +} + // For concat net, if some inputs are missing and we have add_axis argument, it // means that all the inputs should be of the same dimension. In this case, we // can infer the shape of the missing inputs @@ -399,7 +456,7 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) { !(op.type().compare(0, 4, "Int8")) && (op.type() != "Int8Dequantize"); TensorProto::DataType infered_data_type = TensorProto::UNDEFINED; if (is_quantized) { - const static std::map<string, int> type_info_from_input = { + const static std::map<std::string, int> type_info_from_input = { {"Int8Quantize", -1}, // Force this op's output to be uint8 {"Int8ConvRelu", 1}, {"Int8MaxPool", 0}, @@ -420,6 +477,7 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) { } else if (op.type() == "Int8Dequantize") { infered_data_type = TensorProto::FLOAT; } + for (const auto& shape : output_shapes) { if (infered_data_type == TensorProto::UNDEFINED) { infered_data_type = shape.data_type(); diff --git a/caffe2/opt/bound_shape_inferencer.h b/caffe2/opt/bound_shape_inferencer.h index ee1c670375..216534ecb6 100644 --- a/caffe2/opt/bound_shape_inferencer.h +++ b/caffe2/opt/bound_shape_inferencer.h @@ -64,6 +64,8 @@ class CAFFE2_API BoundShapeInferencer { TensorProto::DataType type, bool is_quantized); + void InferConcatInputs(const OperatorDef& op); + void InferGivenTensorFill(const OperatorDef& op); void InferSparseLengthsSum(const OperatorDef& op); void InferFC(const OperatorDef& op); diff --git a/caffe2/opt/onnxifi_transformer.cc b/caffe2/opt/onnxifi_transformer.cc index 797c3f454e..8ec572b4ea 100644 --- a/caffe2/opt/onnxifi_transformer.cc +++ b/caffe2/opt/onnxifi_transformer.cc @@ -826,6 +826,8 @@ bool OnnxifiTransformer::supportOpC2( for (const auto& i : op.input()) { const auto it = shape_hints.find(i); if (it == shape_hints.end()) { + VLOG(1) << "Skipping " << op.type() << " (" << pos + << ") due to missing shape info for input " << i; return false; } if ((it->second).is_quantized == false) { @@ -844,6 +846,8 @@ bool OnnxifiTransformer::supportOpC2( for (const auto& i : op.output()) { const auto it = shape_hints.find(i); if (it == shape_hints.end()) { + VLOG(1) << "Skipping " << op.type() << " (" << pos + << ") due to missing shape info for output " << i; return false; } if ((it->second).is_quantized == false) { diff --git a/caffe2/utils/string_utils.h b/caffe2/utils/string_utils.h index 359186607a..ada947ec11 100644 --- a/caffe2/utils/string_utils.h +++ b/caffe2/utils/string_utils.h @@ -6,6 +6,7 @@ #include <vector> #include "caffe2/core/common.h" +#include "caffe2/core/types.h" namespace caffe2 { @@ -33,6 +34,19 @@ CAFFE2_API inline bool EndsWith( } } +CAFFE2_API inline int32_t GetDimFromOrderString(const std::string& str) { + auto order = StringToStorageOrder(str); + switch (order) { + case StorageOrder::NHWC: + return 3; + case StorageOrder::NCHW: + return 1; + default: + CAFFE_THROW("Unsupported storage order: ", str); + return -1; + } +} + CAFFE2_API int32_t editDistanceHelper(const char* s1, size_t s1_len, const char* s2, |