diff options
author | Yinghai Lu <yinghai@fb.com> | 2019-03-08 13:15:05 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-03-08 13:18:28 -0800 |
commit | efed875b3f930a4caf8db584b3cf169c8f7925fd (patch) | |
tree | f0a5a4ba53c43bba997c70bb7682c5a5def64bcc /caffe2/opt | |
parent | 4a7c549e8f0ec15544a34c7e5505836951d678fb (diff) | |
download | pytorch-efed875b3f930a4caf8db584b3cf169c8f7925fd.tar.gz pytorch-efed875b3f930a4caf8db584b3cf169c8f7925fd.tar.bz2 pytorch-efed875b3f930a4caf8db584b3cf169c8f7925fd.zip |
Catch exceptions in bound_shape_inference (#17775)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17775
Handles use input shape hint properly.
Reviewed By: zrphercule
Differential Revision: D14368735
fbshipit-source-id: 504cd96589e47aa432617e56362aa6b01a25ba9b
Diffstat (limited to 'caffe2/opt')
-rw-r--r-- | caffe2/opt/backend_transformer_base.cc | 6 | ||||
-rw-r--r-- | caffe2/opt/bound_shape_inferencer.cc | 26 | ||||
-rw-r--r-- | caffe2/opt/bound_shape_inferencer.h | 3 |
3 files changed, 29 insertions, 6 deletions
diff --git a/caffe2/opt/backend_transformer_base.cc b/caffe2/opt/backend_transformer_base.cc index f23a7a035f..a4db8b630d 100644 --- a/caffe2/opt/backend_transformer_base.cc +++ b/caffe2/opt/backend_transformer_base.cc @@ -78,11 +78,15 @@ ShapeInfoMap BackendTransformerBase::inferShapes( shape_map[s] = shape_info; } } + // We treat hinted shapes as BATCH. If there are shape hints on blobs in the + // workspace, since they are already inserted as CONSTANT, it will take effect + // here. For SEQ typed tensors, there are only a few of them and they will be + // handled by BoundShapeInferencer. for (const auto& kv : shape_hints_mapped) { shape_map.emplace( std::piecewise_construct, std::forward_as_tuple(kv.first), - std::forward_as_tuple(ShapeInfo::DimType::CONSTANT, kv.second)); + std::forward_as_tuple(ShapeInfo::DimType::BATCH, kv.second)); } BoundShapeInferencer eng(spec); eng.InferBoundShapeAndType(*pred_net, shape_map); diff --git a/caffe2/opt/bound_shape_inferencer.cc b/caffe2/opt/bound_shape_inferencer.cc index 990220a3d3..e56c8d2dd4 100644 --- a/caffe2/opt/bound_shape_inferencer.cc +++ b/caffe2/opt/bound_shape_inferencer.cc @@ -2,6 +2,7 @@ #include "caffe2/core/operator_schema.h" #include "caffe2/core/tensor_impl.h" #include "caffe2/utils/proto_utils.h" +#include "caffe2/utils/string_utils.h" namespace caffe2 { @@ -60,6 +61,10 @@ void BoundShapeInferencer::InferBoundShapeAndType( InferReshape(op); } else if (op.type() == "LengthsRangeFill") { InferLengthsRangeFill(op); + } else if ( + caffe2::StartsWith(op.type(), "GivenTensor") && + caffe2::EndsWith(op.type(), "Fill")) { + InferGivenTensorFill(op); } else { InferCommonOp(op); } @@ -122,6 +127,15 @@ std::vector<TensorShape> InferOutput( return schema->InferTensor(op, input_shapes); } +void BoundShapeInferencer::InferGivenTensorFill(const OperatorDef& op) { + CAFFE_ENFORCE_EQ(op.output_size(), 1, op.type(), " must have 1 output"); + InferCommonOp(op); + auto it = shape_info_.find(op.output(0)); + if (it != shape_info_.end()) { + it->second.dim_type = ShapeInfo::DimType::CONSTANT; + } +} + void BoundShapeInferencer::InferLengthsRangeFill(const OperatorDef& op) { CAFFE_ENFORCE_EQ(op.input_size(), 1, "LengthsRangeFill must have 1 input"); CAFFE_ENFORCE_EQ(op.output_size(), 1, "LengthsRangeFill must have 1 output"); @@ -342,6 +356,7 @@ void BoundShapeInferencer::InferFC(const OperatorDef& op) { void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) { // First, we need to check that all the input shape/types are already // presented + try { std::vector<TensorShape> input_shapes; for (const auto& input : op.input()) { const auto it = shape_info_.find(input); @@ -356,11 +371,7 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) { const OpSchema* schema = OpSchemaRegistry::Schema(op.type()); CAFFE_ENFORCE(schema); std::vector<TensorShape> output_shapes; - try { output_shapes = schema->InferTensor(op, input_shapes); - } catch (const std::exception& e) { - LOG(WARNING) << "Caught exception while inferring shapes for " << op.type(); - } int i = 0; for (const auto& shape : output_shapes) { if (shape.unknown_shape()) { @@ -373,6 +384,13 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) { ConvertToVec(shape.dims()), shape.data_type()); } + } catch (const caffe2::EnforceNotMet& e) { + LOG(ERROR) << "Enforce not met while inferring shapes for " << op.type() + << ": " << e.msg(); + } catch (const std::exception& e) { + LOG(WARNING) << "Caught exception while inferring shapes for " << op.type() + << ": " << e.what(); + } } } // namespace caffe2 diff --git a/caffe2/opt/bound_shape_inferencer.h b/caffe2/opt/bound_shape_inferencer.h index dafac5b4d8..ef6fa0494e 100644 --- a/caffe2/opt/bound_shape_inferencer.h +++ b/caffe2/opt/bound_shape_inferencer.h @@ -63,6 +63,7 @@ class CAFFE2_API BoundShapeInferencer { std::vector<int64_t> bound_dims, TensorProto::DataType type); + void InferGivenTensorFill(const OperatorDef& op); void InferSparseLengthsSum(const OperatorDef& op); void InferFC(const OperatorDef& op); void InferConcat(const OperatorDef& op); @@ -74,7 +75,7 @@ class CAFFE2_API BoundShapeInferencer { void InferCommonOp(const OperatorDef& op); const BoundShapeSpec spec_; - ShapeInfo::DimType current_dim_type_{ShapeInfo::DimType::UNKNOWN}; + ShapeInfo::DimType current_dim_type_{ShapeInfo::DimType::BATCH}; int64_t current_max_batch_size_{0}; std::unordered_map<std::string, ShapeInfo> shape_info_; }; |