summaryrefslogtreecommitdiff
path: root/runtime/neurun/core/include/ir/Graph.h
blob: 5105c3a42a40ee8ce29a456f0475aa12ef3100f7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
/*
 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef __NEURUN_IR_GRAPH_H__
#define __NEURUN_IR_GRAPH_H__

#include <functional>

#include "ir/Operands.h"
#include "ir/Operations.h"
#include "ir/LowerInfoMap.h"
#include "ir/OpSequence.h"
#include "ir/Subgraphs.h"

namespace neurun
{
namespace compiler
{
class BackendResolver;
} // namespace compiler
} // namespace neurun

namespace neurun
{
namespace backend
{
namespace custom
{
class IKernelBuilder;
} // namespace custom
} // namespace backend
} // namespace neurun

namespace neurun
{
namespace ir
{

class Graph
{
private:
  enum class Phase
  {
    BUILDING,
    MODEL
  };

public:
  Graph(void);
  ~Graph(void);

  // Graph Building
public:
  OperandIndex addOperand(const Shape &shape, const TypeInfo &type);
  OperationIndex addOperation(std::unique_ptr<Operation> &&node);
  void setOperandValue(const OperandIndex &ind, std::unique_ptr<Data> &&data);
  void addInput(const OperandIndex &ind);
  void addOutput(const OperandIndex &ind);
  void finishBuilding(void);
  void lower(void);
  void removeOperand(const OperandIndex &ind) { _operands.remove(ind); }
  bool isBuildingPhase(void) const { return _phase == Phase::BUILDING; }

private:
  void initializeUseDef();

  // Custom operations support
public:
  void
  bindKernelBuilder(const std::shared_ptr<neurun::backend::custom::IKernelBuilder> &kernel_builder)
  {
    _kernel_builder = kernel_builder;
  }

  const std::shared_ptr<backend::custom::IKernelBuilder> &getKernelBuilder() const
  {
    return _kernel_builder;
  }

private:
  std::shared_ptr<backend::custom::IKernelBuilder> _kernel_builder;

  // Accessors
public:
  const OperandIndexSequence &getInputs() const { return _inputs; }
  OperandIndexSequence &getInputs() { return _inputs; }
  const OperandIndexSequence &getOutputs() const { return _outputs; }
  OperandIndexSequence &getOutputs() { return _outputs; }
  const Operands &operands() const { return _operands; }
  Operands &operands() { return _operands; } // TODO Remove this non-const accessor
  const Operations &operations() const { return _operations; }
  Operations &operations() { return _operations; }
  const compiler::BackendResolver *backend_resolver() const { return _backend_resolver.get(); }

private:
  Phase _phase{Phase::BUILDING};
  Operations _operations;
  Operands _operands;
  OperandIndexSequence _inputs;
  OperandIndexSequence _outputs;

  // For LOWERED phase
public:
  const LowerInfoMap *getLowerInfo() const { return _lower_info_map.get(); }
  const operation::LowerInfo *getLowerInfo(const SubgraphIndex &subg_index) const;
  void setLowerInfo(const SubgraphIndex &subg_index,
                    std::unique_ptr<operation::LowerInfo> &&lower_info);
  void removeLowerInfo(const SubgraphIndex &subg_index);
  const operand::LowerInfo *getLowerInfo(const OperandIndex &index) const;
  operand::LowerInfo *getLowerInfo(const OperandIndex &index);
  void setLowerInfo(const OperandIndex &index, std::unique_ptr<operand::LowerInfo> &&lower_info);
  void removeLowerInfo(const OperandIndex &index);
  Subgraphs &subgraphs()
  {
    assert(_op_seqs);
    return *_op_seqs;
  }
  const Subgraphs *subgraphs() const { return _op_seqs.get(); }
  void setBackendResolver(std::unique_ptr<compiler::BackendResolver> &&br);

private:
  void makeSubgraphs(OperandIndexMap<std::unique_ptr<operand::LowerInfo>> &operands_lower_info);
  void
  manipulateLowerInfo(OperandIndexMap<std::unique_ptr<operand::LowerInfo>> &operands_lower_info);
  void dumpLowerInfo();
  bool mergeable(const SubgraphIndex &subg_index, const OperationIndex &node_index, Layout layout);
  SubgraphIndex appendFreshSingleOpSubgraph(const OperationIndex &node_index, const Operation &node,
                                            Layout layout);

private:
  std::unique_ptr<compiler::BackendResolver> _backend_resolver;
  std::unique_ptr<LowerInfoMap> _lower_info_map;
  // Pass(for Perm) can accept only graph so that Graph has Subgraphs as a member
  std::unique_ptr<Subgraphs> _op_seqs;
};

} // namespace ir
} // namespace neurun

#endif // __NEURUN_IR_GRAPH_H__