summaryrefslogtreecommitdiff
path: root/caffe2/operators/utility_ops.cc
diff options
context:
space:
mode:
authorJunjie Bai <bai@in.tum.de>2017-10-13 12:10:34 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2017-10-13 12:28:22 -0700
commit4c3b02f31447037b009e80a0251c0ebd0ee95b7c (patch)
tree5fbf0d5ac9df238d6a9d304c9be86ba5ba454587 /caffe2/operators/utility_ops.cc
parentc3a9423c7f339bbfd16260f7615a94de51615bd0 (diff)
downloadpytorch-4c3b02f31447037b009e80a0251c0ebd0ee95b7c.tar.gz
pytorch-4c3b02f31447037b009e80a0251c0ebd0ee95b7c.tar.bz2
pytorch-4c3b02f31447037b009e80a0251c0ebd0ee95b7c.zip
Enable Flatten operator to take an arbitrary axis arguemnt
Summary: input dimensions up to "axis" will be flattened to the outer dim of output and the remaining input dims will be the inner dim Closes https://github.com/caffe2/caffe2/pull/1330 Reviewed By: dzhulgakov Differential Revision: D6039560 Pulled By: bddppq fbshipit-source-id: e92c30b49a9288feeefc4a639522406e97e149e1
Diffstat (limited to 'caffe2/operators/utility_ops.cc')
-rw-r--r--caffe2/operators/utility_ops.cc38
1 files changed, 24 insertions, 14 deletions
diff --git a/caffe2/operators/utility_ops.cc b/caffe2/operators/utility_ops.cc
index 5763e64a35..9a2e82e09f 100644
--- a/caffe2/operators/utility_ops.cc
+++ b/caffe2/operators/utility_ops.cc
@@ -135,34 +135,44 @@ OPERATOR_SCHEMA(LengthsToShape).NumInputs(1).NumOutputs(1);
OPERATOR_SCHEMA(Flatten)
.NumInputs(1)
.NumOutputs(1)
- .TensorInferenceFunction([](const OperatorDef&,
+ .TensorInferenceFunction([](const OperatorDef& def,
const vector<TensorShape>& in) {
+ ArgumentHelper helper(def);
+ const int axis = helper.GetSingleArgument<int>("axis", 1);
vector<TensorShape> out(1);
- int total = 1;
+ TIndex outer = 1;
+ TIndex inner = 1;
std::size_t index = 0;
for (auto d : in[0].dims()) {
- // skip the first element
- if (index++ == 0) {
- continue;
+ if (index < axis) {
+ outer *= d;
+ } else {
+ inner *= d;
}
- total *= d;
+ ++index;
}
out[0].set_data_type(in[0].data_type());
- out[0].add_dims(in[0].dims(0));
- out[0].add_dims(total);
+ out[0].add_dims(outer);
+ out[0].add_dims(inner);
return out;
})
.SetDoc(R"DOC(
-Flattens the input tensor into a 2D matrix, keeping the first dimension
-unchanged.
+Flattens the input tensor into a 2D matrix. If input tensor has shape
+(d_0, d_1, ... d_n) then the output will have shape
+(d_0 X d_1 ... d_(axis-1), d_axis X d_(axis+1) ... X dn)
)DOC")
- .Input(0, "input", "A tensor of rank >= 2.")
+ .Input(0, "input", "A tensor of rank >= axis.")
.Output(
0,
"output",
- "A tensor of rank 2 with the contents of the input tensor, "
- "with first dimension equal first dimension of input, and remaining "
- "input dimensions flattened into the inner dimension of the output.");
+ "A 2D tensor with the contents of the input tensor, "
+ "with input dimensions up to axis flattened to the outer dimension "
+ "of the output and remaining input dimensions flattened into the inner "
+ "dimension of the output.")
+ .Arg(
+ "axis",
+ "(Default to 1) Indicate up to which input dimensions "
+ "(exclusive) should be flattened to the outer dimension of the output");
OPERATOR_SCHEMA(FlattenToVec)
.NumInputs(1)