From 5e12ba92dc2dabcf979f1c5f1d2d202cd796dc00 Mon Sep 17 00:00:00 2001 From: bddppq Date: Mon, 9 Apr 2018 22:03:56 -0700 Subject: Guard couple shape inference functions for unkown input shapes (#6379) --- caffe2/operators/concat_split_op.cc | 7 ++++--- caffe2/operators/conv_pool_op_base.h | 10 ++++++++++ caffe2/operators/dropout_op.cc | 3 +-- caffe2/operators/fully_connected_op.cc | 7 ++++++- 4 files changed, 21 insertions(+), 6 deletions(-) (limited to 'caffe2/operators') diff --git a/caffe2/operators/concat_split_op.cc b/caffe2/operators/concat_split_op.cc index 32b1ee0e64..6686d23237 100644 --- a/caffe2/operators/concat_split_op.cc +++ b/caffe2/operators/concat_split_op.cc @@ -94,8 +94,9 @@ OPERATOR_SCHEMA(Concat) "add_axis", "Pass 1 to add the axis specified in arg 'axis' to all " "input tensors") - .TensorInferenceFunction([](const OperatorDef& def, - const vector& in) { + .TensorInferenceFunction(OpSchema::NeedsAllInputShapes( + [](const OperatorDef& def, + const vector& in) { ArgumentHelper helper(def); const int axis = helper.HasArgument("axis") ? helper.GetSingleArgument("axis", -1) @@ -164,7 +165,7 @@ OPERATOR_SCHEMA(Concat) return vector{ CreateTensorShape(out_shape, in[0].data_type()), CreateTensorShape(split_shape, TensorProto::INT32)}; - }) + })) .CostInferenceFunction(CostInferenceForConcat) .DeviceInferenceFunction(concatOpDevInfer) .SetDoc("Concatenate a list of tensors into a single tensor") diff --git a/caffe2/operators/conv_pool_op_base.h b/caffe2/operators/conv_pool_op_base.h index d94e4dbe47..e0b6541fc0 100644 --- a/caffe2/operators/conv_pool_op_base.h +++ b/caffe2/operators/conv_pool_op_base.h @@ -547,12 +547,22 @@ class ConvPoolOpBase : public Operator { static vector TensorInferenceForConv( const OperatorDef& def, const vector& in) { + if (in[0].unknown_shape()) { + vector out(1); + out[0].set_unknown_shape(true); + return out; + } return TensorInferenceForSchema(def, in, in[1].dims(0)); } static vector TensorInferenceForPool( const OperatorDef& def, const vector& in) { + if (in[0].unknown_shape()) { + vector out(1); + out[0].set_unknown_shape(true); + return out; + } ArgumentHelper helper(def); auto order = StringToStorageOrder(helper.GetSingleArgument("order", "NCHW")); diff --git a/caffe2/operators/dropout_op.cc b/caffe2/operators/dropout_op.cc index cf9680c296..d696978ff4 100644 --- a/caffe2/operators/dropout_op.cc +++ b/caffe2/operators/dropout_op.cc @@ -70,8 +70,7 @@ OPERATOR_SCHEMA(Dropout) vector out; ArgumentHelper argsHelper(def); out.push_back(in[0]); - auto output_mask = !argsHelper.GetSingleArgument("is_test", 0); - if (output_mask) { + if (def.output().size() == 2) { out.push_back(in[0]); out[1].set_data_type(TensorProto_DataType_BOOL); } diff --git a/caffe2/operators/fully_connected_op.cc b/caffe2/operators/fully_connected_op.cc index 0332fb66af..4499789b0a 100644 --- a/caffe2/operators/fully_connected_op.cc +++ b/caffe2/operators/fully_connected_op.cc @@ -26,11 +26,16 @@ std::vector FCShapeInference( const vector& in, bool pretransposed_weight) { vector out(1); + + if (in[0].unknown_shape() || in[1].unknown_shape()) { + out[0].set_unknown_shape(true); + return out; + } + ArgumentHelper helper(def); auto axis = helper.GetSingleArgument("axis", 1); const auto canonical_axis = canonical_axis_index_(axis, in[0].dims().size()); - const int M = size_to_dim_(canonical_axis, GetDimsVector(in[0])); auto axis_w = helper.GetSingleArgument("axis_w", 1); const int canonical_axis_w = canonical_axis_index_(axis_w, in[1].dims().size()); -- cgit v1.2.3