summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src/FuseInstanceNormPass.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/pass/src/FuseInstanceNormPass.cpp')
-rw-r--r--compiler/luci/pass/src/FuseInstanceNormPass.cpp401
1 files changed, 401 insertions, 0 deletions
diff --git a/compiler/luci/pass/src/FuseInstanceNormPass.cpp b/compiler/luci/pass/src/FuseInstanceNormPass.cpp
new file mode 100644
index 000000000..180b5bbef
--- /dev/null
+++ b/compiler/luci/pass/src/FuseInstanceNormPass.cpp
@@ -0,0 +1,401 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FuseInstanceNormPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <loco/Service/ShapeInference.h>
+
+#include <cassert>
+#include <set>
+
+// Helper to find commutative node's arguments
+namespace
+{
+
+/**
+ * INTRODUCTION
+ * Binary operation f(x,y) is 'commutative' when
+ * f(x,y) == f(y,x) holds for all x, y.
+ * For examples, ADD, MUL and SQUARED_DIFFERENCE are commutative.
+ * These helpers make it easy to find commutative arguemnts of commtative node.
+ *
+ * HOW TO USE
+ * COMM_NODE *node;
+ * ARG_TYPE_1 *arg1;
+ * ARG_TYPE_2 *arg2;
+ *
+ * bool ok = fill(&arg1, &arg2).with_commutative_args_of(node);
+ *
+ * Result
+ * If 'node's commutative argument types are actually {ARG_TYPE_1, ARG_TYPE_2}
+ * (as a set), 'arg1' and 'arg2' set as actual 'node's arguemnts with matching
+ * type, and return value 'ok' is true.
+ * Otherwise, 'arg1' and 'arg2' not changed, 'ok' is false.
+ */
+
+template <class ARG_TYPE_1, class ARG_TYPE_2> class NodeFiller final
+{
+public:
+ NodeFiller(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2) : _arg_1(arg_1), _arg_2(arg_2)
+ {
+ // DO NOTHING
+ }
+
+ /**
+ * @return true When 'node's argument types are 'ARG_TYPE_1' and 'ARG_TYPE_2'
+ * In such case, it assign '_arg_1' and '_arg_2' to actual arguments
+ *
+ * @return false When 'node's argument types are NOT matched with 'ARG_TYPE_*'
+ * In such case, it does not amend '_arg_1' and '_arg_2'
+ *
+ * @require COMM_NODE has member x() and y()
+ */
+ template <class COMM_NODE> bool with_commutative_args_of(const COMM_NODE *node);
+
+private:
+ ARG_TYPE_1 **_arg_1;
+ ARG_TYPE_2 **_arg_2;
+};
+
+template <class ARG_TYPE_1, class ARG_TYPE_2>
+inline NodeFiller<ARG_TYPE_1, ARG_TYPE_2> fill(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2)
+{
+ return NodeFiller<ARG_TYPE_1, ARG_TYPE_2>{arg_1, arg_2};
+}
+
+template <class ARG_TYPE_1, class ARG_TYPE_2>
+template <class COMM_NODE>
+bool NodeFiller<ARG_TYPE_1, ARG_TYPE_2>::with_commutative_args_of(const COMM_NODE *node)
+{
+ // Case 1) X == ARG_TYPE_1 / Y == ARG_TYPE_2
+ {
+ auto x = dynamic_cast<ARG_TYPE_1 *>(node->x());
+ auto y = dynamic_cast<ARG_TYPE_2 *>(node->y());
+
+ if (x && y)
+ {
+ *_arg_1 = x;
+ *_arg_2 = y;
+ return true;
+ }
+ }
+
+ // Case 2) X == ARG_TYPE_2 / Y == ARG_TYPE_1
+ {
+ auto x = dynamic_cast<ARG_TYPE_2 *>(node->x());
+ auto y = dynamic_cast<ARG_TYPE_1 *>(node->y());
+
+ if (x && y)
+ {
+ *_arg_1 = y;
+ *_arg_2 = x;
+ return true;
+ }
+ }
+
+ return false;
+}
+
+} // namespace
+
+// Helper to check detail
+namespace
+{
+
+/// @return true When node has shape of '1 x .. x 1 x depth'
+bool is_1D_with_dummy_dim(luci::CircleConst *node, uint32_t depth)
+{
+ auto rank = node->rank();
+ uint32_t axis;
+ for (axis = 0; axis < rank - 1; ++axis)
+ {
+ if (node->dim(axis).value() != 1)
+ return false;
+ }
+ return node->dim(axis).value() == depth;
+}
+
+bool is_instance_mean(luci::CircleMean *mean)
+{
+ //
+ // CHECK 1) input is rank 4
+ //
+ auto input = mean->input();
+ if (not loco::shape_known(input))
+ return false;
+ auto input_shape = loco::shape_get(input).as<loco::TensorShape>();
+ if (input_shape.rank() != 4)
+ return false;
+
+ //
+ // CHECK 2) 'reduction indices' is CircleConst of value [1,2], that is HW of NHWC
+ //
+ // TODO Support equivalent case, like [-3,-2]
+ // TODO Support non-Const case?
+ // TODO What if input is NCHW format in Circle?
+ auto red_indices = dynamic_cast<luci::CircleConst *>(mean->reduction_indices());
+ if (not red_indices)
+ return false;
+ if (red_indices->rank() != 1)
+ return false;
+ std::set<int32_t> red_indices_set;
+ {
+ // TODO Currently only support S32, support other types
+ assert(red_indices->dtype() == loco::DataType::S32);
+ for (uint32_t i = 0; i < red_indices->dim(0).value(); ++i)
+ red_indices_set.insert(red_indices->at<loco::DataType::S32>(i));
+ }
+ if (red_indices_set.size() != 2)
+ return false;
+ if (red_indices_set.find(1) == red_indices_set.end())
+ return false;
+ if (red_indices_set.find(2) == red_indices_set.end())
+ return false;
+
+ //
+ // CHECK 3) keep_dims == true (?)
+ //
+ // We only have case of 'keep_dims == true' so far, but it might be okay with 'keep_dims == false'
+ // TODO Check this fact, and if true, return true regardless of keep_dims
+ return mean->keep_dims();
+}
+
+} // namespace
+
+// Helper to fuse Instance Norm
+namespace
+{
+
+/**
+ * SUBGRAPH PATTERN
+ *
+ * - Below diagram shows Instance Norm pattern to fuse.
+ * - Execution dependency order is top to the bottom.
+ * - Node name is matched with variable name of InstanceNormPattern class.
+ * - Usually, first word of node name (variable name) is node type. For e.g.
+ * variable 'mean_as_variance' is pointer to TFLMean.
+ * - (Item in parenthesis) means actually exist, but not having a name and
+ * not a variable of InstanceNormPattern class.
+ *
+ * TODO support other semantically same patterns for instance norm
+ *
+ * [In]
+ * |
+ * V
+ * +----------- ifm -----+ (reduction indicies)
+ * | | | |
+ * | | V V
+ * | | mean_of_ifm ----------------+
+ * | V | |
+ * | sqdiff <--+ (reduction indicies) |
+ * | | | |
+ * | V | |
+ * | mean_as_variance <---+ const_as_epsilon |
+ * | | | |
+ * | V | |
+ * | add_as_variance <--------+ |
+ * | | |
+ * | V |
+ * | rsqrt const_as_gamma |
+ * | | | |
+ * | V | |
+ * | mul_gamma <--+ |
+ * | | | |
+ * V V V |
+ * mul_as_scaled_ifm mul_as_scaled_mean <-------------+
+ * | |
+ * | const_as_beta |
+ * | | V
+ * | +------> sub
+ * V |
+ * add_as_terminal <----------+
+ * |
+ * V
+ * [Out]
+ */
+class InstanceNormPattern final
+{
+public:
+ InstanceNormPattern(luci::CircleAdd *candidate)
+ {
+ assert(candidate);
+ add_as_terminal = candidate;
+ }
+
+public:
+ bool matched();
+ bool matched() const { return _matched; }
+
+public:
+ // Context
+ loco::Node *ifm = nullptr;
+ luci::CircleMean *mean_of_ifm = nullptr;
+ luci::CircleSquaredDifference *sqdiff = nullptr;
+ luci::CircleMean *mean_as_variance = nullptr;
+ luci::CircleConst *const_as_epsilon = nullptr;
+ luci::CircleAdd *add_as_variance = nullptr;
+ luci::CircleRsqrt *rsqrt = nullptr;
+ luci::CircleConst *const_as_gamma = nullptr;
+ luci::CircleMul *mul_gamma = nullptr;
+ luci::CircleMul *mul_as_scaled_ifm = nullptr;
+ luci::CircleMul *mul_as_scaled_mean = nullptr;
+ luci::CircleConst *const_as_beta = nullptr;
+ luci::CircleSub *sub = nullptr;
+ luci::CircleAdd *add_as_terminal = nullptr;
+
+private:
+ bool _matched = false;
+};
+
+bool InstanceNormPattern::matched()
+{
+ if (_matched)
+ return true;
+
+#define CHECK_OR_FALSE(condition) \
+ if (not(condition)) \
+ return false;
+
+ // Check order is DFS
+
+ CHECK_OR_FALSE(fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal));
+ CHECK_OR_FALSE(fill(&ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_ifm));
+
+ CHECK_OR_FALSE(loco::shape_known(ifm));
+ auto ifm_shape = loco::shape_get(ifm);
+ CHECK_OR_FALSE(ifm_shape.domain() == loco::Domain::Tensor);
+ auto ifm_tensor_shape = ifm_shape.as<loco::TensorShape>();
+ CHECK_OR_FALSE(ifm_tensor_shape.rank() == 4);
+ uint32_t ifm_channel_depth = ifm_tensor_shape.dim(3).value();
+
+ CHECK_OR_FALSE(fill(&rsqrt, &const_as_gamma).with_commutative_args_of(mul_gamma));
+ CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_gamma, ifm_channel_depth));
+
+ add_as_variance = dynamic_cast<luci::CircleAdd *>(rsqrt->x());
+ CHECK_OR_FALSE(add_as_variance);
+
+ CHECK_OR_FALSE(
+ fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
+
+ CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
+ // TODO Support regarding broadcast
+ CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
+
+ CHECK_OR_FALSE(is_instance_mean(mean_as_variance));
+ sqdiff = dynamic_cast<luci::CircleSquaredDifference *>(mean_as_variance->input());
+ CHECK_OR_FALSE(sqdiff);
+
+ loco::Node *ifm_should_be = nullptr;
+ CHECK_OR_FALSE(fill(&ifm_should_be, &mean_of_ifm).with_commutative_args_of(sqdiff));
+ CHECK_OR_FALSE(ifm == ifm_should_be);
+ CHECK_OR_FALSE(is_instance_mean(mean_of_ifm));
+ CHECK_OR_FALSE(ifm == mean_of_ifm->input());
+
+ const_as_beta = dynamic_cast<luci::CircleConst *>(sub->x());
+ CHECK_OR_FALSE(const_as_beta);
+ CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_beta, ifm_channel_depth));
+
+ mul_as_scaled_mean = dynamic_cast<luci::CircleMul *>(sub->y());
+ CHECK_OR_FALSE(mul_as_scaled_mean);
+
+ luci::CircleMul *mul_gamma_should_be = nullptr;
+ luci::CircleMean *mean_of_ifm_should_be = nullptr;
+ CHECK_OR_FALSE(fill(&mul_gamma_should_be, &mean_of_ifm_should_be)
+ .with_commutative_args_of(mul_as_scaled_mean));
+ CHECK_OR_FALSE(mul_gamma == mul_gamma_should_be);
+ CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
+#undef CHECK_OR_FALSE
+ _matched = true;
+ return true;
+}
+
+/**
+ * Instance norm pattern would be fused like following diagram:
+ *
+ * [In] --------------------------- CircleInstanceNorm --- [Out]
+ * / /
+ * const_as_gamma --- TFLReshape --- /
+ * /
+ * const_as_beta ---- TFLReshape ---
+ *
+ * Note
+ * - 'const_as_gamma' and 'const_as_beta' are from original graph
+ * - Value of 'const_as_epsilon' would be copied to CircleInstanceNorm's attribute
+ * - TFLReshape is added as CircleInstanceNorm only accept 1D tensor
+ * - 'CircleConst --- TFLReshape' is expected to be fused in constant folding for Reshape
+ */
+void fuse_instance_norm(const InstanceNormPattern &p)
+{
+ assert(p.matched());
+
+ auto graph = p.add_as_terminal->graph();
+
+ // Make reshape for gamma & beta
+ auto reshape_gamma = graph->nodes()->create<luci::CircleReshape>();
+ auto reshape_beta = graph->nodes()->create<luci::CircleReshape>();
+ {
+ auto ifm_shape = loco::shape_get(p.ifm).as<loco::TensorShape>();
+ uint32_t ifm_channel_depth = ifm_shape.dim(3).value();
+
+ int32_t new_shape[1] = {static_cast<int32_t>(ifm_channel_depth)};
+
+ reshape_gamma->tensor(p.const_as_gamma);
+ reshape_beta->tensor(p.const_as_beta);
+
+ luci::set_new_shape(reshape_gamma, new_shape, 1);
+ luci::set_new_shape(reshape_beta, new_shape, 1);
+ }
+
+ // Make Instance Norm to replace
+ auto instance_norm = graph->nodes()->create<luci::CircleInstanceNorm>();
+ instance_norm->input(p.ifm);
+ instance_norm->gamma(reshape_gamma);
+ instance_norm->beta(reshape_beta);
+ float epsilon = p.const_as_epsilon->at<loco::DataType::FLOAT32>(0);
+ instance_norm->epsilon(epsilon);
+ instance_norm->fusedActivationFunction(p.add_as_terminal->fusedActivationFunction());
+
+ replace(p.add_as_terminal).with(instance_norm);
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool FuseInstanceNormPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto add = dynamic_cast<luci::CircleAdd *>(node);
+ if (not add)
+ continue;
+
+ InstanceNormPattern pattern(add);
+ if (not pattern.matched())
+ continue;
+
+ fuse_instance_norm(pattern);
+ changed = true;
+ }
+
+ return changed;
+}
+
+} // namespace luci