diff options
Diffstat (limited to 'compiler/luci/pass/src/FuseActivationFunctionPass.cpp')
-rw-r--r-- | compiler/luci/pass/src/FuseActivationFunctionPass.cpp | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/compiler/luci/pass/src/FuseActivationFunctionPass.cpp b/compiler/luci/pass/src/FuseActivationFunctionPass.cpp index 844541d2d..66e341518 100644 --- a/compiler/luci/pass/src/FuseActivationFunctionPass.cpp +++ b/compiler/luci/pass/src/FuseActivationFunctionPass.cpp @@ -17,7 +17,9 @@ #include "luci/Pass/FuseActivationFunctionPass.h" #include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleNodeMixins.h> #include <luci/IR/CircleOpcode.h> +#include <luci/Profile/CircleNodeOrigin.h> namespace luci { @@ -32,10 +34,15 @@ bool fuse_activation_function(luci::CircleNode *node) return false; auto node_with_fused_act = - dynamic_cast<luci::LuciNodeMixin<luci::LuciNodeTrait::FusedActFunc> *>(pred_node); + dynamic_cast<luci::CircleNodeMixin<luci::CircleNodeTrait::FusedActFunc> *>(pred_node); if (node_with_fused_act == nullptr) return false; + // TODO remove this work-around + // This will skip fuse for concat as luci-interpreter doesn't support this yet + if (dynamic_cast<luci::CircleConcatenation *>(pred_node) != nullptr) + return false; + auto fused_act = node_with_fused_act->fusedActivationFunction(); luci::FusedActFunc target_func = luci::FusedActFunc::UNDEFINED; @@ -76,6 +83,7 @@ bool fuse_activation_function(luci::CircleNode *node) return false; node_with_fused_act->fusedActivationFunction(target_func); + luci::add_origin(pred_node, luci::get_origin(node)); loco::replace(node).with(pred_node); node->drop(); |