summaryrefslogtreecommitdiff
path: root/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp')
-rw-r--r--compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp55
1 files changed, 55 insertions, 0 deletions
diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp
index d8022f5ca..62fa6786a 100644
--- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp
+++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp
@@ -281,3 +281,58 @@ TEST(CanonicalShapeInferenceRuleTest, tensor_broadcast)
ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(0), 4);
ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(1), 2);
}
+
+namespace
+{
+
+struct MockContext final : public loco::ShapeInferenceRule::Context
+{
+ bool known(const loco::Node *node) const final { return _content.find(node) != _content.end(); }
+ loco::NodeShape get(const loco::Node *node) const final { return _content.at(node); }
+
+ std::map<const loco::Node *, loco::NodeShape> _content;
+};
+
+struct MockSink final : public loco::ShapeInferenceRule::Sink
+{
+ void okay(const loco::NodeShape &res) final { shape = res; }
+ void fail(void) final { return; }
+
+ loco::NodeShape shape;
+};
+
+} // namespace
+
+TEST(CanonicalShapeInferenceRuleTest, infer_v2)
+{
+ auto g = loco::make_graph();
+
+ // Create an incomplete graph
+ auto relu_1 = g->nodes()->create<loco::ReLU>();
+ auto relu_2 = g->nodes()->create<loco::ReLU>();
+
+ relu_2->input(relu_1);
+
+ // Set up Context
+ MockContext ctx;
+
+ loco::TensorShape tensor_shape;
+
+ tensor_shape.rank(2);
+ tensor_shape.dim(0) = 4;
+ tensor_shape.dim(1) = 5;
+
+ ctx._content[relu_1] = tensor_shape;
+
+ // Create a Sink
+ MockSink sink;
+
+ loco::CanonicalShapeInferenceRule rule;
+
+ rule.infer(&ctx, relu_2, &sink);
+
+ ASSERT_EQ(sink.shape.domain(), loco::Domain::Tensor);
+ ASSERT_EQ(sink.shape.as<loco::TensorShape>().rank(), 2);
+ ASSERT_EQ(sink.shape.as<loco::TensorShape>().dim(0), 4);
+ ASSERT_EQ(sink.shape.as<loco::TensorShape>().dim(1), 5);
+}