diff options
Diffstat (limited to 'compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp')
-rw-r--r-- | compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp index d8022f5ca..62fa6786a 100644 --- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp +++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp @@ -281,3 +281,58 @@ TEST(CanonicalShapeInferenceRuleTest, tensor_broadcast) ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(0), 4); ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(1), 2); } + +namespace +{ + +struct MockContext final : public loco::ShapeInferenceRule::Context +{ + bool known(const loco::Node *node) const final { return _content.find(node) != _content.end(); } + loco::NodeShape get(const loco::Node *node) const final { return _content.at(node); } + + std::map<const loco::Node *, loco::NodeShape> _content; +}; + +struct MockSink final : public loco::ShapeInferenceRule::Sink +{ + void okay(const loco::NodeShape &res) final { shape = res; } + void fail(void) final { return; } + + loco::NodeShape shape; +}; + +} // namespace + +TEST(CanonicalShapeInferenceRuleTest, infer_v2) +{ + auto g = loco::make_graph(); + + // Create an incomplete graph + auto relu_1 = g->nodes()->create<loco::ReLU>(); + auto relu_2 = g->nodes()->create<loco::ReLU>(); + + relu_2->input(relu_1); + + // Set up Context + MockContext ctx; + + loco::TensorShape tensor_shape; + + tensor_shape.rank(2); + tensor_shape.dim(0) = 4; + tensor_shape.dim(1) = 5; + + ctx._content[relu_1] = tensor_shape; + + // Create a Sink + MockSink sink; + + loco::CanonicalShapeInferenceRule rule; + + rule.infer(&ctx, relu_2, &sink); + + ASSERT_EQ(sink.shape.domain(), loco::Domain::Tensor); + ASSERT_EQ(sink.shape.as<loco::TensorShape>().rank(), 2); + ASSERT_EQ(sink.shape.as<loco::TensorShape>().dim(0), 4); + ASSERT_EQ(sink.shape.as<loco::TensorShape>().dim(1), 5); +} |