diff options
author | Chunseok Lee <chunseok.lee@samsung.com> | 2022-04-15 19:15:11 +0900 |
---|---|---|
committer | Chunseok Lee <chunseok.lee@samsung.com> | 2022-04-15 19:15:11 +0900 |
commit | 3ad689f0803519e343c36d5700646e86059df961 (patch) | |
tree | 862346c401a5577518fa7f042532aa931b53aa0e /compiler/luci/partition/src | |
parent | ac6e4dd7b480e83b586ef533d7b29a8a97eb48fe (diff) | |
download | nnfw-3ad689f0803519e343c36d5700646e86059df961.tar.gz nnfw-3ad689f0803519e343c36d5700646e86059df961.tar.bz2 nnfw-3ad689f0803519e343c36d5700646e86059df961.zip |
Imported Upstream version 1.20.0upstream/1.20.0submit/tizen/20220415.103159
Diffstat (limited to 'compiler/luci/partition/src')
-rw-r--r-- | compiler/luci/partition/src/ConnectNode.h | 2 | ||||
-rw-r--r-- | compiler/luci/partition/src/Nodes/CircleSVDF.cpp | 47 | ||||
-rw-r--r-- | compiler/luci/partition/src/Nodes/CircleSVDF.test.cpp | 106 | ||||
-rw-r--r-- | compiler/luci/partition/src/Nodes/CircleVariable.cpp | 27 | ||||
-rw-r--r-- | compiler/luci/partition/src/PartitionIRDump.cpp | 11 | ||||
-rw-r--r-- | compiler/luci/partition/src/PartitionMerge.cpp | 50 | ||||
-rw-r--r-- | compiler/luci/partition/src/PartitionPGroups.cpp | 240 |
7 files changed, 347 insertions, 136 deletions
diff --git a/compiler/luci/partition/src/ConnectNode.h b/compiler/luci/partition/src/ConnectNode.h index ebbff7a6a..e60567c69 100644 --- a/compiler/luci/partition/src/ConnectNode.h +++ b/compiler/luci/partition/src/ConnectNode.h @@ -161,6 +161,7 @@ public: void visit(const luci::CircleSquaredDifference *) final; void visit(const luci::CircleSqueeze *) final; void visit(const luci::CircleStridedSlice *) final; + void visit(const luci::CircleSVDF *) final; void visit(const luci::CircleSub *) final; void visit(const luci::CircleSum *) final; void visit(const luci::CircleTanh *) final; @@ -197,6 +198,7 @@ public: void visit(const luci::CircleTopKV2Out *) final; void visit(const luci::CircleUniqueOut *) final; void visit(const luci::CircleUnpackOut *) final; + void visit(const luci::CircleVariable *) final; void visit(const luci::CircleWhileOut *) final; public: diff --git a/compiler/luci/partition/src/Nodes/CircleSVDF.cpp b/compiler/luci/partition/src/Nodes/CircleSVDF.cpp new file mode 100644 index 000000000..f661a794c --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSVDF.cpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleSVDF *node) +{ + auto *cloned = loco::must_cast<luci::CircleSVDF *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *weight_feature = loco::must_cast<luci::CircleNode *>(node->weight_feature()); + luci::CircleNode *weight_time = loco::must_cast<luci::CircleNode *>(node->weight_time()); + luci::CircleNode *bias = loco::must_cast<luci::CircleNode *>(node->bias()); + luci::CircleNode *input_activation_state = + loco::must_cast<luci::CircleNode *>(node->input_activation_state()); + + cloned->input(cn->find_clone(input)); + cloned->weight_feature(cn->find_clone(weight_feature)); + cloned->weight_time(cn->find_clone(weight_time)); + cloned->bias(cn->find_clone(bias)); + cloned->input_activation_state(cn->find_clone(input_activation_state)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleSVDF *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleSVDF.test.cpp b/compiler/luci/partition/src/Nodes/CircleSVDF.test.cpp new file mode 100644 index 000000000..5fae5206e --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSVDF.test.cpp @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleSVDF> +{ +public: + NodeGraphlet() = default; + +public: + void init(loco::Graph *g) + { + NodeGraphletT<luci::CircleSVDF>::init(g); + + _node->fusedActivationFunction(luci::FusedActFunc::RELU); + } +}; + +class TestNodeGraph : public TestIsOGraph<5>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<5>::init({shape, shape, shape, shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->weight_feature(input(1)); + node()->weight_time(input(2)); + node()->bias(input(3)); + node()->input_activation_state(input(4)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_SVDF) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSVDF *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSVDF *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(5, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); + ASSERT_EQ(cth.inputs(2), clone->arg(2)); + ASSERT_EQ(cth.inputs(3), clone->arg(3)); + ASSERT_EQ(cth.inputs(4), clone->arg(4)); +} + +TEST(ConnectNodeTest, connect_SVDF_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSVDF *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSVDF *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleVariable.cpp b/compiler/luci/partition/src/Nodes/CircleVariable.cpp new file mode 100644 index 000000000..f7f6f21fd --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleVariable.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleVariable *) +{ + // Nothing to do +} + +} // namespace luci diff --git a/compiler/luci/partition/src/PartitionIRDump.cpp b/compiler/luci/partition/src/PartitionIRDump.cpp index 4f2c26800..0fabfc416 100644 --- a/compiler/luci/partition/src/PartitionIRDump.cpp +++ b/compiler/luci/partition/src/PartitionIRDump.cpp @@ -32,18 +32,18 @@ void dump(std::ostream &os, const PNode *pnode) void dump(std::ostream &os, const PGroup *pgroup) { os << "--- PGroup: " << pgroup->group << std::endl; - os << "Input(s): "; + os << "Input(s): [ "; for (auto &node_in : pgroup->inputs) os << node_in->name() << " "; - os << std::endl; + os << "]" << std::endl; for (auto &pnode : pgroup->pnodes) { dump(os, pnode.get()); } - os << "Output(s): "; + os << "Output(s): [ "; for (auto &node_out : pgroup->outputs) os << node_out->name() << " "; - os << std::endl; + os << "]" << std::endl; } void dump(std::ostream &os, const PGroups *pgroups) @@ -57,7 +57,8 @@ void dump(std::ostream &os, const PGroups *pgroups) { auto node = it->first; auto group = it->second; - os << " Node: " << node << "(" << node->name() << "): " << group << std::endl; + os << " Node: " << node << "(" << luci::opcode_name(node) << "," << node->name() + << "): " << group << std::endl; } } diff --git a/compiler/luci/partition/src/PartitionMerge.cpp b/compiler/luci/partition/src/PartitionMerge.cpp index c517bf93f..4c3971bd8 100644 --- a/compiler/luci/partition/src/PartitionMerge.cpp +++ b/compiler/luci/partition/src/PartitionMerge.cpp @@ -58,9 +58,6 @@ bool is_input_same(const luci::PGroup *pgroup, const luci::PGroups *pgroups) // we need to clone this CircleConst for each graph of the group. if (dynamic_cast<const luci::CircleConst *>(input) != nullptr) continue; - // Skip also for OutputExclude - if (dynamic_cast<const luci::CircleOutputExclude *>(input) != nullptr) - continue; auto input_group = pgroups->group_of(input); // NOTE: all the nodes should be registered and return should be valid group. @@ -87,7 +84,7 @@ bool is_input_same(const luci::PGroup *pgroup, const luci::PGroups *pgroups) input_pgroup = pgroup_input; else { - if (input_pgroup != pgroup_input) + if (input_pgroup->group != pgroup_input->group) return false; } } @@ -96,6 +93,48 @@ bool is_input_same(const luci::PGroup *pgroup, const luci::PGroups *pgroups) } /** + * @brief return true if there is only one output and is fed to same group of nodes + * @note pgroups is used to find group of pgroup + * ex) + * /-- pgroup_user_1 (grp_1) + * --- pgroup + * \-- pgroup_user_2 (grp_2) + * + * return false if grp_1 != grp_2 + */ +bool is_output_same(const luci::PGroup *pgroup, const luci::PGroups *pgroups) +{ + assert(pgroups != nullptr); + assert(pgroup != nullptr); + + std::string group; + for (auto &output : pgroup->outputs) + { + // get output_group + auto output_group = pgroups->group_of(output); + assert(not output_group.empty()); + if (output_group.empty()) + output_group = pgroups->default_group; + + // find all PGroup that uses output + for (auto &pgroup_user : pgroups->pgroups) + { + for (auto &user_inputs : pgroup_user->inputs) + { + if (output == user_inputs) + { + // OK, these are connected, check group is same + if (pgroup_user->group != output_group) + return false; + } + } + } + } + + return true; +} + +/** * @brief merge pgroup into pgroup_i * @note output of pgroup_i should be input of pgroup */ @@ -191,6 +230,9 @@ std::unique_ptr<luci::PGroups> merge_pgroups(const luci::PGroups *s_pgroups) // skip if there are multiple inputs but inputs differ in group if (!is_input_same(pgroup.get(), d_pgroups.get())) continue; + // skip if pgroup has different group for other users of pgroup_i + if (!is_output_same(pgroup_i.get(), d_pgroups.get())) + continue; // TODO add more condition may be needed merge_into(pgroup.get(), pgroup_i.get()); diff --git a/compiler/luci/partition/src/PartitionPGroups.cpp b/compiler/luci/partition/src/PartitionPGroups.cpp index 0080873e6..eaeacf9c4 100644 --- a/compiler/luci/partition/src/PartitionPGroups.cpp +++ b/compiler/luci/partition/src/PartitionPGroups.cpp @@ -46,6 +46,9 @@ public: bool visit(const luci::CircleUniqueOut *) final { return true; } bool visit(const luci::CircleUnpackOut *) final { return true; } bool visit(const luci::CircleWhileOut *) final { return true; } + // For inputs not used + bool visit(const luci::CircleOutputExclude *) final { return true; } + bool visit(const luci::CircleVariable *) final { return true; } // TODO add all virtual nodes // default is false @@ -69,59 +72,80 @@ bool check_allocate_partition(const luci::CircleNode *node) return true; } -class FindGroupToFollow final : public luci::CircleNodeVisitor<const std::string &> +} // namespace + +namespace { -public: - FindGroupToFollow(const luci::PartitionTable &partition, luci::PGroups *pgroups) - : _partition(partition), _pgroups(pgroups) - { - // NOTHING TODO - } -private: - const std::string &groupof(const luci::CircleNode *input) const +std::string group_from_partition(const luci::CircleNode *node, + const luci::PartitionTable &partition) +{ + LOGGER(l); + + auto group = partition.default_group; + + std::string opcodename; // opcodename or opname + + switch (partition.comply) { - auto group = _pgroups->node2group[input]; - assert(not group.empty()); - if (group.empty()) - return _partition.default_group; - return _pgroups->node2group[input]; + case luci::PartitionTable::COMPLY::OPCODE: + { + opcodename = luci::opcode_name(node); + assert(!opcodename.empty()); + + auto it = partition.byopcodes.find(opcodename); + if (it != partition.byopcodes.end()) + group = it->second; + break; + } + case luci::PartitionTable::COMPLY::OPNAME: + { + opcodename = node->name(); + assert(!opcodename.empty()); + + auto it = partition.byopnames.find(opcodename); + if (it != partition.byopnames.end()) + group = it->second; + break; + } + + default: + throw std::runtime_error("Unsupported partition.comply"); } + INFO(l) << "Op: " << node->name() << ": " << opcodename << ", " << node << ", " << group + << std::endl; + + return group; +} + +class IsVirtualInputNode final : public luci::CircleNodeVisitor<bool> +{ public: -#define IMPLEMENT(CLASS) \ - const std::string &visit(const luci::CLASS *node) final \ - { \ - auto input = loco::must_cast<luci::CircleNode *>(node->input()); \ - return groupof(input); \ - } + // TODO check CircleOutputDummy + bool visit(const luci::CircleOutputExclude *) final { return true; } + bool visit(const luci::CircleVariable *) final { return true; } - IMPLEMENT(CircleCustomOut); - IMPLEMENT(CircleIfOut); - IMPLEMENT(CircleNonMaxSuppressionV4Out); - IMPLEMENT(CircleNonMaxSuppressionV5Out); - IMPLEMENT(CircleSplitOut); - IMPLEMENT(CircleSplitVOut); - IMPLEMENT(CircleTopKV2Out); - IMPLEMENT(CircleUniqueOut); - IMPLEMENT(CircleUnpackOut); - IMPLEMENT(CircleWhileOut); - -#undef IMPLEMENT - - // return empty for nothing to do - const std::string &visit(const luci::CircleNode *) final { return _empty_str; } - -private: - const luci::PartitionTable &_partition; - luci::PGroups *_pgroups = nullptr; - std::string _empty_str; + // default is false + bool visit(const luci::CircleNode *) final { return false; } }; -} // namespace - -namespace +class IsMultiOutputNode final : public luci::CircleNodeVisitor<bool> { +public: + bool visit(const luci::CircleCustom *) final { return true; } + bool visit(const luci::CircleIf *) final { return true; } + bool visit(const luci::CircleNonMaxSuppressionV4 *) final { return true; } + bool visit(const luci::CircleNonMaxSuppressionV5 *) final { return true; } + bool visit(const luci::CircleSplit *) final { return true; } + bool visit(const luci::CircleSplitV *) final { return true; } + bool visit(const luci::CircleTopKV2 *) final { return true; } + bool visit(const luci::CircleUnique *) final { return true; } + bool visit(const luci::CircleUnpack *) final { return true; } + bool visit(const luci::CircleWhile *) final { return true; } + // default is false + bool visit(const luci::CircleNode *) final { return false; } +}; void append(luci::CircleNode *node, luci::PGroups *pgroups, const std::string &group, uint32_t idx) { @@ -136,17 +160,56 @@ void append(luci::CircleNode *node, luci::PGroups *pgroups, const std::string &g pgroup->pnodes.push_back(std::move(pnode)); + IsVirtualInputNode queryvi; // Set input of PGroup for (uint32_t in = 0; in < node->arity(); ++in) { auto input = loco::must_cast<luci::CircleNode *>(node->arg(in)); - // this input maybe CircleInput in source graph - // --> not confident this is safe - pgroup->inputs.push_back(input); + if (input->accept(&queryvi)) + { + auto pnode = std::make_unique<luci::PNode>(); + pnode->node = input; + pnode->group = group; + pnode->pgroup = pgroup.get(); + + pgroup->pnodes.push_back(std::move(pnode)); + + pgroups->node2group[input] = group; + } + else + { + // this input maybe CircleInput in source graph + // --> not confident this is safe + pgroup->inputs.push_back(input); + } + } + + IsMultiOutputNode query; + if (node->accept(&query)) + { + // Include CircleXXXOut virtual nodes in this group + auto succs = loco::succs(node); + for (auto &succ_node : succs) + { + auto nodeout = loco::must_cast<luci::CircleNode *>(succ_node); + + auto pnode = std::make_unique<luci::PNode>(); + pnode->node = nodeout; + pnode->group = group; + pnode->pgroup = pgroup.get(); + + pgroup->pnodes.push_back(std::move(pnode)); + + pgroups->node2group[nodeout] = group; + + pgroup->outputs.push_back(nodeout); + } + } + else + { + // Set output of PGroup: node itself + pgroup->outputs.push_back(node); } - // Set output of PGroup: node itself or multiple virtual outputs - // TODO support multiple virtual outputs - pgroup->outputs.push_back(node); pgroups->node2group[node] = group; pgroups->id2pgroup[pgroup->id] = pgroup.get(); @@ -182,70 +245,9 @@ std::unique_ptr<luci::PGroups> produce_pgroups(const luci::Module *source, // check if node is normal node that we are interested if (check_allocate_partition(node)) { - auto group = partition.default_group; - - std::string opcodename; // opcodename or opname - - switch (partition.comply) - { - case luci::PartitionTable::COMPLY::OPCODE: - { - opcodename = luci::opcode_name(node); - assert(!opcodename.empty()); - - auto it = partition.byopcodes.find(opcodename); - if (it != partition.byopcodes.end()) - group = it->second; - break; - } - case luci::PartitionTable::COMPLY::OPNAME: - { - opcodename = node->name(); - assert(!opcodename.empty()); - - auto it = partition.byopnames.find(opcodename); - if (it != partition.byopnames.end()) - group = it->second; - break; - } - - default: - throw std::runtime_error("Unsupported partition.comply"); - } - - INFO(l) << "Op: " << node->name() << ": " << opcodename << ", " << node << ", " << group - << std::endl; + auto group = group_from_partition(node, partition); append(node, pgroups.get(), group, idx); -#if 0 - auto pgroup = std::make_unique<luci::PGroup>(); - pgroup->group = group; - pgroup->id = idx + 1; - - auto pnode = std::make_unique<luci::PNode>(); - pnode->node = node; - pnode->group = group; - pnode->pgroup = pgroup.get(); - - pgroup->pnodes.push_back(std::move(pnode)); - - // Set input of PGroup - for (uint32_t in = 0; in < node->arity(); ++in) - { - auto input = loco::must_cast<luci::CircleNode *>(node->arg(in)); - // this input maybe CircleInput in source graph - // --> not confident this is safe - pgroup->inputs.push_back(input); - } - // Set output of PGroup: node itself or multiple virtual outputs - // TODO support multiple virtual outputs - pgroup->outputs.push_back(node); - - pgroups->node2group[node] = group; - pgroups->id2pgroup[pgroup->id] = pgroup.get(); - - pgroups->pgroups.push_back(std::move(pgroup)); -#endif } else { @@ -255,22 +257,6 @@ std::unique_ptr<luci::PGroups> produce_pgroups(const luci::Module *source, } } - // handle for virtual nodes like multiple outputs - // these nodes should follow group of the input - for (uint32_t idx = 0; idx < nodes->size(); ++idx) - { - auto node = loco::must_cast<luci::CircleNode *>(nodes->at(idx)); - - // for virtual nodes like CircleUnpackOut should follow it's input (owner) - // or just set to default - FindGroupToFollow query(partition, pgroups.get()); - const auto &group = node->accept(&query); - if (not group.empty()) - { - append(node, pgroups.get(), group, idx); - } - } - return std::move(pgroups); } |