diff options
Diffstat (limited to 'runtimes/neurun/src/graph/Graph.h')
-rw-r--r-- | runtimes/neurun/src/graph/Graph.h | 79 |
1 files changed, 51 insertions, 28 deletions
diff --git a/runtimes/neurun/src/graph/Graph.h b/runtimes/neurun/src/graph/Graph.h index dd1489a93..afcfdce12 100644 --- a/runtimes/neurun/src/graph/Graph.h +++ b/runtimes/neurun/src/graph/Graph.h @@ -19,10 +19,8 @@ #include <functional> -#include "graph/operation/Node.h" -#include "graph/operation/Set.h" -#include "graph/operand/IndexSet.h" -#include "graph/operand/Set.h" +#include "model/operation/Node.h" +#include "graph/Model.h" namespace neurun { @@ -34,6 +32,14 @@ class Linear; namespace neurun { +namespace compiler +{ +class BackendResolver; +} // namespace compiler +} // namespace neurun + +namespace neurun +{ namespace graph { @@ -43,9 +49,7 @@ private: enum class Phase { BUILDING, - MODEL, - LOWERED, - LINEARIZED // Everything is moved to Linear object so this Graph object is no longer effective + MODEL }; public: @@ -53,9 +57,10 @@ public: { public: using GraphRef = typename std::conditional<is_const, const Graph &, Graph &>::type; - using NodeRef = - typename std::conditional<is_const, const operation::Node &, operation::Node &>::type; - using IterFn = std::function<void(NodeRef)>; + using IndexRef = const model::operation::Index &; + using NodeRef = typename std::conditional<is_const, const model::operation::Node &, + model::operation::Node &>::type; + using IterFn = std::function<void(IndexRef, NodeRef)>; public: virtual ~Iterator() = default; @@ -66,6 +71,7 @@ public: { public: using GraphRef = typename Iterator<is_const>::GraphRef; + using IndexRef = typename Iterator<is_const>::IndexRef; using NodeRef = typename Iterator<is_const>::NodeRef; using IterFn = typename Iterator<is_const>::IterFn; @@ -78,6 +84,7 @@ public: { public: using GraphRef = typename Iterator<is_const>::GraphRef; + using IndexRef = typename Iterator<is_const>::IndexRef; using NodeRef = typename Iterator<is_const>::NodeRef; using IterFn = typename Iterator<is_const>::IterFn; @@ -87,20 +94,21 @@ public: using PostDfsConstIterator = PostDfsIterator<true>; public: - Graph(void) = default; + Graph(void); + ~Graph(void); // Graph Building public: - operand::Index addOperand(const operand::Shape &shape, const operand::TypeInfo &type); - operation::Index addOperation(std::unique_ptr<operation::Node> &&node); - operation::Index insertOperation(const operand::Index &prev_operand_index, - const operation::Index &next_operation_index, - std::unique_ptr<operation::Node> &&node); - void setOperandValue(const operand::Index &ind, std::unique_ptr<operand::Data> &&data); - void addInput(const operand::Index &ind); - void addOutput(const operand::Index &ind); + model::operand::Index addOperand(const model::operand::Shape &shape, + const model::operand::TypeInfo &type); + model::operation::Index addOperation(std::unique_ptr<model::operation::Node> &&node); + void setOperandValue(const model::operand::Index &ind, + std::unique_ptr<model::operand::Data> &&data); + void addInput(const model::operand::Index &ind); + void addOutput(const model::operand::Index &ind); void finishBuilding(void); void lower(void); + void removeOperand(const model::operand::Index &ind) { _model->operands.remove(ind); } std::unique_ptr<linear::Linear> linearize(void); bool isBuildingPhase(void) const { return _phase == Phase::BUILDING; } @@ -109,18 +117,33 @@ private: // Accessors public: - const operand::IndexSet &getInputs() const { return _inputs; } - const operand::IndexSet &getOutputs() const { return _outputs; } - const operand::Set &operands() const { return _operands; } - operand::Set &operands() { return _operands; } // TODO Remove this non-const accessor - const operation::Set &operations() const { return _operations; } + const model::operand::IndexSet &getInputs() const { return _model->inputs; } + model::operand::IndexSet &getInputs() { return _model->inputs; } + const model::operand::IndexSet &getOutputs() const { return _model->outputs; } + model::operand::IndexSet &getOutputs() { return _model->outputs; } + const model::operand::Set &operands() const { return _model->operands; } + model::operand::Set &operands() + { + return _model->operands; + } // TODO Remove this non-const accessor + const model::operation::Set &operations() const { return _model->operations; } + model::operation::Set &operations() { return _model->operations; } + const compiler::BackendResolver *backend_resolver() const { return _backend_resolver.get(); } private: Phase _phase{Phase::BUILDING}; - operation::Set _operations; - operand::Set _operands; - operand::IndexSet _inputs; - operand::IndexSet _outputs; + std::unique_ptr<Model> _model{new Model}; + + // For LOWERED phase +public: + const operation::LowerInfo *getLowerInfo(const model::operation::Index &index) const; + void setLowerInfo(const model::operation::Index &index, + std::unique_ptr<operation::LowerInfo> &&lower_info); + +private: + std::unique_ptr<compiler::BackendResolver> _backend_resolver; + std::unordered_map<model::operation::Index, std::unique_ptr<operation::LowerInfo>> + _operation_lower_info; }; } // namespace graph |