diff options
author | Yinghai Lu <yinghai@fb.com> | 2019-04-04 00:19:21 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-04 00:22:22 -0700 |
commit | e5e2110a8ead028c863a7f449273bf6ee90bc423 (patch) | |
tree | 6bb667e33b825feae1ea56eb8cefadd8c4f529ea | |
parent | 0c237f1383abdddcfe7dcaa00af7fc075ddb8d67 (diff) | |
download | pytorch-e5e2110a8ead028c863a7f449273bf6ee90bc423.tar.gz pytorch-e5e2110a8ead028c863a7f449273bf6ee90bc423.tar.bz2 pytorch-e5e2110a8ead028c863a7f449273bf6ee90bc423.zip |
Add shape inference function for Split (#18838)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18838
It turns out that we don't have shape inference function of `Split` op at all. This diff adds that.
Reviewed By: bertmaher
Differential Revision: D14766871
fbshipit-source-id: 535cb4f24bdada603c76579e00e7a39aee93e19f
-rw-r--r-- | caffe2/operators/concat_split_op.cc | 74 | ||||
-rw-r--r-- | caffe2/opt/bound_shape_inference_test.cc | 54 |
2 files changed, 128 insertions, 0 deletions
diff --git a/caffe2/operators/concat_split_op.cc b/caffe2/operators/concat_split_op.cc index ff665783c3..3b4bf971c2 100644 --- a/caffe2/operators/concat_split_op.cc +++ b/caffe2/operators/concat_split_op.cc @@ -16,6 +16,76 @@ std::pair<std::vector<DeviceOption>, std::vector<DeviceOption>> splitOpDevInfer( } return std::make_pair(in_dev, out_dev); } + +vector<TensorShape> TensorInferenceForSplit( + const OperatorDef& def, + const vector<TensorShape>& in) { + auto ret_invalid_shape = [&def]() { + vector<TensorShape> out(def.output().size()); + for (auto& out_ts : out) { + out_ts.set_unknown_shape(true); + } + return out; + }; + // We only support shape inference of Split with 1 input + if (def.input_size() != 1 || in.empty() || in.front().unknown_shape()) { + return ret_invalid_shape(); + } else if (def.output_size() == 0) { + return vector<TensorShape>(); + } + ArgumentHelper helper(def); + const int axis = helper.HasArgument("axis") + ? helper.GetSingleArgument<int>("axis", -1) + : GetDimFromOrderString( + helper.GetSingleArgument<string>("order", "NCHW")); + const int add_axis = helper.HasArgument("axis") + ? helper.GetSingleArgument<int>("add_axis", 0) + : 0; + const auto& input = in[0]; + const int canonical_axis = canonical_axis_index_(axis, input.dims_size()); + const int input_channels = input.dims(canonical_axis); + auto split = helper.GetRepeatedArgument<int>("split"); + // Equally split the input into outputs + const int output_size = def.output_size(); + if (split.empty()) { + if (!input_channels % output_size) { + LOG(WARNING) << "Input channels (" << input_channels + << ") should be divisible by number of outputs (" + << output_size << ")"; + return ret_invalid_shape(); + } + split.resize(output_size, input_channels / output_size); + } else if (split.size() != output_size) { + LOG(WARNING) << "`split` size (" << split.size() + << ") should be equal to output size (" << output_size << ")"; + return ret_invalid_shape(); + } + + // Check validity of the split + const int total_channels = add_axis + ? def.output_size() + : std::accumulate(split.begin(), split.begin() + output_size, 0); + if (total_channels != input_channels) { + LOG(WARNING) << "Input channels (" << input_channels + << ") is not equal to total output channels (" + << total_channels << ")"; + return ret_invalid_shape(); + } + + vector<int> output_dims(input.dims().begin(), input.dims().end()); + if (add_axis) { + output_dims.erase(output_dims.begin() + canonical_axis); + } + vector<TensorShape> output_shapes; + for (int i = 0; i < output_size; ++i) { + if (!add_axis) { + output_dims[canonical_axis] = split[i]; + } + output_shapes.emplace_back( + CreateTensorShape(output_dims, input.data_type())); + } + return output_shapes; +} } // namespace. REGISTER_CPU_OPERATOR(Split, SplitOp<CPUContext>); @@ -29,11 +99,15 @@ OPERATOR_SCHEMA(Split) "split", "(*Tensor`<int>`*): [OPTIONAL] list of output lengths (see also arg `split`)") .Arg("axis", "(*int*): axis to split on") + .Arg( + "add_axis", + "*(type: int)* Pass non-zero integer to remove the axis specified in `axis` to all input tensors.") .Arg("split", "(*Tuple(int)*): length of each output") .Arg( "order", "(*string*): order of dimensions of input and output blobs; either \"NCHW\" or \"NHWC\"") .Output(0, "[output_0, output_1, ...]", "(*Tensor*): output tensor") + .TensorInferenceFunction(TensorInferenceForSplit) .DeviceInferenceFunction(splitOpDevInfer) .SetDoc(R"DOC( Split an `input` tensor into a list of tensors, along the axis specified by the `axis` dimension. The lengths of the split can be specified using argument `split` or optional second input blob to the operator. Otherwise, the tensor is split to equal sized parts. diff --git a/caffe2/opt/bound_shape_inference_test.cc b/caffe2/opt/bound_shape_inference_test.cc index d8f77cf2b9..a148b0e82a 100644 --- a/caffe2/opt/bound_shape_inference_test.cc +++ b/caffe2/opt/bound_shape_inference_test.cc @@ -214,6 +214,60 @@ TEST(BoundShapeInference, ConcatMissingInput) { {spec.max_batch_size, 2, 60}); } +TEST(BoundShapeInference, Split) { + NetDef net; + net.add_op()->CopyFrom(CreateOperatorDef( + "Split", "", {"X"}, {"Y0", "Y1"}, {MakeArgument<int>("axis", 1)})); + net.add_op()->CopyFrom(CreateOperatorDef( + "Split", + "", + {"X"}, + {"Y2", "Y3", "Y4"}, + {MakeArgument<int>("axis", 1), + MakeArgument<std::vector<int>>("split", {4, 30, 14})})); + net.add_op()->CopyFrom(CreateOperatorDef( + "Split", + "", + {"X1"}, + {"Y5", "Y6"}, + {MakeArgument<int>("axis", 1), MakeArgument<int>("add_axis", 1)})); + BoundShapeSpec spec(20, 1000); + ShapeInfoMap shape_map; + shape_map.emplace( + "X", + makeTensorInfo(ShapeInfo::DimType::BATCH, {spec.max_batch_size, 48})); + shape_map.emplace( + "X1", + makeTensorInfo(ShapeInfo::DimType::BATCH, {spec.max_batch_size, 2, 48})); + BoundShapeInferencer eng(spec); + eng.InferBoundShapeAndType(net, shape_map); + const auto& out_shape = eng.shape_info(); + verifyShapeInfo( + out_shape, "X", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 48}); + verifyShapeInfo( + out_shape, "X1", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 2, 48}); + verifyShapeInfo( + out_shape, + "Y0", + ShapeInfo::DimType::BATCH, + {spec.max_batch_size, 48 / 2}); + verifyShapeInfo( + out_shape, + "Y1", + ShapeInfo::DimType::BATCH, + {spec.max_batch_size, 48 / 2}); + verifyShapeInfo( + out_shape, "Y2", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 4}); + verifyShapeInfo( + out_shape, "Y3", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 30}); + verifyShapeInfo( + out_shape, "Y4", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 14}); + verifyShapeInfo( + out_shape, "Y5", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 48}); + verifyShapeInfo( + out_shape, "Y6", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 48}); +} + TEST(BoundShapeInference, FC) { NetDef net; net.add_op()->CopyFrom( |