summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>2019-09-17 15:12:11 +0900
committerGitHub Enterprise <noreply-CODE@samsung.com>2019-09-17 15:12:11 +0900
commit4731b226438d62ea61edfbe2e532d3cfc225e66f (patch)
treee79811ffc84b4c5c6666e804d9befe0858f81b2c
parent53b524f84c156e6aaf4cba7d6504817074b93bd2 (diff)
downloadnnfw-4731b226438d62ea61edfbe2e532d3cfc225e66f.tar.gz
nnfw-4731b226438d62ea61edfbe2e532d3cfc225e66f.tar.bz2
nnfw-4731b226438d62ea61edfbe2e532d3cfc225e66f.zip
[loco] Extend CanonicalShapeInferenceRule with v2 API (#7468)
This commit extends CanonicalShapeinferenceRule with v2 API implementation. Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
-rw-r--r--compiler/loco/include/loco/Service/CanonicalShapeInferenceRule.h2
-rw-r--r--compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp59
-rw-r--r--compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp55
3 files changed, 113 insertions, 3 deletions
diff --git a/compiler/loco/include/loco/Service/CanonicalShapeInferenceRule.h b/compiler/loco/include/loco/Service/CanonicalShapeInferenceRule.h
index 3ef6fee71..cd3bed405 100644
--- a/compiler/loco/include/loco/Service/CanonicalShapeInferenceRule.h
+++ b/compiler/loco/include/loco/Service/CanonicalShapeInferenceRule.h
@@ -27,8 +27,10 @@ namespace loco
*/
struct CanonicalShapeInferenceRule final : public ShapeInferenceRule
{
+ bool support(const API &ver) const final;
bool recognize(const Dialect *) const final;
bool infer(const Node *, NodeShape &) const final;
+ void infer(const Context *, const Node *, Sink *) const final;
};
} // namespace loco
diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
index ef087d5f4..591b02450 100644
--- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
+++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
@@ -520,12 +520,49 @@ struct Context final : public loco::ShapeInferenceRule::Context
loco::NodeShape get(const loco::Node *node) const final { return loco::shape_get(node); }
};
+class Sink final : public loco::ShapeInferenceRule::Sink
+{
+public:
+ enum Status
+ {
+ Unknown,
+ Okay,
+ Fail,
+ };
+
+public:
+ const Status &status(void) const { return _status; }
+ const loco::NodeShape &shape(void) const { return _shape; }
+
+public:
+ void okay(const loco::NodeShape &shape) final
+ {
+ _status = Okay;
+ _shape = shape;
+ }
+
+ void fail(void) final
+ {
+ // Notify failrue
+ _status = Fail;
+ }
+
+private:
+ Status _status = Unknown;
+ loco::NodeShape _shape;
+};
+
} // namespace compat
} // namespace
namespace loco
{
+bool CanonicalShapeInferenceRule::support(const API &api) const
+{
+ return api == API::V1 or api == API::V2;
+}
+
bool CanonicalShapeInferenceRule::recognize(const Dialect *d) const
{
return CanonicalDialect::get() == d;
@@ -534,14 +571,30 @@ bool CanonicalShapeInferenceRule::recognize(const Dialect *d) const
bool CanonicalShapeInferenceRule::infer(const Node *node, NodeShape &shape) const
{
::compat::Context ctx;
+ ::compat::Sink sink;
+
+ infer(&ctx, node, &sink);
+
+ assert(sink.status() == ::compat::Sink::Okay or sink.status() == ::compat::Sink::Fail);
+
+ if (sink.status() == ::compat::Sink::Fail)
+ {
+ return false;
+ }
+ shape = sink.shape();
+ return true;
+}
+
+void CanonicalShapeInferenceRule::infer(const Context *ctx, const Node *node, Sink *sink) const
+{
assert(node->dialect() == loco::CanonicalDialect::get());
assert(dynamic_cast<const loco::CanonicalNode *>(node) != nullptr);
- ForwardShapeInferenceAlgorithm alg{&ctx};
- shape = dynamic_cast<const loco::CanonicalNode *>(node)->accept(&alg);
+ ForwardShapeInferenceAlgorithm alg{ctx};
+ auto shape = dynamic_cast<const loco::CanonicalNode *>(node)->accept(&alg);
- return true;
+ sink->okay(shape);
}
} // namespace loco
diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp
index d8022f5ca..62fa6786a 100644
--- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp
+++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp
@@ -281,3 +281,58 @@ TEST(CanonicalShapeInferenceRuleTest, tensor_broadcast)
ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(0), 4);
ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(1), 2);
}
+
+namespace
+{
+
+struct MockContext final : public loco::ShapeInferenceRule::Context
+{
+ bool known(const loco::Node *node) const final { return _content.find(node) != _content.end(); }
+ loco::NodeShape get(const loco::Node *node) const final { return _content.at(node); }
+
+ std::map<const loco::Node *, loco::NodeShape> _content;
+};
+
+struct MockSink final : public loco::ShapeInferenceRule::Sink
+{
+ void okay(const loco::NodeShape &res) final { shape = res; }
+ void fail(void) final { return; }
+
+ loco::NodeShape shape;
+};
+
+} // namespace
+
+TEST(CanonicalShapeInferenceRuleTest, infer_v2)
+{
+ auto g = loco::make_graph();
+
+ // Create an incomplete graph
+ auto relu_1 = g->nodes()->create<loco::ReLU>();
+ auto relu_2 = g->nodes()->create<loco::ReLU>();
+
+ relu_2->input(relu_1);
+
+ // Set up Context
+ MockContext ctx;
+
+ loco::TensorShape tensor_shape;
+
+ tensor_shape.rank(2);
+ tensor_shape.dim(0) = 4;
+ tensor_shape.dim(1) = 5;
+
+ ctx._content[relu_1] = tensor_shape;
+
+ // Create a Sink
+ MockSink sink;
+
+ loco::CanonicalShapeInferenceRule rule;
+
+ rule.infer(&ctx, relu_2, &sink);
+
+ ASSERT_EQ(sink.shape.domain(), loco::Domain::Tensor);
+ ASSERT_EQ(sink.shape.as<loco::TensorShape>().rank(), 2);
+ ASSERT_EQ(sink.shape.as<loco::TensorShape>().dim(0), 4);
+ ASSERT_EQ(sink.shape.as<loco::TensorShape>().dim(1), 5);
+}