diff options
author | 박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com> | 2019-09-17 15:12:11 +0900 |
---|---|---|
committer | GitHub Enterprise <noreply-CODE@samsung.com> | 2019-09-17 15:12:11 +0900 |
commit | 4731b226438d62ea61edfbe2e532d3cfc225e66f (patch) | |
tree | e79811ffc84b4c5c6666e804d9befe0858f81b2c | |
parent | 53b524f84c156e6aaf4cba7d6504817074b93bd2 (diff) | |
download | nnfw-4731b226438d62ea61edfbe2e532d3cfc225e66f.tar.gz nnfw-4731b226438d62ea61edfbe2e532d3cfc225e66f.tar.bz2 nnfw-4731b226438d62ea61edfbe2e532d3cfc225e66f.zip |
[loco] Extend CanonicalShapeInferenceRule with v2 API (#7468)
This commit extends CanonicalShapeinferenceRule with v2 API implementation.
Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
3 files changed, 113 insertions, 3 deletions
diff --git a/compiler/loco/include/loco/Service/CanonicalShapeInferenceRule.h b/compiler/loco/include/loco/Service/CanonicalShapeInferenceRule.h index 3ef6fee71..cd3bed405 100644 --- a/compiler/loco/include/loco/Service/CanonicalShapeInferenceRule.h +++ b/compiler/loco/include/loco/Service/CanonicalShapeInferenceRule.h @@ -27,8 +27,10 @@ namespace loco */ struct CanonicalShapeInferenceRule final : public ShapeInferenceRule { + bool support(const API &ver) const final; bool recognize(const Dialect *) const final; bool infer(const Node *, NodeShape &) const final; + void infer(const Context *, const Node *, Sink *) const final; }; } // namespace loco diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp index ef087d5f4..591b02450 100644 --- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp +++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp @@ -520,12 +520,49 @@ struct Context final : public loco::ShapeInferenceRule::Context loco::NodeShape get(const loco::Node *node) const final { return loco::shape_get(node); } }; +class Sink final : public loco::ShapeInferenceRule::Sink +{ +public: + enum Status + { + Unknown, + Okay, + Fail, + }; + +public: + const Status &status(void) const { return _status; } + const loco::NodeShape &shape(void) const { return _shape; } + +public: + void okay(const loco::NodeShape &shape) final + { + _status = Okay; + _shape = shape; + } + + void fail(void) final + { + // Notify failrue + _status = Fail; + } + +private: + Status _status = Unknown; + loco::NodeShape _shape; +}; + } // namespace compat } // namespace namespace loco { +bool CanonicalShapeInferenceRule::support(const API &api) const +{ + return api == API::V1 or api == API::V2; +} + bool CanonicalShapeInferenceRule::recognize(const Dialect *d) const { return CanonicalDialect::get() == d; @@ -534,14 +571,30 @@ bool CanonicalShapeInferenceRule::recognize(const Dialect *d) const bool CanonicalShapeInferenceRule::infer(const Node *node, NodeShape &shape) const { ::compat::Context ctx; + ::compat::Sink sink; + + infer(&ctx, node, &sink); + + assert(sink.status() == ::compat::Sink::Okay or sink.status() == ::compat::Sink::Fail); + + if (sink.status() == ::compat::Sink::Fail) + { + return false; + } + shape = sink.shape(); + return true; +} + +void CanonicalShapeInferenceRule::infer(const Context *ctx, const Node *node, Sink *sink) const +{ assert(node->dialect() == loco::CanonicalDialect::get()); assert(dynamic_cast<const loco::CanonicalNode *>(node) != nullptr); - ForwardShapeInferenceAlgorithm alg{&ctx}; - shape = dynamic_cast<const loco::CanonicalNode *>(node)->accept(&alg); + ForwardShapeInferenceAlgorithm alg{ctx}; + auto shape = dynamic_cast<const loco::CanonicalNode *>(node)->accept(&alg); - return true; + sink->okay(shape); } } // namespace loco diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp index d8022f5ca..62fa6786a 100644 --- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp +++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp @@ -281,3 +281,58 @@ TEST(CanonicalShapeInferenceRuleTest, tensor_broadcast) ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(0), 4); ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(1), 2); } + +namespace +{ + +struct MockContext final : public loco::ShapeInferenceRule::Context +{ + bool known(const loco::Node *node) const final { return _content.find(node) != _content.end(); } + loco::NodeShape get(const loco::Node *node) const final { return _content.at(node); } + + std::map<const loco::Node *, loco::NodeShape> _content; +}; + +struct MockSink final : public loco::ShapeInferenceRule::Sink +{ + void okay(const loco::NodeShape &res) final { shape = res; } + void fail(void) final { return; } + + loco::NodeShape shape; +}; + +} // namespace + +TEST(CanonicalShapeInferenceRuleTest, infer_v2) +{ + auto g = loco::make_graph(); + + // Create an incomplete graph + auto relu_1 = g->nodes()->create<loco::ReLU>(); + auto relu_2 = g->nodes()->create<loco::ReLU>(); + + relu_2->input(relu_1); + + // Set up Context + MockContext ctx; + + loco::TensorShape tensor_shape; + + tensor_shape.rank(2); + tensor_shape.dim(0) = 4; + tensor_shape.dim(1) = 5; + + ctx._content[relu_1] = tensor_shape; + + // Create a Sink + MockSink sink; + + loco::CanonicalShapeInferenceRule rule; + + rule.infer(&ctx, relu_2, &sink); + + ASSERT_EQ(sink.shape.domain(), loco::Domain::Tensor); + ASSERT_EQ(sink.shape.as<loco::TensorShape>().rank(), 2); + ASSERT_EQ(sink.shape.as<loco::TensorShape>().dim(0), 4); + ASSERT_EQ(sink.shape.as<loco::TensorShape>().dim(1), 5); +} |