summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp')
-rw-r--r--compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp25
1 files changed, 25 insertions, 0 deletions
diff --git a/compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp b/compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp
index e80623499..bb8e292d4 100644
--- a/compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp
+++ b/compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp
@@ -271,6 +271,31 @@ TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_type2)
ASSERT_EQ(2, perm->at<loco::DataType::S32>(3));
}
+TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_type3)
+{
+ auto graph = loco::make_graph();
+ create_redundunt_transpose(graph.get(), {0, 3, 2, 1}, {0, 2, 3, 1});
+
+ luci::RemoveRedundantTransposePass pass;
+ while (pass.run(graph.get()))
+ ;
+ luci::CircleTranspose *transpose_node = nullptr;
+ for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
+ {
+ auto trans = dynamic_cast<luci::CircleTranspose *>(node);
+ if (not trans)
+ continue;
+ transpose_node = trans;
+ break;
+ }
+ ASSERT_NE(nullptr, transpose_node);
+ auto perm = loco::must_cast<luci::CircleConst *>(transpose_node->perm());
+ ASSERT_EQ(0, perm->at<loco::DataType::S32>(0));
+ ASSERT_EQ(2, perm->at<loco::DataType::S32>(1));
+ ASSERT_EQ(1, perm->at<loco::DataType::S32>(2));
+ ASSERT_EQ(3, perm->at<loco::DataType::S32>(3));
+}
+
/**
* @brief Test case that first transpose output become input of operations more than one.
*/