diff options
Diffstat (limited to 'compiler/logo/src')
12 files changed, 357 insertions, 371 deletions
diff --git a/compiler/logo/src/Passes/ConstantFoldingPass.cpp b/compiler/logo/src/Passes/ConstantFoldingPass.cpp deleted file mode 100644 index e038e7140..000000000 --- a/compiler/logo/src/Passes/ConstantFoldingPass.cpp +++ /dev/null @@ -1,174 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <logo/ConstantFoldingPass.h> - -#include <loco.h> -#include <loco/IR/CanonicalDialect.h> - -#include <stdex/Memory.h> - -#include <locomotiv/Session.h> - -#include <cassert> -#include <stdexcept> - -namespace -{ - -uint64_t num_elements(const loco::NodeMixin<loco::NodeTrait::TensorShape> &shape) -{ - if (shape.rank() == 0) - { - return 0; - } - - uint64_t res = 1; - - for (uint32_t axis = 0; axis < shape.rank(); ++axis) - { - assert(shape.dim(axis).known()); - res *= shape.dim(axis).value(); - } - - return res; -} - -/// @brief For some op, constant folding should not be performed. This returns true if node is such -/// op. -bool skip(const loco::Node *node) -{ - static std::set<uint32_t> skip_op = { - // TODO Current implementation works for 'Tensor' domain only. Support other domains such as - // `Feature`, `Filter`, `Bias`, etc. - static_cast<uint32_t>(loco::CanonicalOpcode::FilterEncode), - static_cast<uint32_t>(loco::CanonicalOpcode::FeatureEncode), - static_cast<uint32_t>(loco::CanonicalOpcode::BiasEncode), - static_cast<uint32_t>(loco::CanonicalOpcode::DepthwiseFilterEncode), - - // We don't perform constant folding for Push - static_cast<uint32_t>(loco::CanonicalOpcode::Push), - - // TensorBroadcast is a good hint for optimization - // TODO Let this option be controlled by driver using logo - static_cast<uint32_t>(loco::CanonicalOpcode::TensorBroadcast), - }; - - if (node->dialect() == loco::CanonicalDialect::get()) - { - if (skip_op.find(node->opnum()) != skip_op.end()) - return true; - } - - return false; -} - -/// @brief Checks if a node is a target of constant folding transform -bool foldable(const loco::Node *node) -{ - if (node->dialect() == loco::CanonicalDialect::get()) - { - if (skip(node)) - return false; - - if (node->arity() == 0) // e.g., when a node is e.g, ConstGen or Pull - return false; - - // When all args are ConstGen, let's do Constant Folding Transforms - for (int i = 0; i < node->arity(); i++) - { - if (node->arg(i)->opnum() != static_cast<uint32_t>(loco::CanonicalOpcode::ConstGen)) - return false; - } - - return true; - } - else - { - return false; - } -} - -void fold(loco::Graph *graph, loco::Node *node) -{ - assert(foldable(node)); // sanity check to find a mistake when this function is reused later - - // calcluate foldable node - locomotiv::Session sess(graph, std::vector<loco::Node *>{node}); - sess.infer(); - auto data = sess.get_output(0); - - assert(data != nullptr); - - auto shape = data->shape(); - auto dtype = data->dtype(); - - // build ConstGen - auto new_const = graph->nodes()->create<loco::ConstGen>(); - { - new_const->dtype(dtype); - - new_const->rank(shape->rank()); - for (int d = 0; d < shape->rank(); d++) - new_const->dim(d) = shape->dim(d); - - auto count = num_elements(*new_const); - - if (dtype == loco::DataType::FLOAT32) - { - new_const->size<loco::DataType::FLOAT32>(count); - - auto const_buf = data->as_f32_bufptr()->base(); - for (int x = 0; x < count; x++) - new_const->at<loco::DataType::FLOAT32>(x) = const_buf[x]; - } - else if (dtype == loco::DataType::S32) - { - new_const->size<loco::DataType::S32>(count); - - auto const_buf = data->as_s32_bufptr()->base(); - for (int x = 0; x < count; x++) - new_const->at<loco::DataType::S32>(x) = const_buf[x]; - } - } - - // replace node with new_const - loco::replace(node).with(new_const); -} - -} // namespace - -namespace logo -{ - -bool ConstantFoldingPass::run(loco::Graph *graph) -{ - auto outputs = loco::output_nodes(graph); - - bool changed = false; - for (auto node : loco::postorder_traversal(outputs)) - { - if (foldable(node)) - { - fold(graph, node); - changed = true; - } - } - - return changed; -} - -} // namespace logo diff --git a/compiler/logo/src/Passes/ConstantFoldingPass.test.cpp b/compiler/logo/src/Passes/ConstantFoldingPass.test.cpp deleted file mode 100644 index b9c4942c4..000000000 --- a/compiler/logo/src/Passes/ConstantFoldingPass.test.cpp +++ /dev/null @@ -1,179 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <logo/ConstantFoldingPass.h> - -#include "TestHelper.h" - -#include <loco.h> - -#include <gtest/gtest.h> - -using namespace logo::test; - -namespace -{ - -/* - test case: - ConstGen ---- Relu ---- Push - (-3.14, 3.14) (0, 3.14) - - after constant folding: - ConstGen ------Push - (0, 3.14) -*/ -void create_net_const_relu(loco::Graph *graph) -{ - assert(graph); - - auto const_node = graph->nodes()->create<loco::ConstGen>(); - { - const_node->dtype(loco::DataType::FLOAT32); - const_node->rank(1); - const_node->dim(0) = 2; - const_node->size<loco::DataType::FLOAT32>(2); - const_node->at<loco::DataType::FLOAT32>(0) = -3.14f; - const_node->at<loco::DataType::FLOAT32>(1) = 3.14f; - } - - auto relu_node = graph->nodes()->create<loco::ReLU>(); - { - relu_node->input(const_node); - } - - auto push_node = graph->nodes()->create<loco::Push>(); - { - push_node->from(relu_node); - } - - auto graph_output = graph->outputs()->create(); - { - graph_output->name("output"); - graph_output->dtype(loco::DataType::FLOAT32); - loco::link(graph_output, push_node); - } -} - -} // namespace - -TEST(ConstantFolding, const_relu_to_const) -{ - auto graph = loco::make_graph(); - create_net_const_relu(graph.get()); - - logo::ConstantFoldingPass pass; - while (pass.run(graph.get()) == true) - { - ; - } - - auto push = logo::test::find_first_node_by_type<loco::Push>(graph.get()); - auto const_gen = loco::must_cast<loco::ConstGen *>(push->from()); - ASSERT_NE(const_gen, nullptr); - - ASSERT_EQ(const_gen->size<loco::DataType::FLOAT32>(), 2); - ASSERT_EQ(const_gen->at<loco::DataType::FLOAT32>(0), 0); // result of relu(-3.14) - ASSERT_EQ(const_gen->at<loco::DataType::FLOAT32>(1), 3.14f); -} - -namespace -{ - -/* - test case: - ConstGen ---- Relu ---+ - (-1, 1) (0, 1) | - ConstGen ---+-- ConcatV2 ----- Push - (2, 3) | (0, 1, 2, 3) - axis(0) ---+ - - after constant folding: - ConstGen ----- Push - (0, 1, 2, 3) -*/ -void create_net_const_relu_concat(loco::Graph *graph) -{ - assert(graph); - - auto const_1_node = graph->nodes()->create<loco::ConstGen>(); - { - const_1_node->dtype(loco::DataType::FLOAT32); - const_1_node->rank(1); - const_1_node->dim(0) = 2; - const_1_node->size<loco::DataType::FLOAT32>(2); - const_1_node->at<loco::DataType::FLOAT32>(0) = -1.0f; - const_1_node->at<loco::DataType::FLOAT32>(1) = 1.0f; - } - - auto relu_node = graph->nodes()->create<loco::ReLU>(); - { - relu_node->input(const_1_node); - } - - auto const_2_node = graph->nodes()->create<loco::ConstGen>(); - { - const_2_node->dtype(loco::DataType::FLOAT32); - const_2_node->rank(1); - const_2_node->dim(0) = 2; - const_2_node->size<loco::DataType::FLOAT32>(2); - const_2_node->at<loco::DataType::FLOAT32>(0) = 2.0f; - const_2_node->at<loco::DataType::FLOAT32>(1) = 3.0f; - } - - auto concat_node = graph->nodes()->create<loco::TensorConcat>(); - { - concat_node->lhs(relu_node); - concat_node->rhs(const_2_node); - concat_node->axis(0); - } - - auto push_node = graph->nodes()->create<loco::Push>(); - { - push_node->from(concat_node); - } - - auto graph_output = graph->outputs()->create(); - { - graph_output->name("output"); - graph_output->dtype(loco::DataType::FLOAT32); - loco::link(graph_output, push_node); - } -} - -} // namespace - -TEST(ConstantFolding, const_relu_to_concat) -{ - auto graph = loco::make_graph(); - create_net_const_relu_concat(graph.get()); - - logo::ConstantFoldingPass pass; - while (pass.run(graph.get()) == true) - { - ; - } - - auto push = logo::test::find_first_node_by_type<loco::Push>(graph.get()); - auto const_gen = loco::must_cast<loco::ConstGen *>(push->from()); - ASSERT_NE(const_gen, nullptr); - - ASSERT_EQ(const_gen->size<loco::DataType::FLOAT32>(), 4); - ASSERT_EQ(const_gen->at<loco::DataType::FLOAT32>(0), 0); - ASSERT_EQ(const_gen->at<loco::DataType::FLOAT32>(1), 1); - ASSERT_EQ(const_gen->at<loco::DataType::FLOAT32>(2), 2); - ASSERT_EQ(const_gen->at<loco::DataType::FLOAT32>(3), 3); -} diff --git a/compiler/logo/src/Passes/EmptyTestGraph.h b/compiler/logo/src/Passes/EmptyTestGraph.h new file mode 100644 index 000000000..67f2c8a11 --- /dev/null +++ b/compiler/logo/src/Passes/EmptyTestGraph.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOGO_EMPTY_TEST_GRAPH_H__ +#define __LOGO_EMPTY_TEST_GRAPH_H__ + +#include <loco.h> + +namespace logo +{ + +void create_empty_test_net(loco::Graph *graph); + +} // namespace logo + +#endif // __LOGO_EMPTY_TEST_GRAPH_H__ diff --git a/compiler/logo/src/Passes/EmptyTestGraph.test.cpp b/compiler/logo/src/Passes/EmptyTestGraph.test.cpp new file mode 100644 index 000000000..46750b79c --- /dev/null +++ b/compiler/logo/src/Passes/EmptyTestGraph.test.cpp @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <loco.h> + +#include <gtest/gtest.h> + +namespace logo +{ + +void create_empty_test_net(loco::Graph *graph) +{ + assert(graph); + + auto const_node = graph->nodes()->create<loco::ConstGen>(); + { + const_node->dtype(loco::DataType::FLOAT32); + const_node->rank(1); + const_node->dim(0) = 1; + const_node->size<loco::DataType::FLOAT32>(1); + const_node->at<loco::DataType::FLOAT32>(0) = 1.0f; + } + + auto push_node = graph->nodes()->create<loco::Push>(); + { + push_node->from(const_node); + } + + auto graph_output = graph->outputs()->create(); + { + graph_output->name("output"); + graph_output->dtype(loco::DataType::FLOAT32); + loco::link(graph_output, push_node); + } +} + +} // namespace logo diff --git a/compiler/logo/src/Passes/RemoveDeadNodePass.test.cpp b/compiler/logo/src/Passes/RemoveDeadNodePass.test.cpp new file mode 100644 index 000000000..c0ecbdaa9 --- /dev/null +++ b/compiler/logo/src/Passes/RemoveDeadNodePass.test.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <logo/RemoveDeadNodePass.h> + +#include "EmptyTestGraph.h" + +#include <gtest/gtest.h> + +TEST(RemoveDeadNodePassTest, name) +{ + logo::RemoveDeadNodePass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST(RemoveDeadNodePassTest, run_NEG) +{ + loco::Graph g; + logo::RemoveDeadNodePass pass; + + logo::create_empty_test_net(&g); + + ASSERT_FALSE(pass.run(&g)); +} diff --git a/compiler/logo/src/Passes/RemoveDeadNodeWithQueryPass.test.cpp b/compiler/logo/src/Passes/RemoveDeadNodeWithQueryPass.test.cpp new file mode 100644 index 000000000..f14bfc30d --- /dev/null +++ b/compiler/logo/src/Passes/RemoveDeadNodeWithQueryPass.test.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <logo/RemoveDeadNodeWithQueryPass.h> + +#include "EmptyTestGraph.h" + +#include <gtest/gtest.h> + +TEST(RemoveDeadNodeWithQueryPassTest, name) +{ + logo::RemoveDeadNodeWithQueryPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST(RemoveDeadNodeWithQueryPassTest, run_NEG) +{ + loco::Graph g; + logo::RemoveDeadNodeWithQueryPass pass; + + logo::create_empty_test_net(&g); + + ASSERT_FALSE(pass.run(&g)); +} diff --git a/compiler/logo/src/Passes/RemoveForwardNodePass.test.cpp b/compiler/logo/src/Passes/RemoveForwardNodePass.test.cpp new file mode 100644 index 000000000..bb905aec5 --- /dev/null +++ b/compiler/logo/src/Passes/RemoveForwardNodePass.test.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <logo/RemoveForwardNodePass.h> + +#include "EmptyTestGraph.h" + +#include <gtest/gtest.h> + +TEST(RemoveForwardNodePassTest, name) +{ + logo::RemoveForwardNodePass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST(RemoveForwardNodePassTest, run_NEG) +{ + loco::Graph g; + logo::RemoveForwardNodePass pass; + + logo::create_empty_test_net(&g); + + ASSERT_FALSE(pass.run(&g)); +} diff --git a/compiler/logo/src/Passes/ReorderDecodePass.test.cpp b/compiler/logo/src/Passes/ReorderDecodePass.test.cpp new file mode 100644 index 000000000..f8e158d3a --- /dev/null +++ b/compiler/logo/src/Passes/ReorderDecodePass.test.cpp @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <logo/ReorderDecodePass.h> + +#include "EmptyTestGraph.h" + +#include <gtest/gtest.h> + +TEST(ReorderDecodePassTest, TensorBiasAdd_name) +{ + logo::ReorderDecodePass<loco::TensorBiasAdd> pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST(ReorderDecodePassTest, ReLU_name) +{ + logo::ReorderDecodePass<loco::ReLU> pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST(ReorderDecodePassTest, TensorBiasAdd_run_NEG) +{ + loco::Graph g; + logo::ReorderDecodePass<loco::TensorBiasAdd> pass; + + logo::create_empty_test_net(&g); + + ASSERT_FALSE(pass.run(&g)); +} + +TEST(ReorderDecodePassTest, ReLU_run_NEG) +{ + loco::Graph g; + logo::ReorderDecodePass<loco::ReLU> pass; + + logo::create_empty_test_net(&g); + + ASSERT_FALSE(pass.run(&g)); +} diff --git a/compiler/logo/src/Passes/ResolveDuplicateReshapePass.test.cpp b/compiler/logo/src/Passes/ResolveDuplicateReshapePass.test.cpp new file mode 100644 index 000000000..de2df6fd5 --- /dev/null +++ b/compiler/logo/src/Passes/ResolveDuplicateReshapePass.test.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <logo/ResolveDuplicateReshapePass.h> + +#include "EmptyTestGraph.h" + +#include <gtest/gtest.h> + +TEST(ResolveDuplicateReshapePassTest, name) +{ + logo::ResolveDuplicateReshapePass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST(ResolveDuplicateReshapePassTest, run_NEG) +{ + loco::Graph g; + logo::ResolveDuplicateReshapePass pass; + + logo::create_empty_test_net(&g); + + ASSERT_FALSE(pass.run(&g)); +} diff --git a/compiler/logo/src/Passes/ResolveRedundantReshapePass.test.cpp b/compiler/logo/src/Passes/ResolveRedundantReshapePass.test.cpp new file mode 100644 index 000000000..9a7e95846 --- /dev/null +++ b/compiler/logo/src/Passes/ResolveRedundantReshapePass.test.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <logo/ResolveRedundantReshapePass.h> + +#include "EmptyTestGraph.h" + +#include <gtest/gtest.h> + +TEST(ResolveRedundantReshapePassTest, name) +{ + logo::ResolveRedundantReshapePass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST(ResolveRedundantReshapePassTest, run_NEG) +{ + loco::Graph g; + logo::ResolveRedundantReshapePass pass; + + logo::create_empty_test_net(&g); + + ASSERT_FALSE(pass.run(&g)); +} diff --git a/compiler/logo/src/Passes/SimplifyDomainConversionPass.cpp b/compiler/logo/src/Passes/SimplifyDomainConversionPass.cpp index 0bda85b6f..40ddb133b 100644 --- a/compiler/logo/src/Passes/SimplifyDomainConversionPass.cpp +++ b/compiler/logo/src/Passes/SimplifyDomainConversionPass.cpp @@ -20,8 +20,7 @@ #include <loco/IR/CanonicalDialect.h> #include <loco/IR/CanonicalNode.h> -#include <stdex/Memory.h> - +#include <memory> #include <set> #include <vector> #include <cassert> @@ -123,9 +122,6 @@ bool SimplifyDomainConversionPass::run(loco::Graph *g) { using namespace loco; - auto encoder = encode_node->encoder(); - assert(encoder != nullptr); - auto decode_node = dynamic_cast<loco::FeatureDecode *>(encode_node->input()); if (decode_node == nullptr) { @@ -133,6 +129,9 @@ bool SimplifyDomainConversionPass::run(loco::Graph *g) } assert(decode_node->input() != nullptr); + auto encoder = encode_node->encoder(); + assert(encoder != nullptr); + auto decoder = decode_node->decoder(); assert(decoder != nullptr); @@ -231,8 +230,8 @@ bool SimplifyDomainConversionPass::run(loco::Graph *g) perm_vec[to] = from; } - transposeCandidates.insert(stdex::make_unique<TransposeCtx>( - encode_node, decode_node, encode_node->input(), perm_vec)); + transposeCandidates.insert( + std::make_unique<TransposeCtx>(encode_node, decode_node, encode_node->input(), perm_vec)); } } @@ -293,8 +292,8 @@ bool SimplifyDomainConversionPass::run(loco::Graph *g) perm_vec[to] = from; } - transposeCandidates.insert(stdex::make_unique<TransposeCtx>( - encode_node, decode_node, encode_node->input(), perm_vec)); + transposeCandidates.insert( + std::make_unique<TransposeCtx>(encode_node, decode_node, encode_node->input(), perm_vec)); } } @@ -303,9 +302,6 @@ bool SimplifyDomainConversionPass::run(loco::Graph *g) { using namespace loco; - auto encoder = encode_node->encoder(); - assert(encoder != nullptr); - auto decode_node = dynamic_cast<loco::MatrixDecode *>(encode_node->input()); if (decode_node == nullptr) { @@ -313,6 +309,9 @@ bool SimplifyDomainConversionPass::run(loco::Graph *g) } assert(decode_node->input() != nullptr); + auto encoder = encode_node->encoder(); + assert(encoder != nullptr); + auto decoder = decode_node->decoder(); assert(decoder != nullptr); @@ -377,8 +376,8 @@ bool SimplifyDomainConversionPass::run(loco::Graph *g) perm_vec[to] = from; } - transposeCandidates.insert(stdex::make_unique<TransposeCtx>( - encode_node, decode_node, encode_node->input(), perm_vec)); + transposeCandidates.insert( + std::make_unique<TransposeCtx>(encode_node, decode_node, encode_node->input(), perm_vec)); } } @@ -397,7 +396,7 @@ bool SimplifyDomainConversionPass::run(loco::Graph *g) TransposeCtx(loco::Node *first, loco::Node *last, loco::Node *input, std::vector<loco::TensorAxis> perm) - : first_node(first), last_node(last), input_node(input), perm_vec(perm) + : first_node(first), last_node(last), input_node(input), perm_vec(perm) { /* empty */ } }; diff --git a/compiler/logo/src/Passes/SimplifyDomainConversionPass.test.cpp b/compiler/logo/src/Passes/SimplifyDomainConversionPass.test.cpp index 9a05763b4..75a288089 100644 --- a/compiler/logo/src/Passes/SimplifyDomainConversionPass.test.cpp +++ b/compiler/logo/src/Passes/SimplifyDomainConversionPass.test.cpp @@ -19,10 +19,26 @@ #include "TestHelper.h" #include <loco.h> -#include <stdex/Memory.h> + +#include <memory> #include <gtest/gtest.h> +TEST(SimplifyDomainConversionPassTest, name) +{ + logo::SimplifyDomainConversionPass pass; + auto const name = pass.name(); + ASSERT_NE(nullptr, name); +} + +TEST(SimplifyDomainConversionPassTest, run_NEG) +{ + loco::Graph g; + logo::SimplifyDomainConversionPass pass; + + ASSERT_FALSE(pass.run(&g)); +} + namespace { @@ -65,7 +81,7 @@ template <FilterLayout T> loco::FilterDecode *make_filter_decode(loco::Node *inp { loco::Graph *g = input_for_decode->graph(); - auto decoder = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Filter>>(); + auto decoder = std::make_unique<loco::PermutingDecoder<loco::Domain::Filter>>(); decoder->perm(perm<T>()); @@ -80,7 +96,7 @@ template <FilterLayout T> loco::FilterEncode *make_filter_encode(loco::Node *inp { loco::Graph *g = input_for_encode->graph(); - auto encoder = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Filter>>(); + auto encoder = std::make_unique<loco::PermutingEncoder<loco::Domain::Filter>>(); encoder->perm(perm<T>()); |