summaryrefslogtreecommitdiff
path: root/compiler/loco/include/loco/Service/TypeInference.h
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/loco/include/loco/Service/TypeInference.h')
-rw-r--r--compiler/loco/include/loco/Service/TypeInference.h114
1 files changed, 114 insertions, 0 deletions
diff --git a/compiler/loco/include/loco/Service/TypeInference.h b/compiler/loco/include/loco/Service/TypeInference.h
new file mode 100644
index 000000000..c2ce1a4c7
--- /dev/null
+++ b/compiler/loco/include/loco/Service/TypeInference.h
@@ -0,0 +1,114 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCO_SERVICE_TYPE_INFERENCE_H__
+#define __LOCO_SERVICE_TYPE_INFERENCE_H__
+
+#include "loco/IR/DataType.h"
+
+#include "loco/IR/Node.h"
+#include "loco/IR/Dialect.h"
+#include "loco/IR/Graph.h"
+
+#include <map>
+
+/**
+ * @file This file implements dialect-agnostic type inference framework.
+ *
+ * HOW TO USE:
+ *
+ * loco::Graph *g = ...;
+ * loco::TypeInferenceRule *rule = ...;
+ * loco::apply(rule).to(g);
+ *
+ */
+namespace loco
+{
+
+struct TypeInferenceRule
+{
+ virtual ~TypeInferenceRule() = default;
+
+ /// @brief Return true if this rule recognizes a given dialect
+ virtual bool recognize(const Dialect *) const = 0;
+
+ /**
+ * Framework guarantees the followings:
+ *
+ * 1. Framework tries to infer the data type of each node only after the data type of all of
+ * its valid (= non-nullptr) argument nodes is inferred.
+ * 2. The result of preceding "infer" is accessible through below dtype_get method.
+ * - This holds only when preceding "infer" returns true.
+ */
+ virtual bool infer(const Node *, DataType &) const = 0;
+};
+
+/**
+ * @brief Type Inference Rule for Canonical Dialect
+ */
+struct CanonicalTypeInferenceRule final : public TypeInferenceRule
+{
+ bool recognize(const Dialect *) const final;
+ bool infer(const Node *, DataType &) const final;
+};
+
+/**
+ * @brief Type Inference Rule for multiple dialects
+ */
+class MultiDialectTypeInferenceRule final : public TypeInferenceRule
+{
+public:
+ bool recognize(const Dialect *) const final;
+ bool infer(const Node *, DataType &) const final;
+
+ /// @brief Bind a specific rule to a Dialect
+ MultiDialectTypeInferenceRule &bind(const Dialect *d, const TypeInferenceRule *rule);
+
+private:
+ std::map<const Dialect *, const TypeInferenceRule *> _rules;
+};
+
+class TypeInferenceSession
+{
+public:
+ TypeInferenceSession(const TypeInferenceRule *rule) : _rule{rule}
+ {
+ // DO NOTHING
+ }
+
+public:
+ bool to(Graph *g) const;
+
+private:
+ const TypeInferenceRule *_rule;
+};
+
+inline TypeInferenceSession apply(TypeInferenceRule *r) { return TypeInferenceSession{r}; }
+
+struct TypeInference
+{
+ static bool known(const Node *);
+ static DataType get(const Node *);
+ static void erase(Node *);
+};
+
+inline bool dtype_known(const Node *node) { return TypeInference::known(node); }
+inline DataType dtype_get(const Node *node) { return TypeInference::get(node); }
+inline void dtype_erase(Node *node) { TypeInference::erase(node); }
+
+} // namespace loco
+
+#endif // __LOCO_SERVICE_TYPE_INFERENCE_H__