diff options
author | bddppq <bai@in.tum.de> | 2018-04-09 22:03:56 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-04-09 22:03:56 -0700 |
commit | 5e12ba92dc2dabcf979f1c5f1d2d202cd796dc00 (patch) | |
tree | 27892afcbc74db45af31beb9543a3ab932370694 /caffe2/operators | |
parent | ce37cf79146dd6737f4a4aabf8854c76e1622eba (diff) | |
download | pytorch-5e12ba92dc2dabcf979f1c5f1d2d202cd796dc00.tar.gz pytorch-5e12ba92dc2dabcf979f1c5f1d2d202cd796dc00.tar.bz2 pytorch-5e12ba92dc2dabcf979f1c5f1d2d202cd796dc00.zip |
Guard couple shape inference functions for unkown input shapes (#6379)
Diffstat (limited to 'caffe2/operators')
-rw-r--r-- | caffe2/operators/concat_split_op.cc | 7 | ||||
-rw-r--r-- | caffe2/operators/conv_pool_op_base.h | 10 | ||||
-rw-r--r-- | caffe2/operators/dropout_op.cc | 3 | ||||
-rw-r--r-- | caffe2/operators/fully_connected_op.cc | 7 |
4 files changed, 21 insertions, 6 deletions
diff --git a/caffe2/operators/concat_split_op.cc b/caffe2/operators/concat_split_op.cc index 32b1ee0e64..6686d23237 100644 --- a/caffe2/operators/concat_split_op.cc +++ b/caffe2/operators/concat_split_op.cc @@ -94,8 +94,9 @@ OPERATOR_SCHEMA(Concat) "add_axis", "Pass 1 to add the axis specified in arg 'axis' to all " "input tensors") - .TensorInferenceFunction([](const OperatorDef& def, - const vector<TensorShape>& in) { + .TensorInferenceFunction(OpSchema::NeedsAllInputShapes( + [](const OperatorDef& def, + const vector<TensorShape>& in) { ArgumentHelper helper(def); const int axis = helper.HasArgument("axis") ? helper.GetSingleArgument<int>("axis", -1) @@ -164,7 +165,7 @@ OPERATOR_SCHEMA(Concat) return vector<TensorShape>{ CreateTensorShape(out_shape, in[0].data_type()), CreateTensorShape(split_shape, TensorProto::INT32)}; - }) + })) .CostInferenceFunction(CostInferenceForConcat) .DeviceInferenceFunction(concatOpDevInfer) .SetDoc("Concatenate a list of tensors into a single tensor") diff --git a/caffe2/operators/conv_pool_op_base.h b/caffe2/operators/conv_pool_op_base.h index d94e4dbe47..e0b6541fc0 100644 --- a/caffe2/operators/conv_pool_op_base.h +++ b/caffe2/operators/conv_pool_op_base.h @@ -547,12 +547,22 @@ class ConvPoolOpBase : public Operator<Context> { static vector<TensorShape> TensorInferenceForConv( const OperatorDef& def, const vector<TensorShape>& in) { + if (in[0].unknown_shape()) { + vector<TensorShape> out(1); + out[0].set_unknown_shape(true); + return out; + } return TensorInferenceForSchema(def, in, in[1].dims(0)); } static vector<TensorShape> TensorInferenceForPool( const OperatorDef& def, const vector<TensorShape>& in) { + if (in[0].unknown_shape()) { + vector<TensorShape> out(1); + out[0].set_unknown_shape(true); + return out; + } ArgumentHelper helper(def); auto order = StringToStorageOrder(helper.GetSingleArgument<string>("order", "NCHW")); diff --git a/caffe2/operators/dropout_op.cc b/caffe2/operators/dropout_op.cc index cf9680c296..d696978ff4 100644 --- a/caffe2/operators/dropout_op.cc +++ b/caffe2/operators/dropout_op.cc @@ -70,8 +70,7 @@ OPERATOR_SCHEMA(Dropout) vector<TensorShape> out; ArgumentHelper argsHelper(def); out.push_back(in[0]); - auto output_mask = !argsHelper.GetSingleArgument<bool>("is_test", 0); - if (output_mask) { + if (def.output().size() == 2) { out.push_back(in[0]); out[1].set_data_type(TensorProto_DataType_BOOL); } diff --git a/caffe2/operators/fully_connected_op.cc b/caffe2/operators/fully_connected_op.cc index 0332fb66af..4499789b0a 100644 --- a/caffe2/operators/fully_connected_op.cc +++ b/caffe2/operators/fully_connected_op.cc @@ -26,11 +26,16 @@ std::vector<TensorShape> FCShapeInference( const vector<TensorShape>& in, bool pretransposed_weight) { vector<TensorShape> out(1); + + if (in[0].unknown_shape() || in[1].unknown_shape()) { + out[0].set_unknown_shape(true); + return out; + } + ArgumentHelper helper(def); auto axis = helper.GetSingleArgument<int32_t>("axis", 1); const auto canonical_axis = canonical_axis_index_(axis, in[0].dims().size()); - const int M = size_to_dim_(canonical_axis, GetDimsVector(in[0])); auto axis_w = helper.GetSingleArgument<int32_t>("axis_w", 1); const int canonical_axis_w = canonical_axis_index_(axis_w, in[1].dims().size()); |