summaryrefslogtreecommitdiff
path: root/compiler/luci/partition/src
diff options
context:
space:
mode:
authorChunseok Lee <chunseok.lee@samsung.com>2022-04-15 19:15:11 +0900
committerChunseok Lee <chunseok.lee@samsung.com>2022-04-15 19:15:11 +0900
commit3ad689f0803519e343c36d5700646e86059df961 (patch)
tree862346c401a5577518fa7f042532aa931b53aa0e /compiler/luci/partition/src
parentac6e4dd7b480e83b586ef533d7b29a8a97eb48fe (diff)
downloadnnfw-3ad689f0803519e343c36d5700646e86059df961.tar.gz
nnfw-3ad689f0803519e343c36d5700646e86059df961.tar.bz2
nnfw-3ad689f0803519e343c36d5700646e86059df961.zip
Imported Upstream version 1.20.0upstream/1.20.0submit/tizen/20220415.103159
Diffstat (limited to 'compiler/luci/partition/src')
-rw-r--r--compiler/luci/partition/src/ConnectNode.h2
-rw-r--r--compiler/luci/partition/src/Nodes/CircleSVDF.cpp47
-rw-r--r--compiler/luci/partition/src/Nodes/CircleSVDF.test.cpp106
-rw-r--r--compiler/luci/partition/src/Nodes/CircleVariable.cpp27
-rw-r--r--compiler/luci/partition/src/PartitionIRDump.cpp11
-rw-r--r--compiler/luci/partition/src/PartitionMerge.cpp50
-rw-r--r--compiler/luci/partition/src/PartitionPGroups.cpp240
7 files changed, 347 insertions, 136 deletions
diff --git a/compiler/luci/partition/src/ConnectNode.h b/compiler/luci/partition/src/ConnectNode.h
index ebbff7a6a..e60567c69 100644
--- a/compiler/luci/partition/src/ConnectNode.h
+++ b/compiler/luci/partition/src/ConnectNode.h
@@ -161,6 +161,7 @@ public:
void visit(const luci::CircleSquaredDifference *) final;
void visit(const luci::CircleSqueeze *) final;
void visit(const luci::CircleStridedSlice *) final;
+ void visit(const luci::CircleSVDF *) final;
void visit(const luci::CircleSub *) final;
void visit(const luci::CircleSum *) final;
void visit(const luci::CircleTanh *) final;
@@ -197,6 +198,7 @@ public:
void visit(const luci::CircleTopKV2Out *) final;
void visit(const luci::CircleUniqueOut *) final;
void visit(const luci::CircleUnpackOut *) final;
+ void visit(const luci::CircleVariable *) final;
void visit(const luci::CircleWhileOut *) final;
public:
diff --git a/compiler/luci/partition/src/Nodes/CircleSVDF.cpp b/compiler/luci/partition/src/Nodes/CircleSVDF.cpp
new file mode 100644
index 000000000..f661a794c
--- /dev/null
+++ b/compiler/luci/partition/src/Nodes/CircleSVDF.cpp
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleSVDF *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleSVDF *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *weight_feature = loco::must_cast<luci::CircleNode *>(node->weight_feature());
+ luci::CircleNode *weight_time = loco::must_cast<luci::CircleNode *>(node->weight_time());
+ luci::CircleNode *bias = loco::must_cast<luci::CircleNode *>(node->bias());
+ luci::CircleNode *input_activation_state =
+ loco::must_cast<luci::CircleNode *>(node->input_activation_state());
+
+ cloned->input(cn->find_clone(input));
+ cloned->weight_feature(cn->find_clone(weight_feature));
+ cloned->weight_time(cn->find_clone(weight_time));
+ cloned->bias(cn->find_clone(bias));
+ cloned->input_activation_state(cn->find_clone(input_activation_state));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleSVDF *node) { connect(this, node); }
+
+} // namespace luci
diff --git a/compiler/luci/partition/src/Nodes/CircleSVDF.test.cpp b/compiler/luci/partition/src/Nodes/CircleSVDF.test.cpp
new file mode 100644
index 000000000..5fae5206e
--- /dev/null
+++ b/compiler/luci/partition/src/Nodes/CircleSVDF.test.cpp
@@ -0,0 +1,106 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleSVDF>
+{
+public:
+ NodeGraphlet() = default;
+
+public:
+ void init(loco::Graph *g)
+ {
+ NodeGraphletT<luci::CircleSVDF>::init(g);
+
+ _node->fusedActivationFunction(luci::FusedActFunc::RELU);
+ }
+};
+
+class TestNodeGraph : public TestIsOGraph<5>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<5>::init({shape, shape, shape, shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->weight_feature(input(1));
+ node()->weight_time(input(2));
+ node()->bias(input(3));
+ node()->input_activation_state(input(4));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_SVDF)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSVDF *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSVDF *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(5, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+ ASSERT_EQ(cth.inputs(3), clone->arg(3));
+ ASSERT_EQ(cth.inputs(4), clone->arg(4));
+}
+
+TEST(ConnectNodeTest, connect_SVDF_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSVDF *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSVDF *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
diff --git a/compiler/luci/partition/src/Nodes/CircleVariable.cpp b/compiler/luci/partition/src/Nodes/CircleVariable.cpp
new file mode 100644
index 000000000..f7f6f21fd
--- /dev/null
+++ b/compiler/luci/partition/src/Nodes/CircleVariable.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleVariable *)
+{
+ // Nothing to do
+}
+
+} // namespace luci
diff --git a/compiler/luci/partition/src/PartitionIRDump.cpp b/compiler/luci/partition/src/PartitionIRDump.cpp
index 4f2c26800..0fabfc416 100644
--- a/compiler/luci/partition/src/PartitionIRDump.cpp
+++ b/compiler/luci/partition/src/PartitionIRDump.cpp
@@ -32,18 +32,18 @@ void dump(std::ostream &os, const PNode *pnode)
void dump(std::ostream &os, const PGroup *pgroup)
{
os << "--- PGroup: " << pgroup->group << std::endl;
- os << "Input(s): ";
+ os << "Input(s): [ ";
for (auto &node_in : pgroup->inputs)
os << node_in->name() << " ";
- os << std::endl;
+ os << "]" << std::endl;
for (auto &pnode : pgroup->pnodes)
{
dump(os, pnode.get());
}
- os << "Output(s): ";
+ os << "Output(s): [ ";
for (auto &node_out : pgroup->outputs)
os << node_out->name() << " ";
- os << std::endl;
+ os << "]" << std::endl;
}
void dump(std::ostream &os, const PGroups *pgroups)
@@ -57,7 +57,8 @@ void dump(std::ostream &os, const PGroups *pgroups)
{
auto node = it->first;
auto group = it->second;
- os << " Node: " << node << "(" << node->name() << "): " << group << std::endl;
+ os << " Node: " << node << "(" << luci::opcode_name(node) << "," << node->name()
+ << "): " << group << std::endl;
}
}
diff --git a/compiler/luci/partition/src/PartitionMerge.cpp b/compiler/luci/partition/src/PartitionMerge.cpp
index c517bf93f..4c3971bd8 100644
--- a/compiler/luci/partition/src/PartitionMerge.cpp
+++ b/compiler/luci/partition/src/PartitionMerge.cpp
@@ -58,9 +58,6 @@ bool is_input_same(const luci::PGroup *pgroup, const luci::PGroups *pgroups)
// we need to clone this CircleConst for each graph of the group.
if (dynamic_cast<const luci::CircleConst *>(input) != nullptr)
continue;
- // Skip also for OutputExclude
- if (dynamic_cast<const luci::CircleOutputExclude *>(input) != nullptr)
- continue;
auto input_group = pgroups->group_of(input);
// NOTE: all the nodes should be registered and return should be valid group.
@@ -87,7 +84,7 @@ bool is_input_same(const luci::PGroup *pgroup, const luci::PGroups *pgroups)
input_pgroup = pgroup_input;
else
{
- if (input_pgroup != pgroup_input)
+ if (input_pgroup->group != pgroup_input->group)
return false;
}
}
@@ -96,6 +93,48 @@ bool is_input_same(const luci::PGroup *pgroup, const luci::PGroups *pgroups)
}
/**
+ * @brief return true if there is only one output and is fed to same group of nodes
+ * @note pgroups is used to find group of pgroup
+ * ex)
+ * /-- pgroup_user_1 (grp_1)
+ * --- pgroup
+ * \-- pgroup_user_2 (grp_2)
+ *
+ * return false if grp_1 != grp_2
+ */
+bool is_output_same(const luci::PGroup *pgroup, const luci::PGroups *pgroups)
+{
+ assert(pgroups != nullptr);
+ assert(pgroup != nullptr);
+
+ std::string group;
+ for (auto &output : pgroup->outputs)
+ {
+ // get output_group
+ auto output_group = pgroups->group_of(output);
+ assert(not output_group.empty());
+ if (output_group.empty())
+ output_group = pgroups->default_group;
+
+ // find all PGroup that uses output
+ for (auto &pgroup_user : pgroups->pgroups)
+ {
+ for (auto &user_inputs : pgroup_user->inputs)
+ {
+ if (output == user_inputs)
+ {
+ // OK, these are connected, check group is same
+ if (pgroup_user->group != output_group)
+ return false;
+ }
+ }
+ }
+ }
+
+ return true;
+}
+
+/**
* @brief merge pgroup into pgroup_i
* @note output of pgroup_i should be input of pgroup
*/
@@ -191,6 +230,9 @@ std::unique_ptr<luci::PGroups> merge_pgroups(const luci::PGroups *s_pgroups)
// skip if there are multiple inputs but inputs differ in group
if (!is_input_same(pgroup.get(), d_pgroups.get()))
continue;
+ // skip if pgroup has different group for other users of pgroup_i
+ if (!is_output_same(pgroup_i.get(), d_pgroups.get()))
+ continue;
// TODO add more condition may be needed
merge_into(pgroup.get(), pgroup_i.get());
diff --git a/compiler/luci/partition/src/PartitionPGroups.cpp b/compiler/luci/partition/src/PartitionPGroups.cpp
index 0080873e6..eaeacf9c4 100644
--- a/compiler/luci/partition/src/PartitionPGroups.cpp
+++ b/compiler/luci/partition/src/PartitionPGroups.cpp
@@ -46,6 +46,9 @@ public:
bool visit(const luci::CircleUniqueOut *) final { return true; }
bool visit(const luci::CircleUnpackOut *) final { return true; }
bool visit(const luci::CircleWhileOut *) final { return true; }
+ // For inputs not used
+ bool visit(const luci::CircleOutputExclude *) final { return true; }
+ bool visit(const luci::CircleVariable *) final { return true; }
// TODO add all virtual nodes
// default is false
@@ -69,59 +72,80 @@ bool check_allocate_partition(const luci::CircleNode *node)
return true;
}
-class FindGroupToFollow final : public luci::CircleNodeVisitor<const std::string &>
+} // namespace
+
+namespace
{
-public:
- FindGroupToFollow(const luci::PartitionTable &partition, luci::PGroups *pgroups)
- : _partition(partition), _pgroups(pgroups)
- {
- // NOTHING TODO
- }
-private:
- const std::string &groupof(const luci::CircleNode *input) const
+std::string group_from_partition(const luci::CircleNode *node,
+ const luci::PartitionTable &partition)
+{
+ LOGGER(l);
+
+ auto group = partition.default_group;
+
+ std::string opcodename; // opcodename or opname
+
+ switch (partition.comply)
{
- auto group = _pgroups->node2group[input];
- assert(not group.empty());
- if (group.empty())
- return _partition.default_group;
- return _pgroups->node2group[input];
+ case luci::PartitionTable::COMPLY::OPCODE:
+ {
+ opcodename = luci::opcode_name(node);
+ assert(!opcodename.empty());
+
+ auto it = partition.byopcodes.find(opcodename);
+ if (it != partition.byopcodes.end())
+ group = it->second;
+ break;
+ }
+ case luci::PartitionTable::COMPLY::OPNAME:
+ {
+ opcodename = node->name();
+ assert(!opcodename.empty());
+
+ auto it = partition.byopnames.find(opcodename);
+ if (it != partition.byopnames.end())
+ group = it->second;
+ break;
+ }
+
+ default:
+ throw std::runtime_error("Unsupported partition.comply");
}
+ INFO(l) << "Op: " << node->name() << ": " << opcodename << ", " << node << ", " << group
+ << std::endl;
+
+ return group;
+}
+
+class IsVirtualInputNode final : public luci::CircleNodeVisitor<bool>
+{
public:
-#define IMPLEMENT(CLASS) \
- const std::string &visit(const luci::CLASS *node) final \
- { \
- auto input = loco::must_cast<luci::CircleNode *>(node->input()); \
- return groupof(input); \
- }
+ // TODO check CircleOutputDummy
+ bool visit(const luci::CircleOutputExclude *) final { return true; }
+ bool visit(const luci::CircleVariable *) final { return true; }
- IMPLEMENT(CircleCustomOut);
- IMPLEMENT(CircleIfOut);
- IMPLEMENT(CircleNonMaxSuppressionV4Out);
- IMPLEMENT(CircleNonMaxSuppressionV5Out);
- IMPLEMENT(CircleSplitOut);
- IMPLEMENT(CircleSplitVOut);
- IMPLEMENT(CircleTopKV2Out);
- IMPLEMENT(CircleUniqueOut);
- IMPLEMENT(CircleUnpackOut);
- IMPLEMENT(CircleWhileOut);
-
-#undef IMPLEMENT
-
- // return empty for nothing to do
- const std::string &visit(const luci::CircleNode *) final { return _empty_str; }
-
-private:
- const luci::PartitionTable &_partition;
- luci::PGroups *_pgroups = nullptr;
- std::string _empty_str;
+ // default is false
+ bool visit(const luci::CircleNode *) final { return false; }
};
-} // namespace
-
-namespace
+class IsMultiOutputNode final : public luci::CircleNodeVisitor<bool>
{
+public:
+ bool visit(const luci::CircleCustom *) final { return true; }
+ bool visit(const luci::CircleIf *) final { return true; }
+ bool visit(const luci::CircleNonMaxSuppressionV4 *) final { return true; }
+ bool visit(const luci::CircleNonMaxSuppressionV5 *) final { return true; }
+ bool visit(const luci::CircleSplit *) final { return true; }
+ bool visit(const luci::CircleSplitV *) final { return true; }
+ bool visit(const luci::CircleTopKV2 *) final { return true; }
+ bool visit(const luci::CircleUnique *) final { return true; }
+ bool visit(const luci::CircleUnpack *) final { return true; }
+ bool visit(const luci::CircleWhile *) final { return true; }
+ // default is false
+ bool visit(const luci::CircleNode *) final { return false; }
+};
void append(luci::CircleNode *node, luci::PGroups *pgroups, const std::string &group, uint32_t idx)
{
@@ -136,17 +160,56 @@ void append(luci::CircleNode *node, luci::PGroups *pgroups, const std::string &g
pgroup->pnodes.push_back(std::move(pnode));
+ IsVirtualInputNode queryvi;
// Set input of PGroup
for (uint32_t in = 0; in < node->arity(); ++in)
{
auto input = loco::must_cast<luci::CircleNode *>(node->arg(in));
- // this input maybe CircleInput in source graph
- // --> not confident this is safe
- pgroup->inputs.push_back(input);
+ if (input->accept(&queryvi))
+ {
+ auto pnode = std::make_unique<luci::PNode>();
+ pnode->node = input;
+ pnode->group = group;
+ pnode->pgroup = pgroup.get();
+
+ pgroup->pnodes.push_back(std::move(pnode));
+
+ pgroups->node2group[input] = group;
+ }
+ else
+ {
+ // this input maybe CircleInput in source graph
+ // --> not confident this is safe
+ pgroup->inputs.push_back(input);
+ }
+ }
+
+ IsMultiOutputNode query;
+ if (node->accept(&query))
+ {
+ // Include CircleXXXOut virtual nodes in this group
+ auto succs = loco::succs(node);
+ for (auto &succ_node : succs)
+ {
+ auto nodeout = loco::must_cast<luci::CircleNode *>(succ_node);
+
+ auto pnode = std::make_unique<luci::PNode>();
+ pnode->node = nodeout;
+ pnode->group = group;
+ pnode->pgroup = pgroup.get();
+
+ pgroup->pnodes.push_back(std::move(pnode));
+
+ pgroups->node2group[nodeout] = group;
+
+ pgroup->outputs.push_back(nodeout);
+ }
+ }
+ else
+ {
+ // Set output of PGroup: node itself
+ pgroup->outputs.push_back(node);
}
- // Set output of PGroup: node itself or multiple virtual outputs
- // TODO support multiple virtual outputs
- pgroup->outputs.push_back(node);
pgroups->node2group[node] = group;
pgroups->id2pgroup[pgroup->id] = pgroup.get();
@@ -182,70 +245,9 @@ std::unique_ptr<luci::PGroups> produce_pgroups(const luci::Module *source,
// check if node is normal node that we are interested
if (check_allocate_partition(node))
{
- auto group = partition.default_group;
-
- std::string opcodename; // opcodename or opname
-
- switch (partition.comply)
- {
- case luci::PartitionTable::COMPLY::OPCODE:
- {
- opcodename = luci::opcode_name(node);
- assert(!opcodename.empty());
-
- auto it = partition.byopcodes.find(opcodename);
- if (it != partition.byopcodes.end())
- group = it->second;
- break;
- }
- case luci::PartitionTable::COMPLY::OPNAME:
- {
- opcodename = node->name();
- assert(!opcodename.empty());
-
- auto it = partition.byopnames.find(opcodename);
- if (it != partition.byopnames.end())
- group = it->second;
- break;
- }
-
- default:
- throw std::runtime_error("Unsupported partition.comply");
- }
-
- INFO(l) << "Op: " << node->name() << ": " << opcodename << ", " << node << ", " << group
- << std::endl;
+ auto group = group_from_partition(node, partition);
append(node, pgroups.get(), group, idx);
-#if 0
- auto pgroup = std::make_unique<luci::PGroup>();
- pgroup->group = group;
- pgroup->id = idx + 1;
-
- auto pnode = std::make_unique<luci::PNode>();
- pnode->node = node;
- pnode->group = group;
- pnode->pgroup = pgroup.get();
-
- pgroup->pnodes.push_back(std::move(pnode));
-
- // Set input of PGroup
- for (uint32_t in = 0; in < node->arity(); ++in)
- {
- auto input = loco::must_cast<luci::CircleNode *>(node->arg(in));
- // this input maybe CircleInput in source graph
- // --> not confident this is safe
- pgroup->inputs.push_back(input);
- }
- // Set output of PGroup: node itself or multiple virtual outputs
- // TODO support multiple virtual outputs
- pgroup->outputs.push_back(node);
-
- pgroups->node2group[node] = group;
- pgroups->id2pgroup[pgroup->id] = pgroup.get();
-
- pgroups->pgroups.push_back(std::move(pgroup));
-#endif
}
else
{
@@ -255,22 +257,6 @@ std::unique_ptr<luci::PGroups> produce_pgroups(const luci::Module *source,
}
}
- // handle for virtual nodes like multiple outputs
- // these nodes should follow group of the input
- for (uint32_t idx = 0; idx < nodes->size(); ++idx)
- {
- auto node = loco::must_cast<luci::CircleNode *>(nodes->at(idx));
-
- // for virtual nodes like CircleUnpackOut should follow it's input (owner)
- // or just set to default
- FindGroupToFollow query(partition, pgroups.get());
- const auto &group = node->accept(&query);
- if (not group.empty())
- {
- append(node, pgroups.get(), group, idx);
- }
- }
-
return std::move(pgroups);
}