summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src/FuseAddWithTConvPass.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/pass/src/FuseAddWithTConvPass.cpp')
-rw-r--r--compiler/luci/pass/src/FuseAddWithTConvPass.cpp20
1 files changed, 17 insertions, 3 deletions
diff --git a/compiler/luci/pass/src/FuseAddWithTConvPass.cpp b/compiler/luci/pass/src/FuseAddWithTConvPass.cpp
index 2bca57014..852bc8b63 100644
--- a/compiler/luci/pass/src/FuseAddWithTConvPass.cpp
+++ b/compiler/luci/pass/src/FuseAddWithTConvPass.cpp
@@ -37,10 +37,10 @@ namespace
* \ |
* [CircleTransposeConv] [CircleAdd]
* |
- * ([CircleRelu6])
+ * ([CircleRelu/Relu6])
* |
*
- * Note: CircleRelu6 is inserted if Add activation is ReLU6
+ * Note: CircleRelu/Relu6 is inserted if Add activation is ReLU6
*/
bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv)
{
@@ -65,7 +65,8 @@ bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv)
if (add->dtype() != loco::DataType::FLOAT32)
return false;
if (add->fusedActivationFunction() != luci::FusedActFunc::NONE &&
- add->fusedActivationFunction() != luci::FusedActFunc::RELU6)
+ add->fusedActivationFunction() != luci::FusedActFunc::RELU6 &&
+ add->fusedActivationFunction() != luci::FusedActFunc::RELU)
return false;
// get addition
@@ -102,6 +103,19 @@ bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv)
// remove add node
replace(add).with(relu);
}
+ else if (add->fusedActivationFunction() == luci::FusedActFunc::RELU)
+ {
+ auto name = addition->name();
+ assert(name.length() > 0);
+ // separate relu op from add op
+ auto relu = add->graph()->nodes()->create<luci::CircleRelu>();
+ relu->features(tconv);
+ relu->name(name + "/Relu");
+ luci::add_origin(relu, luci::get_origin(add));
+
+ // remove add node
+ replace(add).with(relu);
+ }
else
{
replace(add).with(tconv);