diff options
Diffstat (limited to 'runtimes/neurun/core/include/graph/Graph.h')
-rw-r--r-- | runtimes/neurun/core/include/graph/Graph.h | 204 |
1 files changed, 204 insertions, 0 deletions
diff --git a/runtimes/neurun/core/include/graph/Graph.h b/runtimes/neurun/core/include/graph/Graph.h new file mode 100644 index 000000000..b3e6d54ff --- /dev/null +++ b/runtimes/neurun/core/include/graph/Graph.h @@ -0,0 +1,204 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NEURUN_GRAPH_GRAPH_H__ +#define __NEURUN_GRAPH_GRAPH_H__ + +#include <functional> + +#include "model/Operation.h" +#include "model/Model.h" +#include "graph/LowerInfoMap.h" +#include "model/Subgraph.h" +#include "model/Subgraphs.h" + +namespace neurun +{ +namespace graph +{ +namespace operand +{ +class LowerInfo; +} // namespace operand +} // namespace graph +} // namespace neurun + +namespace neurun +{ +namespace compiler +{ +class Linear; +} // namespace linear +} // namespace neurun + +namespace neurun +{ +namespace compiler +{ +class BackendResolver; +} // namespace compiler +} // namespace neurun + +namespace neurun +{ +namespace backend +{ +namespace custom +{ +class KernelRegistry; +} // namespace neurun +} // namespace backend +} // namespace neurun + +namespace neurun +{ +namespace graph +{ + +class Graph +{ +private: + enum class Phase + { + BUILDING, + MODEL + }; + +public: + template <bool is_const> class Iterator + { + public: + using GraphRef = typename std::conditional<is_const, const Graph &, Graph &>::type; + using IndexRef = const model::OperationIndex &; + using NodeRef = + typename std::conditional<is_const, const model::Operation &, model::Operation &>::type; + using IterFn = std::function<void(IndexRef, NodeRef)>; + + public: + virtual ~Iterator() = default; + virtual void iterate(GraphRef graph, const IterFn &fn) const = 0; + }; + + template <bool is_const = false> class DefaultIterator final : public Iterator<is_const> + { + public: + using GraphRef = typename Iterator<is_const>::GraphRef; + using IndexRef = typename Iterator<is_const>::IndexRef; + using NodeRef = typename Iterator<is_const>::NodeRef; + using IterFn = typename Iterator<is_const>::IterFn; + + public: + void iterate(GraphRef graph, const IterFn &fn) const; + }; + using DefaultConstIterator = DefaultIterator<true>; + + template <bool is_const = false> class PostDfsIterator final : public Iterator<is_const> + { + public: + using GraphRef = typename Iterator<is_const>::GraphRef; + using IndexRef = typename Iterator<is_const>::IndexRef; + using NodeRef = typename Iterator<is_const>::NodeRef; + using IterFn = typename Iterator<is_const>::IterFn; + + public: + void iterate(GraphRef graph, const IterFn &fn) const; + }; + using PostDfsConstIterator = PostDfsIterator<true>; + +public: + Graph(void) = delete; + Graph(std::unique_ptr<model::Model> &&model); + ~Graph(void); + + // Graph Building +public: + model::OperandIndex addOperand(const model::Shape &shape, const model::TypeInfo &type); + model::OperationIndex addOperation(std::unique_ptr<model::Operation> &&node); + void setOperandValue(const model::OperandIndex &ind, std::unique_ptr<model::Data> &&data); + void addInput(const model::OperandIndex &ind); + void addOutput(const model::OperandIndex &ind); + void finishBuilding(void); + void lower(void); + void removeOperand(const model::OperandIndex &ind) { _model->operands.remove(ind); } + std::unique_ptr<compiler::Linear> linearize(void); + bool isBuildingPhase(void) const { return _phase == Phase::BUILDING; } + std::shared_ptr<const model::Model> shareModel() { return _model; } + std::unique_ptr<graph::LowerInfoMap> releaseLowerInfo() { return std::move(_lower_info_map); } + std::unique_ptr<model::Subgraphs> releaseSubgraphs() { return std::move(_subgraphs); } + +private: + void initializeUseDef(); + + // Custom operations support +public: + void bindKernelRegistry(const std::shared_ptr<backend::custom::KernelRegistry> ®istry) + { + _kernel_registry = registry; + } + + const std::shared_ptr<backend::custom::KernelRegistry> &getKernelRegistry() const + { + return _kernel_registry; + } + +private: + std::shared_ptr<backend::custom::KernelRegistry> _kernel_registry; + + // Accessors +public: + const model::OperandIndexSequence &getInputs() const { return _model->inputs; } + model::OperandIndexSequence &getInputs() { return _model->inputs; } + const model::OperandIndexSequence &getOutputs() const { return _model->outputs; } + model::OperandIndexSequence &getOutputs() { return _model->outputs; } + const model::Operands &operands() const { return _model->operands; } + model::Operands &operands() { return _model->operands; } // TODO Remove this non-const accessor + const model::Operations &operations() const { return _model->operations; } + model::Operations &operations() { return _model->operations; } + const compiler::BackendResolver *backend_resolver() const { return _backend_resolver.get(); } + +private: + Phase _phase{Phase::BUILDING}; + std::shared_ptr<model::Model> _model; + + // For LOWERED phase +public: + const operation::LowerInfo *getLowerInfo(const model::SubgraphIndex &subg_index) const; + void setLowerInfo(const model::SubgraphIndex &subg_index, + std::unique_ptr<operation::LowerInfo> &&lower_info); + const operand::LowerInfo *getLowerInfo(const model::OperandIndex &index) const; + operand::LowerInfo *getLowerInfo(const model::OperandIndex &index); + void setLowerInfo(const model::OperandIndex &index, + std::unique_ptr<operand::LowerInfo> &&lower_info); + model::Subgraphs &subgraphs() + { + assert(_subgraphs); + return *_subgraphs; + } + const model::Subgraphs *subgraphs() const { return _subgraphs.get(); } + void setBackendResolver(std::unique_ptr<compiler::BackendResolver> &&br); + std::unique_ptr<compiler::BackendResolver> releaseBackendResolver(); + +private: + std::unique_ptr<compiler::BackendResolver> _backend_resolver; + std::unique_ptr<LowerInfoMap> _lower_info_map; + // Pass(for Perm) can accept only graph so that Graph has Subgraphs as a member + std::unique_ptr<model::Subgraphs> _subgraphs; +}; + +} // namespace graph +} // namespace neurun + +#endif // __NEURUN_GRAPH_GRAPH_H__ |