diff options
author | Junjie Bai <bai@in.tum.de> | 2017-10-13 12:10:34 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2017-10-13 12:28:22 -0700 |
commit | 4c3b02f31447037b009e80a0251c0ebd0ee95b7c (patch) | |
tree | 5fbf0d5ac9df238d6a9d304c9be86ba5ba454587 /caffe2 | |
parent | c3a9423c7f339bbfd16260f7615a94de51615bd0 (diff) | |
download | pytorch-4c3b02f31447037b009e80a0251c0ebd0ee95b7c.tar.gz pytorch-4c3b02f31447037b009e80a0251c0ebd0ee95b7c.tar.bz2 pytorch-4c3b02f31447037b009e80a0251c0ebd0ee95b7c.zip |
Enable Flatten operator to take an arbitrary axis arguemnt
Summary:
input dimensions up to "axis" will be flattened to the outer dim of output and the remaining input dims will be the inner dim
Closes https://github.com/caffe2/caffe2/pull/1330
Reviewed By: dzhulgakov
Differential Revision: D6039560
Pulled By: bddppq
fbshipit-source-id: e92c30b49a9288feeefc4a639522406e97e149e1
Diffstat (limited to 'caffe2')
-rw-r--r-- | caffe2/operators/utility_ops.cc | 38 | ||||
-rw-r--r-- | caffe2/operators/utility_ops.h | 12 | ||||
-rw-r--r-- | caffe2/python/operator_test/flatten_op_test.py | 38 | ||||
-rw-r--r-- | caffe2/python/operator_test/shape_inference_test.py | 16 |
4 files changed, 86 insertions, 18 deletions
diff --git a/caffe2/operators/utility_ops.cc b/caffe2/operators/utility_ops.cc index 5763e64a35..9a2e82e09f 100644 --- a/caffe2/operators/utility_ops.cc +++ b/caffe2/operators/utility_ops.cc @@ -135,34 +135,44 @@ OPERATOR_SCHEMA(LengthsToShape).NumInputs(1).NumOutputs(1); OPERATOR_SCHEMA(Flatten) .NumInputs(1) .NumOutputs(1) - .TensorInferenceFunction([](const OperatorDef&, + .TensorInferenceFunction([](const OperatorDef& def, const vector<TensorShape>& in) { + ArgumentHelper helper(def); + const int axis = helper.GetSingleArgument<int>("axis", 1); vector<TensorShape> out(1); - int total = 1; + TIndex outer = 1; + TIndex inner = 1; std::size_t index = 0; for (auto d : in[0].dims()) { - // skip the first element - if (index++ == 0) { - continue; + if (index < axis) { + outer *= d; + } else { + inner *= d; } - total *= d; + ++index; } out[0].set_data_type(in[0].data_type()); - out[0].add_dims(in[0].dims(0)); - out[0].add_dims(total); + out[0].add_dims(outer); + out[0].add_dims(inner); return out; }) .SetDoc(R"DOC( -Flattens the input tensor into a 2D matrix, keeping the first dimension -unchanged. +Flattens the input tensor into a 2D matrix. If input tensor has shape +(d_0, d_1, ... d_n) then the output will have shape +(d_0 X d_1 ... d_(axis-1), d_axis X d_(axis+1) ... X dn) )DOC") - .Input(0, "input", "A tensor of rank >= 2.") + .Input(0, "input", "A tensor of rank >= axis.") .Output( 0, "output", - "A tensor of rank 2 with the contents of the input tensor, " - "with first dimension equal first dimension of input, and remaining " - "input dimensions flattened into the inner dimension of the output."); + "A 2D tensor with the contents of the input tensor, " + "with input dimensions up to axis flattened to the outer dimension " + "of the output and remaining input dimensions flattened into the inner " + "dimension of the output.") + .Arg( + "axis", + "(Default to 1) Indicate up to which input dimensions " + "(exclusive) should be flattened to the outer dimension of the output"); OPERATOR_SCHEMA(FlattenToVec) .NumInputs(1) diff --git a/caffe2/operators/utility_ops.h b/caffe2/operators/utility_ops.h index da080a6cfc..0b130d9655 100644 --- a/caffe2/operators/utility_ops.h +++ b/caffe2/operators/utility_ops.h @@ -222,14 +222,17 @@ template <class Context> class FlattenOp : public Operator<Context> { public: USE_OPERATOR_CONTEXT_FUNCTIONS; - USE_SIMPLE_CTOR_DTOR(FlattenOp); + + FlattenOp(const OperatorDef& operator_def, Workspace* ws) + : Operator<Context>(operator_def, ws), + axis_(OperatorBase::GetSingleArgument<int>("axis", 1)) {} bool RunOnDevice() override { auto& input = Input(0); auto* output = Output(0); CAFFE_ENFORCE_GE( - input.dims().size(), 2, "The rank of the tensor must be >= 2."); - output->Resize(input.dim(0), input.size_from_dim(1)); + input.dims().size(), axis_, "The rank of the tensor must be >= axis."); + output->Resize(input.size_to_dim(axis_), input.size_from_dim(axis_)); context_.template CopyItems<Context, Context>( input.meta(), input.size(), @@ -237,6 +240,9 @@ class FlattenOp : public Operator<Context> { output->raw_mutable_data(input.meta())); return true; } + + private: + int axis_; }; template <class Context> diff --git a/caffe2/python/operator_test/flatten_op_test.py b/caffe2/python/operator_test/flatten_op_test.py new file mode 100644 index 0000000000..19d204e0bd --- /dev/null +++ b/caffe2/python/operator_test/flatten_op_test.py @@ -0,0 +1,38 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from hypothesis import given +import numpy as np + +from caffe2.python import core +import caffe2.python.hypothesis_test_util as hu + + +class TestFlatten(hu.HypothesisTestCase): + @given(X=hu.tensor(min_dim=2, max_dim=4), + **hu.gcs) + def test_flatten(self, X, gc, dc): + for axis in range(X.ndim + 1): + op = core.CreateOperator( + "Flatten", + ["X"], + ["Y"], + axis=axis) + + def flatten_ref(X): + shape = X.shape + outer = np.prod(shape[:axis]).astype(int) + inner = np.prod(shape[axis:]).astype(int) + return np.copy(X).reshape(outer, inner), + + self.assertReferenceChecks(gc, op, [X], flatten_ref) + + # Check over multiple devices + self.assertDeviceChecks(dc, op, [X], [0]) + + +if __name__ == "__main__": + import unittest + unittest.main() diff --git a/caffe2/python/operator_test/shape_inference_test.py b/caffe2/python/operator_test/shape_inference_test.py index db7520a179..e903e8ef2e 100644 --- a/caffe2/python/operator_test/shape_inference_test.py +++ b/caffe2/python/operator_test/shape_inference_test.py @@ -347,7 +347,7 @@ class TestShapeInference(test_util.TestCase): self.InferTensorRunAndCompare(model) - + # test Flatten with default axis (=1) model = model_helper.ModelHelper(name="test_model") model.Flatten("X", "Flat") model.Flatten("empty", "EmptyFlat") @@ -356,6 +356,20 @@ class TestShapeInference(test_util.TestCase): self.InferTensorRunAndCompare(model) + # test Flatten with axis + model = model_helper.ModelHelper(name="test_model") + x = np.random.randn(17, 5, 13) + for axis in range(x.ndim + 1): + model.Flatten("x", "Flat", axis=axis) + workspace.FeedBlob("x", x) + self.InferTensorRunAndCompare(model) + + empty = np.random.randn(0, 5, 13) + for axis in range(empty.ndim + 1): + model.Flatten("empty", "Flat", axis=axis) + workspace.FeedBlob("empty", empty) + self.InferTensorRunAndCompare(model) + def testShapeInferenceReshape(self): model = model_helper.ModelHelper(name="test_model") model.Reshape("X", ["Reshaped", "Old_Shape"], shape=[8, 0, -1, 2]) |