summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp')
-rw-r--r--compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp211
1 files changed, 211 insertions, 0 deletions
diff --git a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp
new file mode 100644
index 000000000..373513270
--- /dev/null
+++ b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp
@@ -0,0 +1,211 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ForwardReshapeToUnaryOpPass.h"
+#include "luci/Pass/CircleShapeInferencePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <luci/test/TestIOGraph.h>
+
+#include <gtest/gtest.h>
+
+#include <vector>
+
+namespace
+{
+
+using namespace luci::test;
+
+class ReshapeNegGraphlet
+{
+public:
+ ReshapeNegGraphlet() = default;
+
+public:
+ void init(loco::Graph *g, const ShapeU32 shape_in, const ShapeU32 shape_out)
+ {
+ std::vector<uint32_t> shape_out_v = shape_out;
+
+ _reshape_shape = g->nodes()->create<luci::CircleConst>();
+ _reshape = g->nodes()->create<luci::CircleReshape>();
+ _neg = g->nodes()->create<luci::CircleNeg>();
+
+ _reshape_shape->dtype(loco::DataType::S32);
+ _reshape_shape->rank(1);
+ _reshape_shape->dim(0).set(shape_out_v.size());
+ _reshape_shape->shape_status(luci::ShapeStatus::VALID);
+ // values
+ const auto size = shape_out_v.size();
+ _reshape_shape->size<loco::DataType::S32>(size);
+ for (uint32_t i = 0; i < size; i++)
+ _reshape_shape->at<loco::DataType::S32>(i) = shape_out_v[i];
+
+ _reshape_shape->name("reshape_shape");
+ _reshape->name("reshape");
+ _neg->name("neg");
+ }
+
+protected:
+ luci::CircleReshape *_reshape = nullptr;
+ luci::CircleNeg *_neg = nullptr;
+ luci::CircleConst *_reshape_shape = nullptr;
+};
+
+// TODO Reduce duplicate code with ReshapeNegGraphlet
+class ReshapeLogisticGraphlet
+{
+public:
+ ReshapeLogisticGraphlet() = default;
+
+public:
+ void init(loco::Graph *g, const ShapeU32 shape_in, const ShapeU32 shape_out)
+ {
+ std::vector<uint32_t> shape_out_v = shape_out;
+
+ _reshape_shape = g->nodes()->create<luci::CircleConst>();
+ _reshape = g->nodes()->create<luci::CircleReshape>();
+ _logistic = g->nodes()->create<luci::CircleLogistic>();
+
+ _reshape_shape->dtype(loco::DataType::S32);
+ _reshape_shape->rank(1);
+ _reshape_shape->dim(0).set(shape_out_v.size());
+ _reshape_shape->shape_status(luci::ShapeStatus::VALID);
+ // values
+ const auto size = shape_out_v.size();
+ _reshape_shape->size<loco::DataType::S32>(size);
+ for (uint32_t i = 0; i < size; i++)
+ _reshape_shape->at<loco::DataType::S32>(i) = shape_out_v[i];
+
+ _reshape_shape->name("reshape_shape");
+ _reshape->name("reshape");
+ _logistic->name("logistic");
+ }
+
+protected:
+ luci::CircleReshape *_reshape = nullptr;
+ luci::CircleLogistic *_logistic = nullptr;
+ luci::CircleConst *_reshape_shape = nullptr;
+};
+
+class ForwardReshapeToNegGraph : public TestIOGraph, public ReshapeNegGraphlet
+{
+public:
+ ForwardReshapeToNegGraph() = default;
+
+public:
+ void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
+ {
+ TestIOGraph::init(shape_in, shape_out);
+ ReshapeNegGraphlet::init(g(), shape_in, shape_out);
+
+ // connect network
+ _reshape->tensor(input());
+ _reshape->shape(_reshape_shape);
+ _neg->x(_reshape);
+
+ output()->from(_neg);
+ }
+};
+
+class ForwardReshapeToLogisticGraph : public TestIOGraph, public ReshapeLogisticGraphlet
+{
+public:
+ ForwardReshapeToLogisticGraph() = default;
+
+public:
+ void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
+ {
+ TestIOGraph::init(shape_in, shape_out);
+ ReshapeLogisticGraphlet::init(g(), shape_in, shape_out);
+
+ // connect network
+ _reshape->tensor(input());
+ _reshape->shape(_reshape_shape);
+ _logistic->x(_reshape);
+
+ output()->from(_logistic);
+ }
+};
+
+class ForwardReshapeToNegGraphTest : public ::testing::Test
+{
+public:
+ ForwardReshapeToNegGraphTest() = default;
+
+ void run_pass(void)
+ {
+ while (_pass.run(_graph.g()))
+ ;
+ }
+
+protected:
+ ForwardReshapeToNegGraph _graph;
+ luci::ForwardReshapeToUnaryOpPass _pass;
+};
+
+class ForwardReshapeToLogisticGraphTest : public ::testing::Test
+{
+public:
+ ForwardReshapeToLogisticGraphTest() = default;
+
+ void run_pass(void)
+ {
+ while (_pass.run(_graph.g()))
+ ;
+ }
+
+protected:
+ ForwardReshapeToLogisticGraph _graph;
+ luci::ForwardReshapeToUnaryOpPass _pass;
+};
+
+} // namespace
+
+TEST(ForwardReshapeToUnaryOpPassTest, name)
+{
+ luci::ForwardReshapeToUnaryOpPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST_F(ForwardReshapeToNegGraphTest, simple_forward)
+{
+ _graph.init({2, 2, 2}, {2, 4});
+
+ run_pass();
+
+ auto reshape = dynamic_cast<luci::CircleReshape *>(_graph.output()->from());
+ auto neg = dynamic_cast<luci::CircleNeg *>(_graph.output()->from());
+ ASSERT_NE(nullptr, reshape);
+ ASSERT_EQ(nullptr, neg);
+ neg = dynamic_cast<luci::CircleNeg *>(reshape->tensor());
+ ASSERT_NE(nullptr, neg);
+}
+
+TEST_F(ForwardReshapeToLogisticGraphTest, forward)
+{
+ _graph.init({2, 2, 2}, {2, 4});
+
+ run_pass();
+
+ auto reshape = dynamic_cast<luci::CircleReshape *>(_graph.output()->from());
+ auto log = dynamic_cast<luci::CircleLogistic *>(_graph.output()->from());
+ ASSERT_NE(nullptr, reshape);
+ ASSERT_EQ(nullptr, log);
+ log = dynamic_cast<luci::CircleLogistic *>(reshape->tensor());
+ ASSERT_NE(nullptr, log);
+}