summaryrefslogtreecommitdiff
path: root/caffe2
diff options
context:
space:
mode:
authorYinghai Lu <yinghai@fb.com>2019-03-21 15:28:20 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-03-21 15:43:05 -0700
commit979db037221d91f9dbd08df58a261d7448e6240d (patch)
tree70c11f8c170dc2e29b41499d78e59e9d98a74366 /caffe2
parent104773c715acd79909bd7ef90376c61c25839793 (diff)
downloadpytorch-979db037221d91f9dbd08df58a261d7448e6240d.tar.gz
pytorch-979db037221d91f9dbd08df58a261d7448e6240d.tar.bz2
pytorch-979db037221d91f9dbd08df58a261d7448e6240d.zip
Blacklist certain op types when doing bound shape inference (#18290)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18290 Some such as `Tile` will mess up our tracking of batch size and for now it makes sense to stop the shape inference on these ops so that we don't lower it and downstream ops without proper batch info. Reviewed By: zrphercule Differential Revision: D14463550 fbshipit-source-id: 2792481efa540f2a7dd310e677c213860c3053ca
Diffstat (limited to 'caffe2')
-rw-r--r--caffe2/opt/bound_shape_inference_test.cc10
-rw-r--r--caffe2/opt/bound_shape_inferencer.cc31
2 files changed, 19 insertions, 22 deletions
diff --git a/caffe2/opt/bound_shape_inference_test.cc b/caffe2/opt/bound_shape_inference_test.cc
index ee59fe763a..d8f77cf2b9 100644
--- a/caffe2/opt/bound_shape_inference_test.cc
+++ b/caffe2/opt/bound_shape_inference_test.cc
@@ -244,8 +244,7 @@ TEST(BoundShapeInference, FC) {
{spec.max_batch_size, 1024});
}
-// We don't support inference input shape when Weight is not 2D
-TEST(BoundShapeInference, UnsupportedFC) {
+TEST(BoundShapeInference, FC3D) {
NetDef net;
net.add_op()->CopyFrom(
CreateOperatorDef("FC", "", {"X0", "W0", "B0"}, {"Out0"}, {}));
@@ -255,7 +254,12 @@ TEST(BoundShapeInference, UnsupportedFC) {
shape_map.emplace("B0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {16}));
BoundShapeSpec spec(20, 1000);
BoundShapeInferencer eng(spec);
- EXPECT_THROW(eng.InferBoundShapeAndType(net, shape_map), EnforceNotMet);
+ eng.InferBoundShapeAndType(net, shape_map);
+ const auto& out_shape = eng.shape_info();
+ verifyShapeInfo(
+ out_shape, "X0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 1024});
+ verifyShapeInfo(
+ out_shape, "Out0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 16});
}
TEST(BoundShapeInference, Combo0) {
diff --git a/caffe2/opt/bound_shape_inferencer.cc b/caffe2/opt/bound_shape_inferencer.cc
index dea1eeb14d..717f8432e8 100644
--- a/caffe2/opt/bound_shape_inferencer.cc
+++ b/caffe2/opt/bound_shape_inferencer.cc
@@ -44,10 +44,15 @@ void EnsureShapeNames(std::unordered_map<std::string, ShapeInfo>* info) {
void BoundShapeInferencer::InferBoundShapeAndType(
const NetDef& net,
const std::unordered_map<std::string, ShapeInfo>& info) {
+ const static std::unordered_set<std::string> unsupported{"Tile"};
shape_info_ = info;
for (const auto& op : net.op()) {
VLOG(1) << op.type();
+ if (unsupported.count(op.type())) {
+ continue;
+ }
+
if (op.type() == "SparseLengthsSum" ||
op.type() == "SparseLengthsSumFused8BitRowwise" ||
op.type() == "SparseLengthsWeightedSum" ||
@@ -316,34 +321,22 @@ void BoundShapeInferencer::InferFC(const OperatorDef& op) {
ArgumentHelper helper(op);
auto axis = helper.GetSingleArgument<int32_t>("axis", 1);
auto axis_w = helper.GetSingleArgument<int32_t>("axis_w", 1);
- CAFFE_ENFORCE_EQ(
- axis,
- 1,
- "Don't know how to deduce input of FC with axis not equal to 1: ",
- op.input(0));
- CAFFE_ENFORCE_EQ(
- axis_w,
- 1,
- "Don't know how to deduce input of FC with axis_w not equal to 1: ",
- op.input(0));
const TensorShape w_shape = w_shape_info.shape;
- CAFFE_ENFORCE_EQ(
- w_shape.dims_size(),
- 2,
- "Don't know how to deduce input of FC other than of dim size 2: ",
- op.input(0));
bool transposed = (op.type() == "FC") ? false : true;
const int canonical_axis_w =
canonical_axis_index_(axis_w, w_shape.dims().size());
const int64_t K = transposed ? SizeToDim(w_shape, canonical_axis_w)
: SizeFromDim(w_shape, canonical_axis_w);
+ std::vector<int64_t> dims;
+ for (int i = 0; i < axis - 1; ++i) {
+ dims.push_back(1);
+ }
+ dims.push_back(spec_.max_batch_size);
+ dims.push_back(K);
current_dim_type_ = ShapeInfo::DimType::BATCH;
current_max_batch_size_ = spec_.max_batch_size;
CheckAndSetTensorShapeAndType(
- op.input(0),
- ShapeInfo::DimType::BATCH,
- {spec_.max_batch_size, K},
- w_shape.data_type());
+ op.input(0), ShapeInfo::DimType::BATCH, dims, w_shape.data_type());
} else {
ShapeInfo& x_shape_info = x_it->second;
if (x_shape_info.dim_type != ShapeInfo::DimType::BATCH) {