summaryrefslogtreecommitdiff
path: root/caffe2/operators
diff options
context:
space:
mode:
authorbddppq <bai@in.tum.de>2018-04-09 22:03:56 -0700
committerGitHub <noreply@github.com>2018-04-09 22:03:56 -0700
commit5e12ba92dc2dabcf979f1c5f1d2d202cd796dc00 (patch)
tree27892afcbc74db45af31beb9543a3ab932370694 /caffe2/operators
parentce37cf79146dd6737f4a4aabf8854c76e1622eba (diff)
downloadpytorch-5e12ba92dc2dabcf979f1c5f1d2d202cd796dc00.tar.gz
pytorch-5e12ba92dc2dabcf979f1c5f1d2d202cd796dc00.tar.bz2
pytorch-5e12ba92dc2dabcf979f1c5f1d2d202cd796dc00.zip
Guard couple shape inference functions for unkown input shapes (#6379)
Diffstat (limited to 'caffe2/operators')
-rw-r--r--caffe2/operators/concat_split_op.cc7
-rw-r--r--caffe2/operators/conv_pool_op_base.h10
-rw-r--r--caffe2/operators/dropout_op.cc3
-rw-r--r--caffe2/operators/fully_connected_op.cc7
4 files changed, 21 insertions, 6 deletions
diff --git a/caffe2/operators/concat_split_op.cc b/caffe2/operators/concat_split_op.cc
index 32b1ee0e64..6686d23237 100644
--- a/caffe2/operators/concat_split_op.cc
+++ b/caffe2/operators/concat_split_op.cc
@@ -94,8 +94,9 @@ OPERATOR_SCHEMA(Concat)
"add_axis",
"Pass 1 to add the axis specified in arg 'axis' to all "
"input tensors")
- .TensorInferenceFunction([](const OperatorDef& def,
- const vector<TensorShape>& in) {
+ .TensorInferenceFunction(OpSchema::NeedsAllInputShapes(
+ [](const OperatorDef& def,
+ const vector<TensorShape>& in) {
ArgumentHelper helper(def);
const int axis = helper.HasArgument("axis")
? helper.GetSingleArgument<int>("axis", -1)
@@ -164,7 +165,7 @@ OPERATOR_SCHEMA(Concat)
return vector<TensorShape>{
CreateTensorShape(out_shape, in[0].data_type()),
CreateTensorShape(split_shape, TensorProto::INT32)};
- })
+ }))
.CostInferenceFunction(CostInferenceForConcat)
.DeviceInferenceFunction(concatOpDevInfer)
.SetDoc("Concatenate a list of tensors into a single tensor")
diff --git a/caffe2/operators/conv_pool_op_base.h b/caffe2/operators/conv_pool_op_base.h
index d94e4dbe47..e0b6541fc0 100644
--- a/caffe2/operators/conv_pool_op_base.h
+++ b/caffe2/operators/conv_pool_op_base.h
@@ -547,12 +547,22 @@ class ConvPoolOpBase : public Operator<Context> {
static vector<TensorShape> TensorInferenceForConv(
const OperatorDef& def,
const vector<TensorShape>& in) {
+ if (in[0].unknown_shape()) {
+ vector<TensorShape> out(1);
+ out[0].set_unknown_shape(true);
+ return out;
+ }
return TensorInferenceForSchema(def, in, in[1].dims(0));
}
static vector<TensorShape> TensorInferenceForPool(
const OperatorDef& def,
const vector<TensorShape>& in) {
+ if (in[0].unknown_shape()) {
+ vector<TensorShape> out(1);
+ out[0].set_unknown_shape(true);
+ return out;
+ }
ArgumentHelper helper(def);
auto order =
StringToStorageOrder(helper.GetSingleArgument<string>("order", "NCHW"));
diff --git a/caffe2/operators/dropout_op.cc b/caffe2/operators/dropout_op.cc
index cf9680c296..d696978ff4 100644
--- a/caffe2/operators/dropout_op.cc
+++ b/caffe2/operators/dropout_op.cc
@@ -70,8 +70,7 @@ OPERATOR_SCHEMA(Dropout)
vector<TensorShape> out;
ArgumentHelper argsHelper(def);
out.push_back(in[0]);
- auto output_mask = !argsHelper.GetSingleArgument<bool>("is_test", 0);
- if (output_mask) {
+ if (def.output().size() == 2) {
out.push_back(in[0]);
out[1].set_data_type(TensorProto_DataType_BOOL);
}
diff --git a/caffe2/operators/fully_connected_op.cc b/caffe2/operators/fully_connected_op.cc
index 0332fb66af..4499789b0a 100644
--- a/caffe2/operators/fully_connected_op.cc
+++ b/caffe2/operators/fully_connected_op.cc
@@ -26,11 +26,16 @@ std::vector<TensorShape> FCShapeInference(
const vector<TensorShape>& in,
bool pretransposed_weight) {
vector<TensorShape> out(1);
+
+ if (in[0].unknown_shape() || in[1].unknown_shape()) {
+ out[0].set_unknown_shape(true);
+ return out;
+ }
+
ArgumentHelper helper(def);
auto axis = helper.GetSingleArgument<int32_t>("axis", 1);
const auto canonical_axis = canonical_axis_index_(axis, in[0].dims().size());
- const int M = size_to_dim_(canonical_axis, GetDimsVector(in[0]));
auto axis_w = helper.GetSingleArgument<int32_t>("axis_w", 1);
const int canonical_axis_w =
canonical_axis_index_(axis_w, in[1].dims().size());