summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src/FusePReluPass.test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/pass/src/FusePReluPass.test.cpp')
-rw-r--r--compiler/luci/pass/src/FusePReluPass.test.cpp187
1 files changed, 187 insertions, 0 deletions
diff --git a/compiler/luci/pass/src/FusePReluPass.test.cpp b/compiler/luci/pass/src/FusePReluPass.test.cpp
new file mode 100644
index 000000000..209fe3911
--- /dev/null
+++ b/compiler/luci/pass/src/FusePReluPass.test.cpp
@@ -0,0 +1,187 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FusePReluPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <luci/test/TestIOGraph.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class PReluGraphlet
+{
+public:
+ PReluGraphlet() = default;
+
+ void init(loco::Graph *g)
+ {
+ _abs = g->nodes()->create<luci::CircleAbs>();
+ _sub = g->nodes()->create<luci::CircleSub>();
+ _mul_alpha = g->nodes()->create<luci::CircleMul>();
+ _mul_half = g->nodes()->create<luci::CircleMul>();
+ _relu = g->nodes()->create<luci::CircleRelu>();
+ _add = g->nodes()->create<luci::CircleAdd>();
+ _const_alpha = g->nodes()->create<luci::CircleConst>();
+ _const_half = g->nodes()->create<luci::CircleConst>();
+
+ _sub->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _mul_alpha->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _mul_half->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _add->fusedActivationFunction(luci::FusedActFunc::NONE);
+
+ _abs->name("abs");
+ _sub->name("sub");
+ _mul_alpha->name("mul_alpha");
+ _mul_half->name("mul_half");
+ _relu->name("relu");
+ _add->name("add");
+ _const_alpha->name("const_alpha");
+ _const_half->name("const_half");
+
+ _const_alpha->dtype(loco::DataType::FLOAT32);
+ _const_alpha->size<loco::DataType::FLOAT32>(1);
+ _const_alpha->shape({1});
+ _const_alpha->at<loco::DataType::FLOAT32>(0) = 0.1;
+ _const_alpha->shape_status(luci::ShapeStatus::VALID);
+
+ _const_half->dtype(loco::DataType::FLOAT32);
+ _const_half->size<loco::DataType::FLOAT32>(1);
+ _const_half->shape({1});
+ _const_half->at<loco::DataType::FLOAT32>(0) = 0.5;
+ _const_half->shape_status(luci::ShapeStatus::VALID);
+ }
+
+ void invalid_half() { _const_half->at<loco::DataType::FLOAT32>(0) = 0.1; }
+ void invalid_act() { _add->fusedActivationFunction(luci::FusedActFunc::RELU); }
+
+protected:
+ luci::CircleAbs *_abs = nullptr;
+ luci::CircleSub *_sub = nullptr;
+ luci::CircleMul *_mul_alpha = nullptr;
+ luci::CircleMul *_mul_half = nullptr;
+ luci::CircleRelu *_relu = nullptr;
+ luci::CircleAdd *_add = nullptr;
+ luci::CircleConst *_const_alpha = nullptr;
+ luci::CircleConst *_const_half = nullptr;
+};
+
+class FusePReluTestGraph : public TestIOGraph, public PReluGraphlet
+{
+public:
+ FusePReluTestGraph() = default;
+
+ void init(void)
+ {
+ TestIOGraph::init({1}, {1});
+ PReluGraphlet::init(g());
+
+ _relu->features(input());
+ _abs->x(input());
+ _sub->x(input());
+ _sub->y(_abs);
+ _mul_alpha->x(_sub);
+ _mul_alpha->y(_const_alpha);
+ _mul_half->x(_mul_alpha);
+ _mul_half->y(_const_half);
+ _add->x(_relu);
+ _add->y(_mul_half);
+
+ output()->from(_add);
+ }
+};
+
+class FusePReluTestNegGraph : public TestIOGraph, public PReluGraphlet
+{
+public:
+ FusePReluTestNegGraph() = default;
+
+ void init(void)
+ {
+ TestIOGraph::init({1}, {1});
+ PReluGraphlet::init(g());
+
+ _relu->features(input());
+ _abs->x(input());
+ // NOTE x and y are incorrect
+ _sub->x(_abs);
+ _sub->y(input());
+ _mul_alpha->x(_sub);
+ _mul_alpha->y(_const_alpha);
+ _mul_half->x(_mul_alpha);
+ _mul_half->y(_const_half);
+ _add->x(_relu);
+ _add->y(_mul_half);
+
+ output()->from(_add);
+ }
+};
+
+} // namespace
+
+TEST(FusePReluPassTest, name)
+{
+ luci::FusePReluPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST(FusePReluPassTest, fuse)
+{
+ FusePReluTestGraph g;
+ luci::FusePReluPass pass;
+
+ g.init();
+
+ EXPECT_TRUE(pass.run(g.g()));
+}
+
+TEST(FusePReluPassTest, fuse_invalid_half_NEG)
+{
+ FusePReluTestNegGraph g;
+ luci::FusePReluPass pass;
+
+ g.init();
+ g.invalid_half();
+
+ EXPECT_FALSE(pass.run(g.g()));
+}
+
+TEST(FusePReluPassTest, fuse_invalid_act_NEG)
+{
+ FusePReluTestNegGraph g;
+ luci::FusePReluPass pass;
+
+ g.init();
+ g.invalid_act();
+
+ EXPECT_FALSE(pass.run(g.g()));
+}
+
+TEST(FusePReluPassTest, fuse_NEG)
+{
+ FusePReluTestNegGraph g;
+ luci::FusePReluPass pass;
+
+ g.init();
+
+ EXPECT_FALSE(pass.run(g.g()));
+}