diff options
Diffstat (limited to 'compiler/luci/pass/src/FuseAddWithTConvPass.cpp')
-rw-r--r-- | compiler/luci/pass/src/FuseAddWithTConvPass.cpp | 20 |
1 files changed, 17 insertions, 3 deletions
diff --git a/compiler/luci/pass/src/FuseAddWithTConvPass.cpp b/compiler/luci/pass/src/FuseAddWithTConvPass.cpp index 2bca57014..852bc8b63 100644 --- a/compiler/luci/pass/src/FuseAddWithTConvPass.cpp +++ b/compiler/luci/pass/src/FuseAddWithTConvPass.cpp @@ -37,10 +37,10 @@ namespace * \ | * [CircleTransposeConv] [CircleAdd] * | - * ([CircleRelu6]) + * ([CircleRelu/Relu6]) * | * - * Note: CircleRelu6 is inserted if Add activation is ReLU6 + * Note: CircleRelu/Relu6 is inserted if Add activation is ReLU6 */ bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv) { @@ -65,7 +65,8 @@ bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv) if (add->dtype() != loco::DataType::FLOAT32) return false; if (add->fusedActivationFunction() != luci::FusedActFunc::NONE && - add->fusedActivationFunction() != luci::FusedActFunc::RELU6) + add->fusedActivationFunction() != luci::FusedActFunc::RELU6 && + add->fusedActivationFunction() != luci::FusedActFunc::RELU) return false; // get addition @@ -102,6 +103,19 @@ bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv) // remove add node replace(add).with(relu); } + else if (add->fusedActivationFunction() == luci::FusedActFunc::RELU) + { + auto name = addition->name(); + assert(name.length() > 0); + // separate relu op from add op + auto relu = add->graph()->nodes()->create<luci::CircleRelu>(); + relu->features(tconv); + relu->name(name + "/Relu"); + luci::add_origin(relu, luci::get_origin(add)); + + // remove add node + replace(add).with(relu); + } else { replace(add).with(tconv); |