diff options
author | Junjie Bai <bai@in.tum.de> | 2017-10-13 12:10:34 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2017-10-13 12:28:22 -0700 |
commit | 4c3b02f31447037b009e80a0251c0ebd0ee95b7c (patch) | |
tree | 5fbf0d5ac9df238d6a9d304c9be86ba5ba454587 /caffe2/operators/utility_ops.cc | |
parent | c3a9423c7f339bbfd16260f7615a94de51615bd0 (diff) | |
download | pytorch-4c3b02f31447037b009e80a0251c0ebd0ee95b7c.tar.gz pytorch-4c3b02f31447037b009e80a0251c0ebd0ee95b7c.tar.bz2 pytorch-4c3b02f31447037b009e80a0251c0ebd0ee95b7c.zip |
Enable Flatten operator to take an arbitrary axis arguemnt
Summary:
input dimensions up to "axis" will be flattened to the outer dim of output and the remaining input dims will be the inner dim
Closes https://github.com/caffe2/caffe2/pull/1330
Reviewed By: dzhulgakov
Differential Revision: D6039560
Pulled By: bddppq
fbshipit-source-id: e92c30b49a9288feeefc4a639522406e97e149e1
Diffstat (limited to 'caffe2/operators/utility_ops.cc')
-rw-r--r-- | caffe2/operators/utility_ops.cc | 38 |
1 files changed, 24 insertions, 14 deletions
diff --git a/caffe2/operators/utility_ops.cc b/caffe2/operators/utility_ops.cc index 5763e64a35..9a2e82e09f 100644 --- a/caffe2/operators/utility_ops.cc +++ b/caffe2/operators/utility_ops.cc @@ -135,34 +135,44 @@ OPERATOR_SCHEMA(LengthsToShape).NumInputs(1).NumOutputs(1); OPERATOR_SCHEMA(Flatten) .NumInputs(1) .NumOutputs(1) - .TensorInferenceFunction([](const OperatorDef&, + .TensorInferenceFunction([](const OperatorDef& def, const vector<TensorShape>& in) { + ArgumentHelper helper(def); + const int axis = helper.GetSingleArgument<int>("axis", 1); vector<TensorShape> out(1); - int total = 1; + TIndex outer = 1; + TIndex inner = 1; std::size_t index = 0; for (auto d : in[0].dims()) { - // skip the first element - if (index++ == 0) { - continue; + if (index < axis) { + outer *= d; + } else { + inner *= d; } - total *= d; + ++index; } out[0].set_data_type(in[0].data_type()); - out[0].add_dims(in[0].dims(0)); - out[0].add_dims(total); + out[0].add_dims(outer); + out[0].add_dims(inner); return out; }) .SetDoc(R"DOC( -Flattens the input tensor into a 2D matrix, keeping the first dimension -unchanged. +Flattens the input tensor into a 2D matrix. If input tensor has shape +(d_0, d_1, ... d_n) then the output will have shape +(d_0 X d_1 ... d_(axis-1), d_axis X d_(axis+1) ... X dn) )DOC") - .Input(0, "input", "A tensor of rank >= 2.") + .Input(0, "input", "A tensor of rank >= axis.") .Output( 0, "output", - "A tensor of rank 2 with the contents of the input tensor, " - "with first dimension equal first dimension of input, and remaining " - "input dimensions flattened into the inner dimension of the output."); + "A 2D tensor with the contents of the input tensor, " + "with input dimensions up to axis flattened to the outer dimension " + "of the output and remaining input dimensions flattened into the inner " + "dimension of the output.") + .Arg( + "axis", + "(Default to 1) Indicate up to which input dimensions " + "(exclusive) should be flattened to the outer dimension of the output"); OPERATOR_SCHEMA(FlattenToVec) .NumInputs(1) |