summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYinghai Lu <yinghai@fb.com>2019-04-05 10:09:14 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-05 10:11:58 -0700
commit1d263ed92a11941bbb856114a68d2dbfa7a95e3f (patch)
treedf2e07e2828c27f115c2721d907e2eaba6f61338
parent0c5d444b2895857e8d10a85de0fef8cfc32b8c42 (diff)
downloadpytorch-1d263ed92a11941bbb856114a68d2dbfa7a95e3f.tar.gz
pytorch-1d263ed92a11941bbb856114a68d2dbfa7a95e3f.tar.bz2
pytorch-1d263ed92a11941bbb856114a68d2dbfa7a95e3f.zip
Add backward pass to infer single missing input shape for Concat opportunitiscally (#18911)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18911 Att. Reviewed By: bddppq Differential Revision: D14791295 fbshipit-source-id: 4b7a775924f0eadb0cb73aa6c434a6a5be8b92be
-rw-r--r--caffe2/operators/concat_split_op.h16
-rw-r--r--caffe2/opt/bound_shape_inference_test.cc34
-rw-r--r--caffe2/opt/bound_shape_inferencer.cc60
-rw-r--r--caffe2/opt/bound_shape_inferencer.h2
-rw-r--r--caffe2/opt/onnxifi_transformer.cc4
-rw-r--r--caffe2/utils/string_utils.h14
6 files changed, 114 insertions, 16 deletions
diff --git a/caffe2/operators/concat_split_op.h b/caffe2/operators/concat_split_op.h
index 47ed663f25..74d74e45d2 100644
--- a/caffe2/operators/concat_split_op.h
+++ b/caffe2/operators/concat_split_op.h
@@ -5,24 +5,10 @@
#include "caffe2/core/operator.h"
#include "caffe2/core/types.h"
#include "caffe2/utils/math.h"
+#include "caffe2/utils/string_utils.h"
namespace caffe2 {
-namespace {
-inline int GetDimFromOrderString(const string& str) {
- auto order = StringToStorageOrder(str);
- switch (order) {
- case StorageOrder::NHWC:
- return 3;
- case StorageOrder::NCHW:
- return 1;
- default:
- CAFFE_THROW("Unsupported storage order: ", str);
- return -1;
- }
-}
-} // namespace
-
template <class Context>
class SplitOp final : public Operator<Context> {
public:
diff --git a/caffe2/opt/bound_shape_inference_test.cc b/caffe2/opt/bound_shape_inference_test.cc
index a148b0e82a..0efa8788cc 100644
--- a/caffe2/opt/bound_shape_inference_test.cc
+++ b/caffe2/opt/bound_shape_inference_test.cc
@@ -214,6 +214,40 @@ TEST(BoundShapeInference, ConcatMissingInput) {
{spec.max_batch_size, 2, 60});
}
+TEST(BoundShapeInference, ConcatInferInputBackwards) {
+ NetDef net;
+ net.add_op()->CopyFrom(CreateOperatorDef(
+ "Concat",
+ "",
+ {"I0", "I1"},
+ {"Cout", "split_info"},
+ {MakeArgument<int>("axis", 1)}));
+ net.add_op()->CopyFrom(
+ CreateOperatorDef("FCTransposed", "", {"Cout", "W0", "B0"}, {"Y"}, {}));
+ BoundShapeSpec spec(20, 1000);
+ ShapeInfoMap shape_map;
+ shape_map.emplace(
+ "I0",
+ makeTensorInfo(ShapeInfo::DimType::BATCH, {spec.max_batch_size, 60}));
+ shape_map.emplace(
+ "W0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {101, 16}));
+ shape_map.emplace("B0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {16}));
+ BoundShapeInferencer eng(spec);
+ eng.InferBoundShapeAndType(net, shape_map);
+ const auto& out_shape = eng.shape_info();
+ verifyShapeInfo(
+ out_shape, "I0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 60});
+ verifyShapeInfo(
+ out_shape, "Cout", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 101});
+ verifyShapeInfo(
+ out_shape, "Y", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 16});
+ verifyShapeInfo(
+ out_shape,
+ "I1",
+ ShapeInfo::DimType::BATCH,
+ {spec.max_batch_size, 101 - 60});
+}
+
TEST(BoundShapeInference, Split) {
NetDef net;
net.add_op()->CopyFrom(CreateOperatorDef(
diff --git a/caffe2/opt/bound_shape_inferencer.cc b/caffe2/opt/bound_shape_inferencer.cc
index b7c20d592c..1d2f940e19 100644
--- a/caffe2/opt/bound_shape_inferencer.cc
+++ b/caffe2/opt/bound_shape_inferencer.cc
@@ -79,6 +79,14 @@ void BoundShapeInferencer::InferBoundShapeAndType(
}
}
+ // Doing a reverse pass to infer the input shapes if applicable
+ for (int i = net.op_size() - 1; i >= 0; --i) {
+ const auto& op = net.op(i);
+ if (op.type() == "Concat") {
+ InferConcatInputs(op);
+ }
+ }
+
// Make sure shape has name
EnsureShapeNames(&shape_info_);
}
@@ -251,6 +259,55 @@ void BoundShapeInferencer::InferReshape(const OperatorDef& op) {
shape_info_[op.output(1)].dim_type = ShapeInfo::DimType::CONSTANT;
}
}
+
+void BoundShapeInferencer::InferConcatInputs(const OperatorDef& op) {
+ ArgumentHelper helper(op);
+ const auto add_axis = helper.GetSingleArgument<int32_t>("add_axis", 0);
+ if (add_axis) {
+ return;
+ } else if (op.output_size() == 0 || !shape_info_.count(op.output(0))) {
+ return;
+ }
+
+ const auto axis = helper.HasArgument("axis")
+ ? helper.GetSingleArgument<int32_t>("axis", -1)
+ : GetDimFromOrderString(
+ helper.GetSingleArgument<string>("order", "NCHW"));
+
+ const auto& shape_info = shape_info_.at(op.output(0));
+ int output_channel = shape_info.shape.dims(axis);
+ int missing_shape_infos = 0;
+ int channel_acc = 0;
+ std::string input_to_infer;
+ for (const auto& i : op.input()) {
+ const auto it = shape_info_.find(i);
+ if (it != shape_info_.end()) {
+ const auto& current_input_shape = it->second;
+ channel_acc += current_input_shape.shape.dims(axis);
+ } else if (missing_shape_infos) {
+ LOG(INFO) << "More than one missing shapes, previous one: "
+ << input_to_infer;
+ // We can only infer one missing input shape info
+ return;
+ } else {
+ ++missing_shape_infos;
+ input_to_infer = i;
+ }
+ }
+
+ if (missing_shape_infos && !input_to_infer.empty()) {
+ auto input_shape_info = shape_info;
+ input_shape_info.shape.set_dims(axis, output_channel - channel_acc);
+ shape_info_.emplace(input_to_infer, std::move(input_shape_info));
+
+ // Infer the shape of the second output of Concat
+ InferCommonOp(op);
+ if (op.output_size() > 1 && shape_info_.count(op.output(1))) {
+ shape_info_[op.output(1)].dim_type = ShapeInfo::DimType::CONSTANT;
+ }
+ }
+}
+
// For concat net, if some inputs are missing and we have add_axis argument, it
// means that all the inputs should be of the same dimension. In this case, we
// can infer the shape of the missing inputs
@@ -399,7 +456,7 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) {
!(op.type().compare(0, 4, "Int8")) && (op.type() != "Int8Dequantize");
TensorProto::DataType infered_data_type = TensorProto::UNDEFINED;
if (is_quantized) {
- const static std::map<string, int> type_info_from_input = {
+ const static std::map<std::string, int> type_info_from_input = {
{"Int8Quantize", -1}, // Force this op's output to be uint8
{"Int8ConvRelu", 1},
{"Int8MaxPool", 0},
@@ -420,6 +477,7 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) {
} else if (op.type() == "Int8Dequantize") {
infered_data_type = TensorProto::FLOAT;
}
+
for (const auto& shape : output_shapes) {
if (infered_data_type == TensorProto::UNDEFINED) {
infered_data_type = shape.data_type();
diff --git a/caffe2/opt/bound_shape_inferencer.h b/caffe2/opt/bound_shape_inferencer.h
index ee1c670375..216534ecb6 100644
--- a/caffe2/opt/bound_shape_inferencer.h
+++ b/caffe2/opt/bound_shape_inferencer.h
@@ -64,6 +64,8 @@ class CAFFE2_API BoundShapeInferencer {
TensorProto::DataType type,
bool is_quantized);
+ void InferConcatInputs(const OperatorDef& op);
+
void InferGivenTensorFill(const OperatorDef& op);
void InferSparseLengthsSum(const OperatorDef& op);
void InferFC(const OperatorDef& op);
diff --git a/caffe2/opt/onnxifi_transformer.cc b/caffe2/opt/onnxifi_transformer.cc
index 797c3f454e..8ec572b4ea 100644
--- a/caffe2/opt/onnxifi_transformer.cc
+++ b/caffe2/opt/onnxifi_transformer.cc
@@ -826,6 +826,8 @@ bool OnnxifiTransformer::supportOpC2(
for (const auto& i : op.input()) {
const auto it = shape_hints.find(i);
if (it == shape_hints.end()) {
+ VLOG(1) << "Skipping " << op.type() << " (" << pos
+ << ") due to missing shape info for input " << i;
return false;
}
if ((it->second).is_quantized == false) {
@@ -844,6 +846,8 @@ bool OnnxifiTransformer::supportOpC2(
for (const auto& i : op.output()) {
const auto it = shape_hints.find(i);
if (it == shape_hints.end()) {
+ VLOG(1) << "Skipping " << op.type() << " (" << pos
+ << ") due to missing shape info for output " << i;
return false;
}
if ((it->second).is_quantized == false) {
diff --git a/caffe2/utils/string_utils.h b/caffe2/utils/string_utils.h
index 359186607a..ada947ec11 100644
--- a/caffe2/utils/string_utils.h
+++ b/caffe2/utils/string_utils.h
@@ -6,6 +6,7 @@
#include <vector>
#include "caffe2/core/common.h"
+#include "caffe2/core/types.h"
namespace caffe2 {
@@ -33,6 +34,19 @@ CAFFE2_API inline bool EndsWith(
}
}
+CAFFE2_API inline int32_t GetDimFromOrderString(const std::string& str) {
+ auto order = StringToStorageOrder(str);
+ switch (order) {
+ case StorageOrder::NHWC:
+ return 3;
+ case StorageOrder::NCHW:
+ return 1;
+ default:
+ CAFFE_THROW("Unsupported storage order: ", str);
+ return -1;
+ }
+}
+
CAFFE2_API int32_t editDistanceHelper(const char* s1,
size_t s1_len,
const char* s2,