summaryrefslogtreecommitdiff
path: root/caffe2/opt
diff options
context:
space:
mode:
authorYinghai Lu <yinghai@fb.com>2019-03-08 13:15:05 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-03-08 13:18:28 -0800
commitefed875b3f930a4caf8db584b3cf169c8f7925fd (patch)
treef0a5a4ba53c43bba997c70bb7682c5a5def64bcc /caffe2/opt
parent4a7c549e8f0ec15544a34c7e5505836951d678fb (diff)
downloadpytorch-efed875b3f930a4caf8db584b3cf169c8f7925fd.tar.gz
pytorch-efed875b3f930a4caf8db584b3cf169c8f7925fd.tar.bz2
pytorch-efed875b3f930a4caf8db584b3cf169c8f7925fd.zip
Catch exceptions in bound_shape_inference (#17775)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17775 Handles use input shape hint properly. Reviewed By: zrphercule Differential Revision: D14368735 fbshipit-source-id: 504cd96589e47aa432617e56362aa6b01a25ba9b
Diffstat (limited to 'caffe2/opt')
-rw-r--r--caffe2/opt/backend_transformer_base.cc6
-rw-r--r--caffe2/opt/bound_shape_inferencer.cc26
-rw-r--r--caffe2/opt/bound_shape_inferencer.h3
3 files changed, 29 insertions, 6 deletions
diff --git a/caffe2/opt/backend_transformer_base.cc b/caffe2/opt/backend_transformer_base.cc
index f23a7a035f..a4db8b630d 100644
--- a/caffe2/opt/backend_transformer_base.cc
+++ b/caffe2/opt/backend_transformer_base.cc
@@ -78,11 +78,15 @@ ShapeInfoMap BackendTransformerBase::inferShapes(
shape_map[s] = shape_info;
}
}
+ // We treat hinted shapes as BATCH. If there are shape hints on blobs in the
+ // workspace, since they are already inserted as CONSTANT, it will take effect
+ // here. For SEQ typed tensors, there are only a few of them and they will be
+ // handled by BoundShapeInferencer.
for (const auto& kv : shape_hints_mapped) {
shape_map.emplace(
std::piecewise_construct,
std::forward_as_tuple(kv.first),
- std::forward_as_tuple(ShapeInfo::DimType::CONSTANT, kv.second));
+ std::forward_as_tuple(ShapeInfo::DimType::BATCH, kv.second));
}
BoundShapeInferencer eng(spec);
eng.InferBoundShapeAndType(*pred_net, shape_map);
diff --git a/caffe2/opt/bound_shape_inferencer.cc b/caffe2/opt/bound_shape_inferencer.cc
index 990220a3d3..e56c8d2dd4 100644
--- a/caffe2/opt/bound_shape_inferencer.cc
+++ b/caffe2/opt/bound_shape_inferencer.cc
@@ -2,6 +2,7 @@
#include "caffe2/core/operator_schema.h"
#include "caffe2/core/tensor_impl.h"
#include "caffe2/utils/proto_utils.h"
+#include "caffe2/utils/string_utils.h"
namespace caffe2 {
@@ -60,6 +61,10 @@ void BoundShapeInferencer::InferBoundShapeAndType(
InferReshape(op);
} else if (op.type() == "LengthsRangeFill") {
InferLengthsRangeFill(op);
+ } else if (
+ caffe2::StartsWith(op.type(), "GivenTensor") &&
+ caffe2::EndsWith(op.type(), "Fill")) {
+ InferGivenTensorFill(op);
} else {
InferCommonOp(op);
}
@@ -122,6 +127,15 @@ std::vector<TensorShape> InferOutput(
return schema->InferTensor(op, input_shapes);
}
+void BoundShapeInferencer::InferGivenTensorFill(const OperatorDef& op) {
+ CAFFE_ENFORCE_EQ(op.output_size(), 1, op.type(), " must have 1 output");
+ InferCommonOp(op);
+ auto it = shape_info_.find(op.output(0));
+ if (it != shape_info_.end()) {
+ it->second.dim_type = ShapeInfo::DimType::CONSTANT;
+ }
+}
+
void BoundShapeInferencer::InferLengthsRangeFill(const OperatorDef& op) {
CAFFE_ENFORCE_EQ(op.input_size(), 1, "LengthsRangeFill must have 1 input");
CAFFE_ENFORCE_EQ(op.output_size(), 1, "LengthsRangeFill must have 1 output");
@@ -342,6 +356,7 @@ void BoundShapeInferencer::InferFC(const OperatorDef& op) {
void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) {
// First, we need to check that all the input shape/types are already
// presented
+ try {
std::vector<TensorShape> input_shapes;
for (const auto& input : op.input()) {
const auto it = shape_info_.find(input);
@@ -356,11 +371,7 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) {
const OpSchema* schema = OpSchemaRegistry::Schema(op.type());
CAFFE_ENFORCE(schema);
std::vector<TensorShape> output_shapes;
- try {
output_shapes = schema->InferTensor(op, input_shapes);
- } catch (const std::exception& e) {
- LOG(WARNING) << "Caught exception while inferring shapes for " << op.type();
- }
int i = 0;
for (const auto& shape : output_shapes) {
if (shape.unknown_shape()) {
@@ -373,6 +384,13 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) {
ConvertToVec(shape.dims()),
shape.data_type());
}
+ } catch (const caffe2::EnforceNotMet& e) {
+ LOG(ERROR) << "Enforce not met while inferring shapes for " << op.type()
+ << ": " << e.msg();
+ } catch (const std::exception& e) {
+ LOG(WARNING) << "Caught exception while inferring shapes for " << op.type()
+ << ": " << e.what();
+ }
}
} // namespace caffe2
diff --git a/caffe2/opt/bound_shape_inferencer.h b/caffe2/opt/bound_shape_inferencer.h
index dafac5b4d8..ef6fa0494e 100644
--- a/caffe2/opt/bound_shape_inferencer.h
+++ b/caffe2/opt/bound_shape_inferencer.h
@@ -63,6 +63,7 @@ class CAFFE2_API BoundShapeInferencer {
std::vector<int64_t> bound_dims,
TensorProto::DataType type);
+ void InferGivenTensorFill(const OperatorDef& op);
void InferSparseLengthsSum(const OperatorDef& op);
void InferFC(const OperatorDef& op);
void InferConcat(const OperatorDef& op);
@@ -74,7 +75,7 @@ class CAFFE2_API BoundShapeInferencer {
void InferCommonOp(const OperatorDef& op);
const BoundShapeSpec spec_;
- ShapeInfo::DimType current_dim_type_{ShapeInfo::DimType::UNKNOWN};
+ ShapeInfo::DimType current_dim_type_{ShapeInfo::DimType::BATCH};
int64_t current_max_batch_size_{0};
std::unordered_map<std::string, ShapeInfo> shape_info_;
};