summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp')
-rw-r--r--compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp223
1 files changed, 223 insertions, 0 deletions
diff --git a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp
new file mode 100644
index 000000000..7096c2591
--- /dev/null
+++ b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp
@@ -0,0 +1,223 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+namespace
+{
+
+luci::CircleConst *create_weights_from_gamma(luci::CircleConst *gamma)
+{
+ assert(gamma->rank() == 1);
+ auto channel_size = gamma->dim(0).value();
+
+ // Channel-wise MUL is the same as DEPTHWISE_CONV2D with filter shape (1,1,1,channel_size)
+ auto weights = gamma->graph()->nodes()->create<luci::CircleConst>();
+ weights->dtype(loco::DataType::FLOAT32);
+ weights->rank(4);
+ weights->dim(0).set(1);
+ weights->dim(1).set(1);
+ weights->dim(2).set(1);
+ weights->dim(3).set(channel_size);
+ weights->shape_status(luci::ShapeStatus::VALID);
+ weights->size<loco::DataType::FLOAT32>(channel_size);
+ for (uint32_t i = 0; i < channel_size; i++)
+ {
+ weights->at<loco::DataType::FLOAT32>(i) = gamma->at<loco::DataType::FLOAT32>(i);
+ }
+
+ return weights;
+}
+
+luci::CircleConst *create_bias_from_beta(luci::CircleConst *beta)
+{
+ assert(beta->rank() == 1);
+ auto channel_size = beta->dim(0).value();
+
+ // Channel-wise ADD is the same as bias (shape = (channel_size)) of DEPTHWISE_CONV2D
+ auto bias = beta->graph()->nodes()->create<luci::CircleConst>();
+ bias->dtype(loco::DataType::FLOAT32);
+ bias->rank(1);
+ bias->dim(0).set(channel_size);
+ bias->size<loco::DataType::FLOAT32>(channel_size);
+ bias->shape_status(luci::ShapeStatus::VALID);
+ for (uint32_t i = 0; i < channel_size; i++)
+ {
+ bias->at<loco::DataType::FLOAT32>(i) = beta->at<loco::DataType::FLOAT32>(i);
+ }
+
+ return bias;
+}
+
+bool is_batchnorm_add(const luci::CircleAdd *add, luci::CircleMul *&mul, luci::CircleConst *&beta)
+{
+ auto x = loco::must_cast<luci::CircleNode *>(add->x());
+ auto y = loco::must_cast<luci::CircleNode *>(add->y());
+
+ luci::CircleMul *pred = nullptr;
+ luci::CircleConst *constant = nullptr;
+
+ if (x->opcode() == luci::CircleOpcode::CIRCLECONST && y->opcode() == luci::CircleOpcode::MUL)
+ {
+ pred = loco::must_cast<luci::CircleMul *>(y);
+ constant = loco::must_cast<luci::CircleConst *>(x);
+ }
+ else if (x->opcode() == luci::CircleOpcode::MUL && y->opcode() == luci::CircleOpcode::CIRCLECONST)
+ {
+ pred = loco::must_cast<luci::CircleMul *>(x);
+ constant = loco::must_cast<luci::CircleConst *>(y);
+ }
+ else
+ {
+ return false;
+ }
+
+ if (constant->rank() != 1)
+ return false;
+
+ auto channel_dim = constant->dim(0);
+ // Assumption: Layout is channel-last
+ if (!(channel_dim == add->dim(add->rank() - 1)))
+ return false;
+
+ mul = pred;
+ beta = constant;
+ return true;
+}
+
+// Check if mul is batchnorm mul
+bool is_batchnorm_mul(const luci::CircleMul *mul, luci::CircleNode *&pred_node,
+ luci::CircleConst *&gamma)
+{
+ auto x = dynamic_cast<luci::CircleConst *>(mul->x());
+ auto y = dynamic_cast<luci::CircleConst *>(mul->y());
+
+ luci::CircleNode *pred = nullptr;
+ luci::CircleConst *constant = nullptr;
+
+ if (x != nullptr && y == nullptr)
+ {
+ pred = loco::must_cast<luci::CircleNode *>(mul->y());
+ constant = x;
+ }
+ else if (x == nullptr && y != nullptr)
+ {
+ pred = loco::must_cast<luci::CircleNode *>(mul->x());
+ constant = y;
+ }
+ else
+ {
+ return false;
+ }
+
+ if (constant->rank() != 1)
+ return false;
+
+ auto channel_dim = constant->dim(0);
+ if (!(channel_dim == mul->dim(mul->rank() - 1)))
+ return false;
+
+ pred_node = pred;
+ gamma = constant;
+ return true;
+}
+
+/**
+ * Replace channel-wise Mul/Add with DepthwiseConv2D
+ *
+ * BEFORE
+ *
+ * [Node] [gamma]
+ * | /
+ * [Mul] [beta]
+ * | /
+ * [Add]
+ *
+ * AFTER
+ *
+ * [Node] [weights] [bias]
+ * \ / /
+ * [DepthwiseConv2D]
+ */
+bool replace_mul_add_with_dwconv(luci::CircleAdd *add)
+{
+ luci::CircleNode *pred_node = nullptr;
+ luci::CircleMul *mul = nullptr;
+ luci::CircleConst *beta = nullptr;
+ luci::CircleConst *gamma = nullptr;
+
+ if (!is_batchnorm_add(add, mul, beta))
+ return false;
+
+ if (loco::succs(mul).size() != 1)
+ return false;
+
+ if (!is_batchnorm_mul(mul, pred_node, gamma))
+ return false;
+
+ if (pred_node->rank() != 4)
+ return false;
+
+ if (pred_node->dtype() != loco::DataType::FLOAT32 || beta->dtype() != loco::DataType::FLOAT32 ||
+ gamma->dtype() != loco::DataType::FLOAT32)
+ return false;
+
+ auto weights = create_weights_from_gamma(gamma);
+ auto bias = create_bias_from_beta(beta);
+
+ auto dwconv = add->graph()->nodes()->create<luci::CircleDepthwiseConv2D>();
+ dwconv->input(pred_node);
+ dwconv->filter(weights);
+ dwconv->bias(bias);
+ dwconv->padding(luci::Padding::SAME);
+ dwconv->stride()->w(1);
+ dwconv->stride()->h(1);
+ dwconv->depthMultiplier(1);
+ dwconv->dilation()->w(1);
+ dwconv->dilation()->h(1);
+ dwconv->fusedActivationFunction(add->fusedActivationFunction());
+
+ loco::replace(add).with(dwconv);
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool ReplaceMulAddWithDepthwiseConvPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto add = dynamic_cast<luci::CircleAdd *>(node);
+ if (not add)
+ continue;
+
+ if (replace_mul_add_with_dwconv(add))
+ {
+ changed = true;
+ break;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci