summaryrefslogtreecommitdiff
path: root/runtimes/neurun/core/include/graph/Graph.h
diff options
context:
space:
mode:
Diffstat (limited to 'runtimes/neurun/core/include/graph/Graph.h')
-rw-r--r--runtimes/neurun/core/include/graph/Graph.h204
1 files changed, 204 insertions, 0 deletions
diff --git a/runtimes/neurun/core/include/graph/Graph.h b/runtimes/neurun/core/include/graph/Graph.h
new file mode 100644
index 000000000..b3e6d54ff
--- /dev/null
+++ b/runtimes/neurun/core/include/graph/Graph.h
@@ -0,0 +1,204 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __NEURUN_GRAPH_GRAPH_H__
+#define __NEURUN_GRAPH_GRAPH_H__
+
+#include <functional>
+
+#include "model/Operation.h"
+#include "model/Model.h"
+#include "graph/LowerInfoMap.h"
+#include "model/Subgraph.h"
+#include "model/Subgraphs.h"
+
+namespace neurun
+{
+namespace graph
+{
+namespace operand
+{
+class LowerInfo;
+} // namespace operand
+} // namespace graph
+} // namespace neurun
+
+namespace neurun
+{
+namespace compiler
+{
+class Linear;
+} // namespace linear
+} // namespace neurun
+
+namespace neurun
+{
+namespace compiler
+{
+class BackendResolver;
+} // namespace compiler
+} // namespace neurun
+
+namespace neurun
+{
+namespace backend
+{
+namespace custom
+{
+class KernelRegistry;
+} // namespace neurun
+} // namespace backend
+} // namespace neurun
+
+namespace neurun
+{
+namespace graph
+{
+
+class Graph
+{
+private:
+ enum class Phase
+ {
+ BUILDING,
+ MODEL
+ };
+
+public:
+ template <bool is_const> class Iterator
+ {
+ public:
+ using GraphRef = typename std::conditional<is_const, const Graph &, Graph &>::type;
+ using IndexRef = const model::OperationIndex &;
+ using NodeRef =
+ typename std::conditional<is_const, const model::Operation &, model::Operation &>::type;
+ using IterFn = std::function<void(IndexRef, NodeRef)>;
+
+ public:
+ virtual ~Iterator() = default;
+ virtual void iterate(GraphRef graph, const IterFn &fn) const = 0;
+ };
+
+ template <bool is_const = false> class DefaultIterator final : public Iterator<is_const>
+ {
+ public:
+ using GraphRef = typename Iterator<is_const>::GraphRef;
+ using IndexRef = typename Iterator<is_const>::IndexRef;
+ using NodeRef = typename Iterator<is_const>::NodeRef;
+ using IterFn = typename Iterator<is_const>::IterFn;
+
+ public:
+ void iterate(GraphRef graph, const IterFn &fn) const;
+ };
+ using DefaultConstIterator = DefaultIterator<true>;
+
+ template <bool is_const = false> class PostDfsIterator final : public Iterator<is_const>
+ {
+ public:
+ using GraphRef = typename Iterator<is_const>::GraphRef;
+ using IndexRef = typename Iterator<is_const>::IndexRef;
+ using NodeRef = typename Iterator<is_const>::NodeRef;
+ using IterFn = typename Iterator<is_const>::IterFn;
+
+ public:
+ void iterate(GraphRef graph, const IterFn &fn) const;
+ };
+ using PostDfsConstIterator = PostDfsIterator<true>;
+
+public:
+ Graph(void) = delete;
+ Graph(std::unique_ptr<model::Model> &&model);
+ ~Graph(void);
+
+ // Graph Building
+public:
+ model::OperandIndex addOperand(const model::Shape &shape, const model::TypeInfo &type);
+ model::OperationIndex addOperation(std::unique_ptr<model::Operation> &&node);
+ void setOperandValue(const model::OperandIndex &ind, std::unique_ptr<model::Data> &&data);
+ void addInput(const model::OperandIndex &ind);
+ void addOutput(const model::OperandIndex &ind);
+ void finishBuilding(void);
+ void lower(void);
+ void removeOperand(const model::OperandIndex &ind) { _model->operands.remove(ind); }
+ std::unique_ptr<compiler::Linear> linearize(void);
+ bool isBuildingPhase(void) const { return _phase == Phase::BUILDING; }
+ std::shared_ptr<const model::Model> shareModel() { return _model; }
+ std::unique_ptr<graph::LowerInfoMap> releaseLowerInfo() { return std::move(_lower_info_map); }
+ std::unique_ptr<model::Subgraphs> releaseSubgraphs() { return std::move(_subgraphs); }
+
+private:
+ void initializeUseDef();
+
+ // Custom operations support
+public:
+ void bindKernelRegistry(const std::shared_ptr<backend::custom::KernelRegistry> &registry)
+ {
+ _kernel_registry = registry;
+ }
+
+ const std::shared_ptr<backend::custom::KernelRegistry> &getKernelRegistry() const
+ {
+ return _kernel_registry;
+ }
+
+private:
+ std::shared_ptr<backend::custom::KernelRegistry> _kernel_registry;
+
+ // Accessors
+public:
+ const model::OperandIndexSequence &getInputs() const { return _model->inputs; }
+ model::OperandIndexSequence &getInputs() { return _model->inputs; }
+ const model::OperandIndexSequence &getOutputs() const { return _model->outputs; }
+ model::OperandIndexSequence &getOutputs() { return _model->outputs; }
+ const model::Operands &operands() const { return _model->operands; }
+ model::Operands &operands() { return _model->operands; } // TODO Remove this non-const accessor
+ const model::Operations &operations() const { return _model->operations; }
+ model::Operations &operations() { return _model->operations; }
+ const compiler::BackendResolver *backend_resolver() const { return _backend_resolver.get(); }
+
+private:
+ Phase _phase{Phase::BUILDING};
+ std::shared_ptr<model::Model> _model;
+
+ // For LOWERED phase
+public:
+ const operation::LowerInfo *getLowerInfo(const model::SubgraphIndex &subg_index) const;
+ void setLowerInfo(const model::SubgraphIndex &subg_index,
+ std::unique_ptr<operation::LowerInfo> &&lower_info);
+ const operand::LowerInfo *getLowerInfo(const model::OperandIndex &index) const;
+ operand::LowerInfo *getLowerInfo(const model::OperandIndex &index);
+ void setLowerInfo(const model::OperandIndex &index,
+ std::unique_ptr<operand::LowerInfo> &&lower_info);
+ model::Subgraphs &subgraphs()
+ {
+ assert(_subgraphs);
+ return *_subgraphs;
+ }
+ const model::Subgraphs *subgraphs() const { return _subgraphs.get(); }
+ void setBackendResolver(std::unique_ptr<compiler::BackendResolver> &&br);
+ std::unique_ptr<compiler::BackendResolver> releaseBackendResolver();
+
+private:
+ std::unique_ptr<compiler::BackendResolver> _backend_resolver;
+ std::unique_ptr<LowerInfoMap> _lower_info_map;
+ // Pass(for Perm) can accept only graph so that Graph has Subgraphs as a member
+ std::unique_ptr<model::Subgraphs> _subgraphs;
+};
+
+} // namespace graph
+} // namespace neurun
+
+#endif // __NEURUN_GRAPH_GRAPH_H__