diff options
author | Yinghai Lu <yinghai@fb.com> | 2019-03-21 15:28:20 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-03-21 15:43:05 -0700 |
commit | 979db037221d91f9dbd08df58a261d7448e6240d (patch) | |
tree | 70c11f8c170dc2e29b41499d78e59e9d98a74366 /caffe2 | |
parent | 104773c715acd79909bd7ef90376c61c25839793 (diff) | |
download | pytorch-979db037221d91f9dbd08df58a261d7448e6240d.tar.gz pytorch-979db037221d91f9dbd08df58a261d7448e6240d.tar.bz2 pytorch-979db037221d91f9dbd08df58a261d7448e6240d.zip |
Blacklist certain op types when doing bound shape inference (#18290)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18290
Some such as `Tile` will mess up our tracking of batch size and for now it makes sense to stop the shape inference on these ops so that we don't lower it and downstream ops without proper batch info.
Reviewed By: zrphercule
Differential Revision: D14463550
fbshipit-source-id: 2792481efa540f2a7dd310e677c213860c3053ca
Diffstat (limited to 'caffe2')
-rw-r--r-- | caffe2/opt/bound_shape_inference_test.cc | 10 | ||||
-rw-r--r-- | caffe2/opt/bound_shape_inferencer.cc | 31 |
2 files changed, 19 insertions, 22 deletions
diff --git a/caffe2/opt/bound_shape_inference_test.cc b/caffe2/opt/bound_shape_inference_test.cc index ee59fe763a..d8f77cf2b9 100644 --- a/caffe2/opt/bound_shape_inference_test.cc +++ b/caffe2/opt/bound_shape_inference_test.cc @@ -244,8 +244,7 @@ TEST(BoundShapeInference, FC) { {spec.max_batch_size, 1024}); } -// We don't support inference input shape when Weight is not 2D -TEST(BoundShapeInference, UnsupportedFC) { +TEST(BoundShapeInference, FC3D) { NetDef net; net.add_op()->CopyFrom( CreateOperatorDef("FC", "", {"X0", "W0", "B0"}, {"Out0"}, {})); @@ -255,7 +254,12 @@ TEST(BoundShapeInference, UnsupportedFC) { shape_map.emplace("B0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {16})); BoundShapeSpec spec(20, 1000); BoundShapeInferencer eng(spec); - EXPECT_THROW(eng.InferBoundShapeAndType(net, shape_map), EnforceNotMet); + eng.InferBoundShapeAndType(net, shape_map); + const auto& out_shape = eng.shape_info(); + verifyShapeInfo( + out_shape, "X0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 1024}); + verifyShapeInfo( + out_shape, "Out0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 16}); } TEST(BoundShapeInference, Combo0) { diff --git a/caffe2/opt/bound_shape_inferencer.cc b/caffe2/opt/bound_shape_inferencer.cc index dea1eeb14d..717f8432e8 100644 --- a/caffe2/opt/bound_shape_inferencer.cc +++ b/caffe2/opt/bound_shape_inferencer.cc @@ -44,10 +44,15 @@ void EnsureShapeNames(std::unordered_map<std::string, ShapeInfo>* info) { void BoundShapeInferencer::InferBoundShapeAndType( const NetDef& net, const std::unordered_map<std::string, ShapeInfo>& info) { + const static std::unordered_set<std::string> unsupported{"Tile"}; shape_info_ = info; for (const auto& op : net.op()) { VLOG(1) << op.type(); + if (unsupported.count(op.type())) { + continue; + } + if (op.type() == "SparseLengthsSum" || op.type() == "SparseLengthsSumFused8BitRowwise" || op.type() == "SparseLengthsWeightedSum" || @@ -316,34 +321,22 @@ void BoundShapeInferencer::InferFC(const OperatorDef& op) { ArgumentHelper helper(op); auto axis = helper.GetSingleArgument<int32_t>("axis", 1); auto axis_w = helper.GetSingleArgument<int32_t>("axis_w", 1); - CAFFE_ENFORCE_EQ( - axis, - 1, - "Don't know how to deduce input of FC with axis not equal to 1: ", - op.input(0)); - CAFFE_ENFORCE_EQ( - axis_w, - 1, - "Don't know how to deduce input of FC with axis_w not equal to 1: ", - op.input(0)); const TensorShape w_shape = w_shape_info.shape; - CAFFE_ENFORCE_EQ( - w_shape.dims_size(), - 2, - "Don't know how to deduce input of FC other than of dim size 2: ", - op.input(0)); bool transposed = (op.type() == "FC") ? false : true; const int canonical_axis_w = canonical_axis_index_(axis_w, w_shape.dims().size()); const int64_t K = transposed ? SizeToDim(w_shape, canonical_axis_w) : SizeFromDim(w_shape, canonical_axis_w); + std::vector<int64_t> dims; + for (int i = 0; i < axis - 1; ++i) { + dims.push_back(1); + } + dims.push_back(spec_.max_batch_size); + dims.push_back(K); current_dim_type_ = ShapeInfo::DimType::BATCH; current_max_batch_size_ = spec_.max_batch_size; CheckAndSetTensorShapeAndType( - op.input(0), - ShapeInfo::DimType::BATCH, - {spec_.max_batch_size, K}, - w_shape.data_type()); + op.input(0), ShapeInfo::DimType::BATCH, dims, w_shape.data_type()); } else { ShapeInfo& x_shape_info = x_it->second; if (x_shape_info.dim_type != ShapeInfo::DimType::BATCH) { |