From e5e2110a8ead028c863a7f449273bf6ee90bc423 Mon Sep 17 00:00:00 2001 From: Yinghai Lu Date: Thu, 4 Apr 2019 00:19:21 -0700 Subject: Add shape inference function for Split (#18838) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18838 It turns out that we don't have shape inference function of `Split` op at all. This diff adds that. Reviewed By: bertmaher Differential Revision: D14766871 fbshipit-source-id: 535cb4f24bdada603c76579e00e7a39aee93e19f --- caffe2/operators/concat_split_op.cc | 74 ++++++++++++++++++++++++++++++++ caffe2/opt/bound_shape_inference_test.cc | 54 +++++++++++++++++++++++ 2 files changed, 128 insertions(+) diff --git a/caffe2/operators/concat_split_op.cc b/caffe2/operators/concat_split_op.cc index ff665783c3..3b4bf971c2 100644 --- a/caffe2/operators/concat_split_op.cc +++ b/caffe2/operators/concat_split_op.cc @@ -16,6 +16,76 @@ std::pair, std::vector> splitOpDevInfer( } return std::make_pair(in_dev, out_dev); } + +vector TensorInferenceForSplit( + const OperatorDef& def, + const vector& in) { + auto ret_invalid_shape = [&def]() { + vector out(def.output().size()); + for (auto& out_ts : out) { + out_ts.set_unknown_shape(true); + } + return out; + }; + // We only support shape inference of Split with 1 input + if (def.input_size() != 1 || in.empty() || in.front().unknown_shape()) { + return ret_invalid_shape(); + } else if (def.output_size() == 0) { + return vector(); + } + ArgumentHelper helper(def); + const int axis = helper.HasArgument("axis") + ? helper.GetSingleArgument("axis", -1) + : GetDimFromOrderString( + helper.GetSingleArgument("order", "NCHW")); + const int add_axis = helper.HasArgument("axis") + ? helper.GetSingleArgument("add_axis", 0) + : 0; + const auto& input = in[0]; + const int canonical_axis = canonical_axis_index_(axis, input.dims_size()); + const int input_channels = input.dims(canonical_axis); + auto split = helper.GetRepeatedArgument("split"); + // Equally split the input into outputs + const int output_size = def.output_size(); + if (split.empty()) { + if (!input_channels % output_size) { + LOG(WARNING) << "Input channels (" << input_channels + << ") should be divisible by number of outputs (" + << output_size << ")"; + return ret_invalid_shape(); + } + split.resize(output_size, input_channels / output_size); + } else if (split.size() != output_size) { + LOG(WARNING) << "`split` size (" << split.size() + << ") should be equal to output size (" << output_size << ")"; + return ret_invalid_shape(); + } + + // Check validity of the split + const int total_channels = add_axis + ? def.output_size() + : std::accumulate(split.begin(), split.begin() + output_size, 0); + if (total_channels != input_channels) { + LOG(WARNING) << "Input channels (" << input_channels + << ") is not equal to total output channels (" + << total_channels << ")"; + return ret_invalid_shape(); + } + + vector output_dims(input.dims().begin(), input.dims().end()); + if (add_axis) { + output_dims.erase(output_dims.begin() + canonical_axis); + } + vector output_shapes; + for (int i = 0; i < output_size; ++i) { + if (!add_axis) { + output_dims[canonical_axis] = split[i]; + } + output_shapes.emplace_back( + CreateTensorShape(output_dims, input.data_type())); + } + return output_shapes; +} } // namespace. REGISTER_CPU_OPERATOR(Split, SplitOp); @@ -29,11 +99,15 @@ OPERATOR_SCHEMA(Split) "split", "(*Tensor``*): [OPTIONAL] list of output lengths (see also arg `split`)") .Arg("axis", "(*int*): axis to split on") + .Arg( + "add_axis", + "*(type: int)* Pass non-zero integer to remove the axis specified in `axis` to all input tensors.") .Arg("split", "(*Tuple(int)*): length of each output") .Arg( "order", "(*string*): order of dimensions of input and output blobs; either \"NCHW\" or \"NHWC\"") .Output(0, "[output_0, output_1, ...]", "(*Tensor*): output tensor") + .TensorInferenceFunction(TensorInferenceForSplit) .DeviceInferenceFunction(splitOpDevInfer) .SetDoc(R"DOC( Split an `input` tensor into a list of tensors, along the axis specified by the `axis` dimension. The lengths of the split can be specified using argument `split` or optional second input blob to the operator. Otherwise, the tensor is split to equal sized parts. diff --git a/caffe2/opt/bound_shape_inference_test.cc b/caffe2/opt/bound_shape_inference_test.cc index d8f77cf2b9..a148b0e82a 100644 --- a/caffe2/opt/bound_shape_inference_test.cc +++ b/caffe2/opt/bound_shape_inference_test.cc @@ -214,6 +214,60 @@ TEST(BoundShapeInference, ConcatMissingInput) { {spec.max_batch_size, 2, 60}); } +TEST(BoundShapeInference, Split) { + NetDef net; + net.add_op()->CopyFrom(CreateOperatorDef( + "Split", "", {"X"}, {"Y0", "Y1"}, {MakeArgument("axis", 1)})); + net.add_op()->CopyFrom(CreateOperatorDef( + "Split", + "", + {"X"}, + {"Y2", "Y3", "Y4"}, + {MakeArgument("axis", 1), + MakeArgument>("split", {4, 30, 14})})); + net.add_op()->CopyFrom(CreateOperatorDef( + "Split", + "", + {"X1"}, + {"Y5", "Y6"}, + {MakeArgument("axis", 1), MakeArgument("add_axis", 1)})); + BoundShapeSpec spec(20, 1000); + ShapeInfoMap shape_map; + shape_map.emplace( + "X", + makeTensorInfo(ShapeInfo::DimType::BATCH, {spec.max_batch_size, 48})); + shape_map.emplace( + "X1", + makeTensorInfo(ShapeInfo::DimType::BATCH, {spec.max_batch_size, 2, 48})); + BoundShapeInferencer eng(spec); + eng.InferBoundShapeAndType(net, shape_map); + const auto& out_shape = eng.shape_info(); + verifyShapeInfo( + out_shape, "X", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 48}); + verifyShapeInfo( + out_shape, "X1", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 2, 48}); + verifyShapeInfo( + out_shape, + "Y0", + ShapeInfo::DimType::BATCH, + {spec.max_batch_size, 48 / 2}); + verifyShapeInfo( + out_shape, + "Y1", + ShapeInfo::DimType::BATCH, + {spec.max_batch_size, 48 / 2}); + verifyShapeInfo( + out_shape, "Y2", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 4}); + verifyShapeInfo( + out_shape, "Y3", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 30}); + verifyShapeInfo( + out_shape, "Y4", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 14}); + verifyShapeInfo( + out_shape, "Y5", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 48}); + verifyShapeInfo( + out_shape, "Y6", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 48}); +} + TEST(BoundShapeInference, FC) { NetDef net; net.add_op()->CopyFrom( -- cgit v1.2.3