summaryrefslogtreecommitdiff
path: root/compiler/moco-tf/src/BroadcastHelper.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/moco-tf/src/BroadcastHelper.cpp')
-rw-r--r--compiler/moco-tf/src/BroadcastHelper.cpp226
1 files changed, 226 insertions, 0 deletions
diff --git a/compiler/moco-tf/src/BroadcastHelper.cpp b/compiler/moco-tf/src/BroadcastHelper.cpp
new file mode 100644
index 000000000..fc058c141
--- /dev/null
+++ b/compiler/moco-tf/src/BroadcastHelper.cpp
@@ -0,0 +1,226 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "BroadcastHelper.h"
+
+#include <loco/IR/Nodes.h>
+#include <loco/Service/ShapeInference.h>
+
+#include <cassert>
+
+namespace
+{
+
+class NodeWithTensorShape
+{
+public:
+ NodeWithTensorShape() = default;
+
+public:
+ NodeWithTensorShape(loco::Node *node, const loco::TensorShape &shape) : _node{node}, _shape{shape}
+ {
+ // DO NOTHING
+ }
+
+public:
+ loco::Node *node(void) const { return _node; }
+ const loco::TensorShape &shape(void) const { return _shape; }
+
+private:
+ loco::Node *_node = nullptr;
+ loco::TensorShape _shape;
+};
+
+NodeWithTensorShape glue(loco::Node *node, const loco::TensorShape &shape)
+{
+ return NodeWithTensorShape(node, shape);
+}
+
+/**
+ * @brief Create a higher-rank TensorShape following NumPy broadcasting semantics
+ *
+ * HOW TO USE:
+ *
+ * auto expanded_tensor_shape = expand(tensor_shape).to(N);
+ */
+class TensorShapeExpander
+{
+public:
+ TensorShapeExpander(const loco::TensorShape &shape) : _shape{shape}
+ {
+ // DO NOTHING
+ }
+
+public:
+ loco::TensorShape to(uint32_t output_rank)
+ {
+ auto const &input_shape = _shape;
+ uint32_t const input_rank = input_shape.rank();
+
+ assert(input_rank <= output_rank && "Cannot shrink rank");
+ uint32_t const axis_shift = output_rank - input_rank;
+
+ loco::TensorShape output_shape;
+
+ output_shape.rank(output_rank);
+ for (uint32_t axis = 0; axis < output_rank; ++axis)
+ {
+ output_shape.dim(axis) = (axis < axis_shift) ? 1 : input_shape.dim(axis - axis_shift);
+ }
+
+ return output_shape;
+ }
+
+private:
+ const loco::TensorShape _shape;
+};
+
+TensorShapeExpander expand(const loco::TensorShape &shape) { return TensorShapeExpander{shape}; }
+
+/**
+ * @brief Create a rank-expanded node (if required)
+ */
+class ExpandRankFunctor final
+{
+public:
+ ExpandRankFunctor(uint32_t rank) : _rank{rank}
+ {
+ // DO NOTHING
+ }
+
+public:
+ NodeWithTensorShape operator()(const NodeWithTensorShape &in) const
+ {
+ auto const input_node = in.node();
+ auto const input_shape = in.shape();
+ auto const input_rank = input_shape.rank();
+
+ uint32_t const expected_rank = _rank;
+
+ assert(input_rank <= expected_rank);
+ if (input_rank == expected_rank)
+ {
+ // Nothing to expand
+ return in;
+ }
+
+ auto g = input_node->graph();
+ assert(g != nullptr);
+
+ auto output_shape = expand(input_shape).to(expected_rank);
+ auto output_node = g->nodes()->create<loco::FixedReshape>();
+
+ output_node->input(input_node);
+ output_node->rank(expected_rank);
+ for (uint32_t axis = 0; axis < expected_rank; ++axis)
+ {
+ output_node->dim(axis) = output_shape.dim(axis);
+ }
+
+ return glue(output_node, output_shape);
+ }
+
+private:
+ uint32_t _rank;
+};
+
+ExpandRankFunctor expand_rank_to(uint32_t rank) { return ExpandRankFunctor{rank}; }
+
+/**
+ * @brief Create a dimension-expanded node (if required)
+ */
+class ExpandDimsFunctor final
+{
+public:
+ ExpandDimsFunctor(const loco::TensorShape &shape) : _shape{shape}
+ {
+ // DO NOTHING
+ }
+
+public:
+ NodeWithTensorShape operator()(const NodeWithTensorShape &in) const
+ {
+ auto const input_node = in.node();
+ auto const input_shape = in.shape();
+ const auto &output_shape = _shape;
+
+ assert(input_shape.rank() == output_shape.rank());
+
+ if (input_shape == output_shape)
+ {
+ // Nothing to expand
+ return in;
+ }
+
+ uint32_t const rank = output_shape.rank();
+
+ auto g = input_node->graph();
+ assert(g != nullptr);
+
+ auto output_node = g->nodes()->create<loco::TensorBroadcast>();
+
+ for (uint32_t axis = 0; axis < rank; ++axis)
+ {
+ auto input_dim = input_shape.dim(axis);
+ auto output_dim = output_shape.dim(axis);
+
+ assert(input_dim.known() and output_dim.known());
+
+ if (!(input_dim == output_dim))
+ {
+ assert(input_dim == 1);
+ output_node->mapping()->dim(axis) = output_dim;
+ }
+ }
+
+ output_node->input(input_node);
+
+ return glue(output_node, output_shape);
+ }
+
+private:
+ loco::TensorShape _shape;
+};
+
+ExpandDimsFunctor expand_dims_as(const loco::TensorShape &shape)
+{
+ return ExpandDimsFunctor{shape};
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+loco::Node *BroadcastFunctor::build(loco::Node *node, const loco::TensorShape &shape) const
+{
+ // clang-format off
+ return glue(node, shape)
+ | expand_rank_to(_shape.rank())
+ | expand_dims_as(_shape)
+ | [] (const NodeWithTensorShape &in) { return in.node(); };
+ // clang-format on
+}
+
+loco::Node *BroadcastFunctor::build(loco::Node *node) const
+{
+ return build(node, loco::shape_get(node).as<loco::TensorShape>());
+}
+
+} // namespace tf
+} // namespace moco