summaryrefslogtreecommitdiff
path: root/caffe2
diff options
context:
space:
mode:
authorJunjie Bai <bai@in.tum.de>2017-10-13 12:10:34 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2017-10-13 12:28:22 -0700
commit4c3b02f31447037b009e80a0251c0ebd0ee95b7c (patch)
tree5fbf0d5ac9df238d6a9d304c9be86ba5ba454587 /caffe2
parentc3a9423c7f339bbfd16260f7615a94de51615bd0 (diff)
downloadpytorch-4c3b02f31447037b009e80a0251c0ebd0ee95b7c.tar.gz
pytorch-4c3b02f31447037b009e80a0251c0ebd0ee95b7c.tar.bz2
pytorch-4c3b02f31447037b009e80a0251c0ebd0ee95b7c.zip
Enable Flatten operator to take an arbitrary axis arguemnt
Summary: input dimensions up to "axis" will be flattened to the outer dim of output and the remaining input dims will be the inner dim Closes https://github.com/caffe2/caffe2/pull/1330 Reviewed By: dzhulgakov Differential Revision: D6039560 Pulled By: bddppq fbshipit-source-id: e92c30b49a9288feeefc4a639522406e97e149e1
Diffstat (limited to 'caffe2')
-rw-r--r--caffe2/operators/utility_ops.cc38
-rw-r--r--caffe2/operators/utility_ops.h12
-rw-r--r--caffe2/python/operator_test/flatten_op_test.py38
-rw-r--r--caffe2/python/operator_test/shape_inference_test.py16
4 files changed, 86 insertions, 18 deletions
diff --git a/caffe2/operators/utility_ops.cc b/caffe2/operators/utility_ops.cc
index 5763e64a35..9a2e82e09f 100644
--- a/caffe2/operators/utility_ops.cc
+++ b/caffe2/operators/utility_ops.cc
@@ -135,34 +135,44 @@ OPERATOR_SCHEMA(LengthsToShape).NumInputs(1).NumOutputs(1);
OPERATOR_SCHEMA(Flatten)
.NumInputs(1)
.NumOutputs(1)
- .TensorInferenceFunction([](const OperatorDef&,
+ .TensorInferenceFunction([](const OperatorDef& def,
const vector<TensorShape>& in) {
+ ArgumentHelper helper(def);
+ const int axis = helper.GetSingleArgument<int>("axis", 1);
vector<TensorShape> out(1);
- int total = 1;
+ TIndex outer = 1;
+ TIndex inner = 1;
std::size_t index = 0;
for (auto d : in[0].dims()) {
- // skip the first element
- if (index++ == 0) {
- continue;
+ if (index < axis) {
+ outer *= d;
+ } else {
+ inner *= d;
}
- total *= d;
+ ++index;
}
out[0].set_data_type(in[0].data_type());
- out[0].add_dims(in[0].dims(0));
- out[0].add_dims(total);
+ out[0].add_dims(outer);
+ out[0].add_dims(inner);
return out;
})
.SetDoc(R"DOC(
-Flattens the input tensor into a 2D matrix, keeping the first dimension
-unchanged.
+Flattens the input tensor into a 2D matrix. If input tensor has shape
+(d_0, d_1, ... d_n) then the output will have shape
+(d_0 X d_1 ... d_(axis-1), d_axis X d_(axis+1) ... X dn)
)DOC")
- .Input(0, "input", "A tensor of rank >= 2.")
+ .Input(0, "input", "A tensor of rank >= axis.")
.Output(
0,
"output",
- "A tensor of rank 2 with the contents of the input tensor, "
- "with first dimension equal first dimension of input, and remaining "
- "input dimensions flattened into the inner dimension of the output.");
+ "A 2D tensor with the contents of the input tensor, "
+ "with input dimensions up to axis flattened to the outer dimension "
+ "of the output and remaining input dimensions flattened into the inner "
+ "dimension of the output.")
+ .Arg(
+ "axis",
+ "(Default to 1) Indicate up to which input dimensions "
+ "(exclusive) should be flattened to the outer dimension of the output");
OPERATOR_SCHEMA(FlattenToVec)
.NumInputs(1)
diff --git a/caffe2/operators/utility_ops.h b/caffe2/operators/utility_ops.h
index da080a6cfc..0b130d9655 100644
--- a/caffe2/operators/utility_ops.h
+++ b/caffe2/operators/utility_ops.h
@@ -222,14 +222,17 @@ template <class Context>
class FlattenOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
- USE_SIMPLE_CTOR_DTOR(FlattenOp);
+
+ FlattenOp(const OperatorDef& operator_def, Workspace* ws)
+ : Operator<Context>(operator_def, ws),
+ axis_(OperatorBase::GetSingleArgument<int>("axis", 1)) {}
bool RunOnDevice() override {
auto& input = Input(0);
auto* output = Output(0);
CAFFE_ENFORCE_GE(
- input.dims().size(), 2, "The rank of the tensor must be >= 2.");
- output->Resize(input.dim(0), input.size_from_dim(1));
+ input.dims().size(), axis_, "The rank of the tensor must be >= axis.");
+ output->Resize(input.size_to_dim(axis_), input.size_from_dim(axis_));
context_.template CopyItems<Context, Context>(
input.meta(),
input.size(),
@@ -237,6 +240,9 @@ class FlattenOp : public Operator<Context> {
output->raw_mutable_data(input.meta()));
return true;
}
+
+ private:
+ int axis_;
};
template <class Context>
diff --git a/caffe2/python/operator_test/flatten_op_test.py b/caffe2/python/operator_test/flatten_op_test.py
new file mode 100644
index 0000000000..19d204e0bd
--- /dev/null
+++ b/caffe2/python/operator_test/flatten_op_test.py
@@ -0,0 +1,38 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from hypothesis import given
+import numpy as np
+
+from caffe2.python import core
+import caffe2.python.hypothesis_test_util as hu
+
+
+class TestFlatten(hu.HypothesisTestCase):
+ @given(X=hu.tensor(min_dim=2, max_dim=4),
+ **hu.gcs)
+ def test_flatten(self, X, gc, dc):
+ for axis in range(X.ndim + 1):
+ op = core.CreateOperator(
+ "Flatten",
+ ["X"],
+ ["Y"],
+ axis=axis)
+
+ def flatten_ref(X):
+ shape = X.shape
+ outer = np.prod(shape[:axis]).astype(int)
+ inner = np.prod(shape[axis:]).astype(int)
+ return np.copy(X).reshape(outer, inner),
+
+ self.assertReferenceChecks(gc, op, [X], flatten_ref)
+
+ # Check over multiple devices
+ self.assertDeviceChecks(dc, op, [X], [0])
+
+
+if __name__ == "__main__":
+ import unittest
+ unittest.main()
diff --git a/caffe2/python/operator_test/shape_inference_test.py b/caffe2/python/operator_test/shape_inference_test.py
index db7520a179..e903e8ef2e 100644
--- a/caffe2/python/operator_test/shape_inference_test.py
+++ b/caffe2/python/operator_test/shape_inference_test.py
@@ -347,7 +347,7 @@ class TestShapeInference(test_util.TestCase):
self.InferTensorRunAndCompare(model)
-
+ # test Flatten with default axis (=1)
model = model_helper.ModelHelper(name="test_model")
model.Flatten("X", "Flat")
model.Flatten("empty", "EmptyFlat")
@@ -356,6 +356,20 @@ class TestShapeInference(test_util.TestCase):
self.InferTensorRunAndCompare(model)
+ # test Flatten with axis
+ model = model_helper.ModelHelper(name="test_model")
+ x = np.random.randn(17, 5, 13)
+ for axis in range(x.ndim + 1):
+ model.Flatten("x", "Flat", axis=axis)
+ workspace.FeedBlob("x", x)
+ self.InferTensorRunAndCompare(model)
+
+ empty = np.random.randn(0, 5, 13)
+ for axis in range(empty.ndim + 1):
+ model.Flatten("empty", "Flat", axis=axis)
+ workspace.FeedBlob("empty", empty)
+ self.InferTensorRunAndCompare(model)
+
def testShapeInferenceReshape(self):
model = model_helper.ModelHelper(name="test_model")
model.Reshape("X", ["Reshaped", "Old_Shape"], shape=[8, 0, -1, 2])