summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYinghai Lu <yinghai@fb.com>2019-04-04 00:19:21 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-04 00:22:22 -0700
commite5e2110a8ead028c863a7f449273bf6ee90bc423 (patch)
tree6bb667e33b825feae1ea56eb8cefadd8c4f529ea
parent0c237f1383abdddcfe7dcaa00af7fc075ddb8d67 (diff)
downloadpytorch-e5e2110a8ead028c863a7f449273bf6ee90bc423.tar.gz
pytorch-e5e2110a8ead028c863a7f449273bf6ee90bc423.tar.bz2
pytorch-e5e2110a8ead028c863a7f449273bf6ee90bc423.zip
Add shape inference function for Split (#18838)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18838 It turns out that we don't have shape inference function of `Split` op at all. This diff adds that. Reviewed By: bertmaher Differential Revision: D14766871 fbshipit-source-id: 535cb4f24bdada603c76579e00e7a39aee93e19f
-rw-r--r--caffe2/operators/concat_split_op.cc74
-rw-r--r--caffe2/opt/bound_shape_inference_test.cc54
2 files changed, 128 insertions, 0 deletions
diff --git a/caffe2/operators/concat_split_op.cc b/caffe2/operators/concat_split_op.cc
index ff665783c3..3b4bf971c2 100644
--- a/caffe2/operators/concat_split_op.cc
+++ b/caffe2/operators/concat_split_op.cc
@@ -16,6 +16,76 @@ std::pair<std::vector<DeviceOption>, std::vector<DeviceOption>> splitOpDevInfer(
}
return std::make_pair(in_dev, out_dev);
}
+
+vector<TensorShape> TensorInferenceForSplit(
+ const OperatorDef& def,
+ const vector<TensorShape>& in) {
+ auto ret_invalid_shape = [&def]() {
+ vector<TensorShape> out(def.output().size());
+ for (auto& out_ts : out) {
+ out_ts.set_unknown_shape(true);
+ }
+ return out;
+ };
+ // We only support shape inference of Split with 1 input
+ if (def.input_size() != 1 || in.empty() || in.front().unknown_shape()) {
+ return ret_invalid_shape();
+ } else if (def.output_size() == 0) {
+ return vector<TensorShape>();
+ }
+ ArgumentHelper helper(def);
+ const int axis = helper.HasArgument("axis")
+ ? helper.GetSingleArgument<int>("axis", -1)
+ : GetDimFromOrderString(
+ helper.GetSingleArgument<string>("order", "NCHW"));
+ const int add_axis = helper.HasArgument("axis")
+ ? helper.GetSingleArgument<int>("add_axis", 0)
+ : 0;
+ const auto& input = in[0];
+ const int canonical_axis = canonical_axis_index_(axis, input.dims_size());
+ const int input_channels = input.dims(canonical_axis);
+ auto split = helper.GetRepeatedArgument<int>("split");
+ // Equally split the input into outputs
+ const int output_size = def.output_size();
+ if (split.empty()) {
+ if (!input_channels % output_size) {
+ LOG(WARNING) << "Input channels (" << input_channels
+ << ") should be divisible by number of outputs ("
+ << output_size << ")";
+ return ret_invalid_shape();
+ }
+ split.resize(output_size, input_channels / output_size);
+ } else if (split.size() != output_size) {
+ LOG(WARNING) << "`split` size (" << split.size()
+ << ") should be equal to output size (" << output_size << ")";
+ return ret_invalid_shape();
+ }
+
+ // Check validity of the split
+ const int total_channels = add_axis
+ ? def.output_size()
+ : std::accumulate(split.begin(), split.begin() + output_size, 0);
+ if (total_channels != input_channels) {
+ LOG(WARNING) << "Input channels (" << input_channels
+ << ") is not equal to total output channels ("
+ << total_channels << ")";
+ return ret_invalid_shape();
+ }
+
+ vector<int> output_dims(input.dims().begin(), input.dims().end());
+ if (add_axis) {
+ output_dims.erase(output_dims.begin() + canonical_axis);
+ }
+ vector<TensorShape> output_shapes;
+ for (int i = 0; i < output_size; ++i) {
+ if (!add_axis) {
+ output_dims[canonical_axis] = split[i];
+ }
+ output_shapes.emplace_back(
+ CreateTensorShape(output_dims, input.data_type()));
+ }
+ return output_shapes;
+}
} // namespace.
REGISTER_CPU_OPERATOR(Split, SplitOp<CPUContext>);
@@ -29,11 +99,15 @@ OPERATOR_SCHEMA(Split)
"split",
"(*Tensor`<int>`*): [OPTIONAL] list of output lengths (see also arg `split`)")
.Arg("axis", "(*int*): axis to split on")
+ .Arg(
+ "add_axis",
+ "*(type: int)* Pass non-zero integer to remove the axis specified in `axis` to all input tensors.")
.Arg("split", "(*Tuple(int)*): length of each output")
.Arg(
"order",
"(*string*): order of dimensions of input and output blobs; either \"NCHW\" or \"NHWC\"")
.Output(0, "[output_0, output_1, ...]", "(*Tensor*): output tensor")
+ .TensorInferenceFunction(TensorInferenceForSplit)
.DeviceInferenceFunction(splitOpDevInfer)
.SetDoc(R"DOC(
Split an `input` tensor into a list of tensors, along the axis specified by the `axis` dimension. The lengths of the split can be specified using argument `split` or optional second input blob to the operator. Otherwise, the tensor is split to equal sized parts.
diff --git a/caffe2/opt/bound_shape_inference_test.cc b/caffe2/opt/bound_shape_inference_test.cc
index d8f77cf2b9..a148b0e82a 100644
--- a/caffe2/opt/bound_shape_inference_test.cc
+++ b/caffe2/opt/bound_shape_inference_test.cc
@@ -214,6 +214,60 @@ TEST(BoundShapeInference, ConcatMissingInput) {
{spec.max_batch_size, 2, 60});
}
+TEST(BoundShapeInference, Split) {
+ NetDef net;
+ net.add_op()->CopyFrom(CreateOperatorDef(
+ "Split", "", {"X"}, {"Y0", "Y1"}, {MakeArgument<int>("axis", 1)}));
+ net.add_op()->CopyFrom(CreateOperatorDef(
+ "Split",
+ "",
+ {"X"},
+ {"Y2", "Y3", "Y4"},
+ {MakeArgument<int>("axis", 1),
+ MakeArgument<std::vector<int>>("split", {4, 30, 14})}));
+ net.add_op()->CopyFrom(CreateOperatorDef(
+ "Split",
+ "",
+ {"X1"},
+ {"Y5", "Y6"},
+ {MakeArgument<int>("axis", 1), MakeArgument<int>("add_axis", 1)}));
+ BoundShapeSpec spec(20, 1000);
+ ShapeInfoMap shape_map;
+ shape_map.emplace(
+ "X",
+ makeTensorInfo(ShapeInfo::DimType::BATCH, {spec.max_batch_size, 48}));
+ shape_map.emplace(
+ "X1",
+ makeTensorInfo(ShapeInfo::DimType::BATCH, {spec.max_batch_size, 2, 48}));
+ BoundShapeInferencer eng(spec);
+ eng.InferBoundShapeAndType(net, shape_map);
+ const auto& out_shape = eng.shape_info();
+ verifyShapeInfo(
+ out_shape, "X", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 48});
+ verifyShapeInfo(
+ out_shape, "X1", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 2, 48});
+ verifyShapeInfo(
+ out_shape,
+ "Y0",
+ ShapeInfo::DimType::BATCH,
+ {spec.max_batch_size, 48 / 2});
+ verifyShapeInfo(
+ out_shape,
+ "Y1",
+ ShapeInfo::DimType::BATCH,
+ {spec.max_batch_size, 48 / 2});
+ verifyShapeInfo(
+ out_shape, "Y2", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 4});
+ verifyShapeInfo(
+ out_shape, "Y3", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 30});
+ verifyShapeInfo(
+ out_shape, "Y4", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 14});
+ verifyShapeInfo(
+ out_shape, "Y5", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 48});
+ verifyShapeInfo(
+ out_shape, "Y6", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 48});
+}
+
TEST(BoundShapeInference, FC) {
NetDef net;
net.add_op()->CopyFrom(