diff options
Diffstat (limited to 'runtime/onert/core/src/ir/Graph.cc')
-rw-r--r-- | runtime/onert/core/src/ir/Graph.cc | 174 |
1 files changed, 122 insertions, 52 deletions
diff --git a/runtime/onert/core/src/ir/Graph.cc b/runtime/onert/core/src/ir/Graph.cc index fe8b1b443..306572c99 100644 --- a/runtime/onert/core/src/ir/Graph.cc +++ b/runtime/onert/core/src/ir/Graph.cc @@ -16,18 +16,10 @@ #include "ir/Graph.h" -#include <algorithm> -#include <bitset> -#include <sstream> - -#include "util/logging.h" +#include "OperationValidator.h" #include "verifier/Verifier.h" -#include "ir/operation/LowerInfo.h" -#include "ir/operand/LowerInfo.h" -#include "ir/operand/PermuteFactor.h" -#include "ir/OperandIndexMap.h" -#include "ir/GraphIterator.h" -#include "backend/IConfig.h" + +#include "util/Set.h" namespace onert { @@ -36,6 +28,8 @@ namespace ir Graph::Graph() = default; +Graph::Graph(const Graph &) = default; + Graph::~Graph(void) = default; OperandIndex Graph::addOperand(const Shape &shape, const TypeInfo &type) @@ -43,22 +37,91 @@ OperandIndex Graph::addOperand(const Shape &shape, const TypeInfo &type) return _operands.emplace(shape, type); } -OperationIndex Graph::addOperation(std::unique_ptr<Operation> &&node) +OperandIndex Graph::addOperand(OperandIndex index, std::unique_ptr<Operand> &&operand) +{ + return _operands.push(std::move(operand), index); +} + +bool Graph::checkOperandsForOperation(const IOperation &operation) { - assert(isBuildingPhase()); - return _operations.push(std::move(node)); + auto inputs = operation.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED; + auto outputs = operation.getOutputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED; + for (auto &&input : inputs) + if (!operands().exist(input)) + return false; + for (auto &&input : outputs) + if (!operands().exist(input)) + return false; + return true; +} + +void Graph::linkOperandToOperation(OperationIndex index, const IOperation &operation) +{ + auto inputs = operation.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED; + auto outputs = operation.getOutputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED; + + for (auto &&input : inputs) + operands().at(input).insertUse(index); + for (auto &&output : outputs) + operands().at(output).setDef(index); +} + +OperationIndex Graph::addOperation(std::unique_ptr<IOperation> &&operation) +{ + const IOperation &op_ref = *operation; + if (!checkOperandsForOperation(op_ref)) + return OperationIndex{}; + auto ind = _operations.push(std::move(operation)); + if (ind.valid()) + linkOperandToOperation(ind, op_ref); + return ind; +} + +OperationIndex Graph::addOperation(OperationIndex index, std::unique_ptr<IOperation> &&operation) +{ + const IOperation &op_ref = *operation; + if (!checkOperandsForOperation(op_ref)) + return OperationIndex{}; + auto ind_gen = _operations.push(std::move(operation), index); + if (ind_gen.valid()) + { + assert(ind_gen == index); + linkOperandToOperation(index, op_ref); + } + return index; +} + +OperationIndex Graph::replaceOperation(OperationIndex index, + std::unique_ptr<IOperation> &&operation) +{ + const IOperation &op_ref = *operation; + if (!checkOperandsForOperation(op_ref) || !_operations.exist(index)) + return OperationIndex{}; + + // Check the new operation has the same inputs/outputs as the existing operation + const auto &old_op = _operations.at(index); + if (!(old_op.getInputs() == op_ref.getInputs() && old_op.getOutputs() == op_ref.getOutputs())) + { + return OperationIndex{}; + } + + return _operations.set(index, std::move(operation)); } void Graph::setOperandValue(const OperandIndex &ind, std::shared_ptr<Data> data) { - assert(isBuildingPhase()); assert(_operands.exist(ind)); _operands.at(ind).data(std::move(data)); } +void Graph::changeShape(const OperandIndex &ind, const ir::Shape &new_shape) +{ + assert(_operands.exist(ind)); + _operands.at(ind).info().shape(new_shape); +} + void Graph::addInput(const OperandIndex &ind, const std::string &name) { - assert(isBuildingPhase()); if (!name.empty()) _name_to_input.emplace(name, IOIndex{_inputs.size()}); _inputs.append(ind); @@ -66,7 +129,6 @@ void Graph::addInput(const OperandIndex &ind, const std::string &name) void Graph::addOutput(const OperandIndex &ind, const std::string &name) { - assert(isBuildingPhase()); if (!name.empty()) _name_to_output.emplace(name, IOIndex{_outputs.size()}); _outputs.append(ind); @@ -84,62 +146,70 @@ IOIndex Graph::getOutputIndex(const std::string &name) const return (itr == _name_to_output.end()) ? IOIndex{} : itr->second; } -void Graph::finishBuilding(void) +void Graph::verify(void) const { - assert(isBuildingPhase()); - _phase = Phase::MODEL; - - initializeUseDef(); - sweepGarbageOperands(); - // Call graph verifications for the MODEL phase { - assert(verifier::DAGChecker().verify(*this)); - assert(verifier::EdgeConsistencyChecker().verify(*this)); + // Except for edge consistency, the user might have been given a bad model + // so here it throws an execption rather than assertion. + if (!verifier::InputOutputChecker().verify(*this)) + throw std::runtime_error{"One of model input and output operands does not exist."}; + if (!verifier::DAGChecker().verify(*this)) + throw std::runtime_error{"The graph is cyclic."}; + assert(verifier::EdgeChecker().verify(*this)); } + + // Check shape independent operation feature + // - Operand type + // - Shape independent parameter + OperationValidator{*this}(); } void Graph::initializeUseDef() { - operations().iterate([&](const OperationIndex &index, const Operation &node) -> void { - auto outputs = node.getOutputs(); - for (auto output : outputs) + operations().iterate([&](const OperationIndex &index, const IOperation &node) -> void { + const auto &outputs = node.getOutputs(); + for (auto &&output : outputs | ir::Remove::UNDEFINED) { operands().at(output).setDef(index); } - for (auto input : node.getInputs() | ir::Remove::UNDEFINED) + for (auto &&input : node.getInputs() | ir::Remove::UNDEFINED) { operands().at(input).insertUse(index); } }); } -void Graph::sweepGarbageOperands() +std::vector<ir::OperationIndex> Graph::topolSortOperations() const { - // Remove operands that are not used by any operations, except Graph inputs/outputs - ir::OperandIndexMap<bool> visited; - - operations().iterate([&](const OperationIndex &, const Operation &node) { - for (auto ind : node.getInputs() + node.getOutputs()) - { - visited[ind] = true; - } - }); - - // Graph's inputs/outputs are always reachable - for (auto ind : getInputs() + getOutputs()) - { - visited[ind] = true; - } - - operands().iterate([&](const OperandIndex &ind, const Operand &) { - if (!visited[ind]) + std::vector<ir::OperationIndex> ret; + util::Set<ir::OperationIndex> unvisited; + operations().iterate( + [&](const ir::OperationIndex &index, const ir::IOperation &) { unvisited.add(index); }); + + std::function<void(const ir::OperationIndex &, const ir::IOperation &)> dfs = + [&](const ir::OperationIndex &index, const ir::IOperation &op) -> void { + if (!unvisited.contains(index)) + return; + unvisited.remove(index); + + for (const auto &output : op.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED) { - VERBOSE(Graph::sweepGarbageOperands) << "Sweep garbage operand " << ind.value() << std::endl; - operands().remove(ind); + const auto &operand = operands().at(output); + for (const auto &use : operand.getUses()) + { + dfs(use, operations().at(use)); + } } - }); + ret.push_back(index); + }; + operations().iterate(dfs); + + assert(unvisited.empty()); // All of the nodes must have been visited + // Reversing Postorder DFS result to make it sorted in topoligical order + std::reverse(ret.begin(), ret.end()); + return ret; } } // namespace ir |