summaryrefslogtreecommitdiff
path: root/runtimes/neurun/src/graph/Graph.h
diff options
context:
space:
mode:
Diffstat (limited to 'runtimes/neurun/src/graph/Graph.h')
-rw-r--r--runtimes/neurun/src/graph/Graph.h79
1 files changed, 51 insertions, 28 deletions
diff --git a/runtimes/neurun/src/graph/Graph.h b/runtimes/neurun/src/graph/Graph.h
index dd1489a93..afcfdce12 100644
--- a/runtimes/neurun/src/graph/Graph.h
+++ b/runtimes/neurun/src/graph/Graph.h
@@ -19,10 +19,8 @@
#include <functional>
-#include "graph/operation/Node.h"
-#include "graph/operation/Set.h"
-#include "graph/operand/IndexSet.h"
-#include "graph/operand/Set.h"
+#include "model/operation/Node.h"
+#include "graph/Model.h"
namespace neurun
{
@@ -34,6 +32,14 @@ class Linear;
namespace neurun
{
+namespace compiler
+{
+class BackendResolver;
+} // namespace compiler
+} // namespace neurun
+
+namespace neurun
+{
namespace graph
{
@@ -43,9 +49,7 @@ private:
enum class Phase
{
BUILDING,
- MODEL,
- LOWERED,
- LINEARIZED // Everything is moved to Linear object so this Graph object is no longer effective
+ MODEL
};
public:
@@ -53,9 +57,10 @@ public:
{
public:
using GraphRef = typename std::conditional<is_const, const Graph &, Graph &>::type;
- using NodeRef =
- typename std::conditional<is_const, const operation::Node &, operation::Node &>::type;
- using IterFn = std::function<void(NodeRef)>;
+ using IndexRef = const model::operation::Index &;
+ using NodeRef = typename std::conditional<is_const, const model::operation::Node &,
+ model::operation::Node &>::type;
+ using IterFn = std::function<void(IndexRef, NodeRef)>;
public:
virtual ~Iterator() = default;
@@ -66,6 +71,7 @@ public:
{
public:
using GraphRef = typename Iterator<is_const>::GraphRef;
+ using IndexRef = typename Iterator<is_const>::IndexRef;
using NodeRef = typename Iterator<is_const>::NodeRef;
using IterFn = typename Iterator<is_const>::IterFn;
@@ -78,6 +84,7 @@ public:
{
public:
using GraphRef = typename Iterator<is_const>::GraphRef;
+ using IndexRef = typename Iterator<is_const>::IndexRef;
using NodeRef = typename Iterator<is_const>::NodeRef;
using IterFn = typename Iterator<is_const>::IterFn;
@@ -87,20 +94,21 @@ public:
using PostDfsConstIterator = PostDfsIterator<true>;
public:
- Graph(void) = default;
+ Graph(void);
+ ~Graph(void);
// Graph Building
public:
- operand::Index addOperand(const operand::Shape &shape, const operand::TypeInfo &type);
- operation::Index addOperation(std::unique_ptr<operation::Node> &&node);
- operation::Index insertOperation(const operand::Index &prev_operand_index,
- const operation::Index &next_operation_index,
- std::unique_ptr<operation::Node> &&node);
- void setOperandValue(const operand::Index &ind, std::unique_ptr<operand::Data> &&data);
- void addInput(const operand::Index &ind);
- void addOutput(const operand::Index &ind);
+ model::operand::Index addOperand(const model::operand::Shape &shape,
+ const model::operand::TypeInfo &type);
+ model::operation::Index addOperation(std::unique_ptr<model::operation::Node> &&node);
+ void setOperandValue(const model::operand::Index &ind,
+ std::unique_ptr<model::operand::Data> &&data);
+ void addInput(const model::operand::Index &ind);
+ void addOutput(const model::operand::Index &ind);
void finishBuilding(void);
void lower(void);
+ void removeOperand(const model::operand::Index &ind) { _model->operands.remove(ind); }
std::unique_ptr<linear::Linear> linearize(void);
bool isBuildingPhase(void) const { return _phase == Phase::BUILDING; }
@@ -109,18 +117,33 @@ private:
// Accessors
public:
- const operand::IndexSet &getInputs() const { return _inputs; }
- const operand::IndexSet &getOutputs() const { return _outputs; }
- const operand::Set &operands() const { return _operands; }
- operand::Set &operands() { return _operands; } // TODO Remove this non-const accessor
- const operation::Set &operations() const { return _operations; }
+ const model::operand::IndexSet &getInputs() const { return _model->inputs; }
+ model::operand::IndexSet &getInputs() { return _model->inputs; }
+ const model::operand::IndexSet &getOutputs() const { return _model->outputs; }
+ model::operand::IndexSet &getOutputs() { return _model->outputs; }
+ const model::operand::Set &operands() const { return _model->operands; }
+ model::operand::Set &operands()
+ {
+ return _model->operands;
+ } // TODO Remove this non-const accessor
+ const model::operation::Set &operations() const { return _model->operations; }
+ model::operation::Set &operations() { return _model->operations; }
+ const compiler::BackendResolver *backend_resolver() const { return _backend_resolver.get(); }
private:
Phase _phase{Phase::BUILDING};
- operation::Set _operations;
- operand::Set _operands;
- operand::IndexSet _inputs;
- operand::IndexSet _outputs;
+ std::unique_ptr<Model> _model{new Model};
+
+ // For LOWERED phase
+public:
+ const operation::LowerInfo *getLowerInfo(const model::operation::Index &index) const;
+ void setLowerInfo(const model::operation::Index &index,
+ std::unique_ptr<operation::LowerInfo> &&lower_info);
+
+private:
+ std::unique_ptr<compiler::BackendResolver> _backend_resolver;
+ std::unordered_map<model::operation::Index, std::unique_ptr<operation::LowerInfo>>
+ _operation_lower_info;
};
} // namespace graph