summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src/FuseActivationFunctionPass.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/pass/src/FuseActivationFunctionPass.cpp')
-rw-r--r--compiler/luci/pass/src/FuseActivationFunctionPass.cpp10
1 files changed, 9 insertions, 1 deletions
diff --git a/compiler/luci/pass/src/FuseActivationFunctionPass.cpp b/compiler/luci/pass/src/FuseActivationFunctionPass.cpp
index 844541d2d..66e341518 100644
--- a/compiler/luci/pass/src/FuseActivationFunctionPass.cpp
+++ b/compiler/luci/pass/src/FuseActivationFunctionPass.cpp
@@ -17,7 +17,9 @@
#include "luci/Pass/FuseActivationFunctionPass.h"
#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeMixins.h>
#include <luci/IR/CircleOpcode.h>
+#include <luci/Profile/CircleNodeOrigin.h>
namespace luci
{
@@ -32,10 +34,15 @@ bool fuse_activation_function(luci::CircleNode *node)
return false;
auto node_with_fused_act =
- dynamic_cast<luci::LuciNodeMixin<luci::LuciNodeTrait::FusedActFunc> *>(pred_node);
+ dynamic_cast<luci::CircleNodeMixin<luci::CircleNodeTrait::FusedActFunc> *>(pred_node);
if (node_with_fused_act == nullptr)
return false;
+ // TODO remove this work-around
+ // This will skip fuse for concat as luci-interpreter doesn't support this yet
+ if (dynamic_cast<luci::CircleConcatenation *>(pred_node) != nullptr)
+ return false;
+
auto fused_act = node_with_fused_act->fusedActivationFunction();
luci::FusedActFunc target_func = luci::FusedActFunc::UNDEFINED;
@@ -76,6 +83,7 @@ bool fuse_activation_function(luci::CircleNode *node)
return false;
node_with_fused_act->fusedActivationFunction(target_func);
+ luci::add_origin(pred_node, luci::get_origin(node));
loco::replace(node).with(pred_node);
node->drop();