summaryrefslogtreecommitdiff
path: root/compiler/loco/src/Service/ShapeInference.test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/loco/src/Service/ShapeInference.test.cpp')
-rw-r--r--compiler/loco/src/Service/ShapeInference.test.cpp87
1 files changed, 87 insertions, 0 deletions
diff --git a/compiler/loco/src/Service/ShapeInference.test.cpp b/compiler/loco/src/Service/ShapeInference.test.cpp
new file mode 100644
index 000000000..e10b98844
--- /dev/null
+++ b/compiler/loco/src/Service/ShapeInference.test.cpp
@@ -0,0 +1,87 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/Service/ShapeInference.h"
+#include "GraphTestcase.h"
+
+#include <vector>
+
+#include <gtest/gtest.h>
+
+// This test validates whether framework works as expected.
+TEST(ShapeInferenceTest, framework)
+{
+ // Mock-up Shape Inference Rule
+ struct SampleShapeInferenceRule final : public loco::ShapeInferenceRule
+ {
+ public:
+ SampleShapeInferenceRule(std::vector<const loco::Node *> *nodes) : _nodes{nodes}
+ {
+ // DO NOTHING
+ }
+
+ public:
+ // Accept all the dialects
+ bool recognize(const loco::Dialect *) const final { return true; }
+
+ bool infer(const loco::Node *node, loco::NodeShape &shape) const final
+ {
+ // Record the order of inference
+ _nodes->emplace_back(node);
+
+ if (_nodes->size() != 1)
+ {
+ return false;
+ }
+
+ // Set the first node as Tensor<1>
+ loco::TensorShape tensor_shape;
+
+ tensor_shape.rank(1);
+ tensor_shape.dim(0) = 4;
+
+ shape.set(tensor_shape);
+
+ return true;
+ }
+
+ private:
+ std::vector<const loco::Node *> *_nodes;
+ };
+
+ GraphTestcase<GraphCode::Identity> testcase;
+
+ std::vector<const loco::Node *> nodes;
+
+ SampleShapeInferenceRule rule{&nodes};
+
+ loco::apply(&rule).to(testcase.graph());
+
+ // Framework SHOULD visit all the nodes
+ ASSERT_EQ(nodes.size(), 2);
+ // Framework SHOULD visit "pull" before "push"
+ ASSERT_EQ(nodes.at(0), testcase.pull_node);
+ ASSERT_EQ(nodes.at(1), testcase.push_node);
+
+ // Framework SHOULD make an annotation if "rule" returns TRUE
+ ASSERT_TRUE(loco::shape_known(testcase.pull_node));
+ ASSERT_EQ(loco::shape_get(testcase.pull_node).domain(), loco::Domain::Tensor);
+ ASSERT_EQ(loco::shape_get(testcase.pull_node).as<loco::TensorShape>().rank(), 1);
+ ASSERT_EQ(loco::shape_get(testcase.pull_node).as<loco::TensorShape>().dim(0), 4);
+
+ // Framework SHOULD NOT make any annotation if "rule" returns FALSE
+ ASSERT_FALSE(loco::shape_known(testcase.push_node));
+}