summaryrefslogtreecommitdiff
path: root/compiler/loco/src/Service/MultiDialectShapeInferenceRule.test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/loco/src/Service/MultiDialectShapeInferenceRule.test.cpp')
-rw-r--r--compiler/loco/src/Service/MultiDialectShapeInferenceRule.test.cpp134
1 files changed, 134 insertions, 0 deletions
diff --git a/compiler/loco/src/Service/MultiDialectShapeInferenceRule.test.cpp b/compiler/loco/src/Service/MultiDialectShapeInferenceRule.test.cpp
new file mode 100644
index 000000000..ffa9ee5ca
--- /dev/null
+++ b/compiler/loco/src/Service/MultiDialectShapeInferenceRule.test.cpp
@@ -0,0 +1,134 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/Service/CanonicalShapeInferenceRule.h"
+#include "loco/Service/MultiDialectShapeInferenceRule.h"
+#include "loco/Service/ShapeInference.h"
+
+#include <loco/IR/Dialect.h>
+#include <loco/IR/CanonicalDialect.h>
+
+#include <gtest/gtest.h>
+
+#include <cassert>
+#include <vector>
+
+// mockup for MultiDialectShapeInferenceRule
+// Each class is dedicated for handling shape { D1, D2 } and D1, D2 are declared as a template
+namespace
+{
+
+template <uint32_t D1, uint32_t D2> class TestDialect final : public loco::Dialect
+{
+public:
+ static Dialect *get(void)
+ {
+ static TestDialect<D1, D2> d;
+ return &d;
+ }
+};
+
+template <uint32_t D1, uint32_t D2>
+struct TestOpNode final : public loco::FixedArity<1>::Mixin<loco::Node>,
+ public loco::NodeMixin<loco::NodeTrait::TensorShape>
+{
+ void input(Node *node) { at(0)->node(node); }
+ const loco::Dialect *dialect(void) const final { return TestDialect<D1, D2>::get(); }
+ uint32_t opnum(void) const final { return static_cast<uint32_t>(D1); /* not used */ }
+};
+
+template <uint32_t D1, uint32_t D2>
+struct TestShapeInferenceRule final : public loco::ShapeInferenceRule
+{
+public:
+ bool recognize(const loco::Dialect *d) const final { return (d == TestDialect<D1, D2>::get()); }
+
+ bool infer(const loco::Node *node, loco::NodeShape &node_shape) const final
+ {
+ assert(recognize(node->dialect()));
+ auto test_node = dynamic_cast<const TestOpNode<D1, D2> *>(node);
+ assert(test_node != nullptr);
+
+ loco::TensorShape ts;
+ {
+ ts.rank(2);
+ ts.dim(0) = D1;
+ ts.dim(1) = D2; // making shape : { D1, D2 }
+ }
+
+ node_shape.set(ts);
+
+ return true;
+ }
+};
+
+} // namespace
+
+TEST(MultiDialectShapeInferenceRuleTest, test1)
+{
+ // Create a simple network : Pull ------- t23<2,3> ------------ t45<4,5> ---------- Push
+ // TensorShape({2, 3}) TensorShape({4, 5})
+ auto g = loco::make_graph();
+
+ auto pull_node = g->nodes()->create<loco::Pull>();
+ auto t23_node = g->nodes()->create<TestOpNode<2, 3>>();
+ auto t45_node = g->nodes()->create<TestOpNode<4, 5>>();
+ auto push_node = g->nodes()->create<loco::Push>();
+
+ t23_node->input(pull_node);
+ t45_node->input(t23_node);
+ push_node->from(t45_node);
+
+ auto graph_input = g->inputs()->create();
+ graph_input->name("input");
+ loco::link(graph_input, pull_node);
+
+ auto graph_output = g->outputs()->create();
+ graph_output->name("output");
+ loco::link(graph_output, push_node);
+
+ // initially they don't have shape info
+ ASSERT_FALSE(loco::shape_known(t23_node));
+ ASSERT_FALSE(loco::shape_known(t45_node));
+
+ // Run Type Inference
+ loco::CanonicalShapeInferenceRule canonical_rule;
+ TestShapeInferenceRule<2, 3> t23_rule;
+ TestShapeInferenceRule<4, 5> t45_rule;
+
+ loco::MultiDialectShapeInferenceRule rules;
+
+ rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
+ .bind(TestDialect<2, 3>::get(), &t23_rule)
+ .bind(TestDialect<4, 5>::get(), &t45_rule);
+
+ loco::apply(&rules).to(g.get());
+
+ // Verify!
+ ASSERT_TRUE(loco::shape_known(t23_node));
+ auto t23_shape = loco::shape_get(t23_node);
+ ASSERT_EQ(t23_shape.domain(), loco::Domain::Tensor);
+ ASSERT_EQ(t23_shape.as<loco::TensorShape>().rank(), 2);
+ ASSERT_EQ(t23_shape.as<loco::TensorShape>().dim(0), 2);
+ ASSERT_EQ(t23_shape.as<loco::TensorShape>().dim(1), 3);
+
+ ASSERT_TRUE(loco::shape_known(t45_node));
+ auto t45_shape = loco::shape_get(t45_node);
+ ASSERT_EQ(t45_shape.domain(), loco::Domain::Tensor);
+ ASSERT_EQ(t45_shape.as<loco::TensorShape>().rank(), 2);
+ ASSERT_EQ(t45_shape.as<loco::TensorShape>().dim(0), 4);
+ ASSERT_EQ(t45_shape.as<loco::TensorShape>().dim(1), 5);
+}