/* * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "luci/Pass/SubstitutePackToReshapePass.h" #include #include namespace { void create_substitute_pack_to_reshape(loco::Graph *g, const std::initializer_list shape, int32_t axis) { assert(g); // Input Create. auto input = g->nodes()->create(); auto graph_input = g->inputs()->create(); input->index(graph_input->index()); input->shape_status(luci::ShapeStatus::VALID); input->rank(shape.size()); input->shape(shape); input->name("input"); // Pack Node create. auto pack = g->nodes()->create(1); pack->values(0, input); pack->axis(axis); pack->name("pack"); // Output Connect. auto output = g->nodes()->create(); output->from(pack); auto graph_output = g->outputs()->create(); output->index(graph_output->index()); output->name("output"); return; } } // namespace TEST(SubstitutePackToReshapePassTest, name) { luci::SubstitutePackToReshapePass pass; auto const name = pass.name(); ASSERT_NE(nullptr, name); } TEST(SubstitutePackToReshapePass, simple_case) { auto graph = loco::make_graph(); create_substitute_pack_to_reshape(graph.get(), {1, 2, 3, 4}, 0); luci::SubstitutePackToReshapePass pass; while (pass.run(graph.get())) ; luci::CircleReshape *reshape_node = nullptr; luci::CirclePack *pack_node = nullptr; for (auto node : loco::active_nodes(loco::output_nodes(graph.get()))) { if (auto reshape = dynamic_cast(node)) reshape_node = reshape; else if (auto pack = dynamic_cast(node)) pack_node = pack; } ASSERT_NE(nullptr, reshape_node); ASSERT_EQ(nullptr, pack_node); auto new_shape = loco::must_cast(reshape_node->shape()); ASSERT_EQ(1, new_shape->at(0)); ASSERT_EQ(1, new_shape->at(1)); ASSERT_EQ(2, new_shape->at(2)); ASSERT_EQ(3, new_shape->at(3)); ASSERT_EQ(4, new_shape->at(4)); } TEST(SubstitutePackToReshapePass, simple_case_neg_axis) { auto graph = loco::make_graph(); create_substitute_pack_to_reshape(graph.get(), {1, 2, 3, 4}, -1); luci::SubstitutePackToReshapePass pass; while (pass.run(graph.get())) ; luci::CircleReshape *reshape_node = nullptr; luci::CirclePack *pack_node = nullptr; for (auto node : loco::active_nodes(loco::output_nodes(graph.get()))) { if (auto reshape = dynamic_cast(node)) reshape_node = reshape; else if (auto pack = dynamic_cast(node)) pack_node = pack; } ASSERT_NE(nullptr, reshape_node); ASSERT_EQ(nullptr, pack_node); auto new_shape = loco::must_cast(reshape_node->shape()); ASSERT_EQ(1, new_shape->at(0)); ASSERT_EQ(2, new_shape->at(1)); ASSERT_EQ(3, new_shape->at(2)); ASSERT_EQ(4, new_shape->at(3)); ASSERT_EQ(1, new_shape->at(4)); }