diff options
Diffstat (limited to 'runtime/onert/core/src/compiler/OperationValidator.cc')
-rw-r--r-- | runtime/onert/core/src/compiler/OperationValidator.cc | 171 |
1 files changed, 78 insertions, 93 deletions
diff --git a/runtime/onert/core/src/compiler/OperationValidator.cc b/runtime/onert/core/src/compiler/OperationValidator.cc index 5c545aedd..44496318f 100644 --- a/runtime/onert/core/src/compiler/OperationValidator.cc +++ b/runtime/onert/core/src/compiler/OperationValidator.cc @@ -41,6 +41,21 @@ OperationValidator::OperationValidator(const ir::Graph &graph) { } +void OperationValidator::checkUnaryOp(const ir::Operation &node) +{ + const auto output_index{node.getOutputs().at(0)}; + const auto input_index{node.getInputs().at(0)}; + + // Check if I/O types match + OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type()); + + if (_ctx.at(output_index).info().isDynamic()) + return; + + // Check if I/O shapes match + OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape()); +} + void OperationValidator::operator()() { // There is no reason for each subgraph to have subgraphs since compiler has subgraphs when @@ -53,16 +68,7 @@ void OperationValidator::operator()() [&](const ir::OperationIndex &, const ir::Operation &node) { node.accept(*this); }); } -void OperationValidator::visit(const ir::operation::Abs &node) -{ - const auto output_index{node.getOutputs().at(0)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto input_index{node.getInputs().at(0)}; - - OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape()); -} +void OperationValidator::visit(const ir::operation::Abs &node) { checkUnaryOp(node); } void OperationValidator::visit(const ir::operation::AvgPool2D &node) { @@ -292,17 +298,7 @@ void OperationValidator::visit(const ir::operation::RNN &node) num_units == _ctx.at(hidden_state_out_index).shape().dim(1)); } -void OperationValidator::visit(const ir::operation::Round &node) -{ - const auto output_index{node.getOutputs().at(0)}; - const auto input_index{node.getInputs().at(ir::operation::Round::Input::INPUT)}; - - OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type()); - - if (_ctx.at(output_index).info().isDynamic()) - return; - OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape()); -} +void OperationValidator::visit(const ir::operation::Round &node) { checkUnaryOp(node); } void OperationValidator::visit(const ir::operation::SpaceToBatchND &node) { @@ -393,17 +389,7 @@ void OperationValidator::visit(const ir::operation::EmbeddingLookup &node) } } -void OperationValidator::visit(const ir::operation::Exp &node) -{ - const auto output_index{node.getOutputs().at(0)}; - const auto input_index{node.getInputs().at(ir::operation::Exp::Input::INPUT)}; - - OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type()); - - if (_ctx.at(output_index).info().isDynamic()) - return; - OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape()); -} +void OperationValidator::visit(const ir::operation::Exp &node) { checkUnaryOp(node); } void OperationValidator::visit(const ir::operation::ExpandDims &node) { @@ -419,17 +405,7 @@ void OperationValidator::visit(const ir::operation::ExpandDims &node) OP_REQUIRES(_ctx.at(axis_index).shape().rank() <= 1); } -void OperationValidator::visit(const ir::operation::Floor &node) -{ - const auto output_index{node.getOutputs().at(0)}; - const auto input_index{node.getInputs().at(ir::operation::Floor::Input::INPUT)}; - - OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type()); - - if (_ctx.at(output_index).info().isDynamic()) - return; - OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape()); -} +void OperationValidator::visit(const ir::operation::Floor &node) { checkUnaryOp(node); } void OperationValidator::visit(const ir::operation::HashtableLookup &node) { @@ -789,6 +765,25 @@ void OperationValidator::visit(const ir::operation::LSTM &node) } } +void OperationValidator::visit(const ir::operation::L2Normalization &node) +{ + const auto ofm_index{node.getOutputs().at(0)}; + if (_ctx.at(ofm_index).info().isDynamic()) + return; + + const auto ifm_index{node.getInputs().at(ir::operation::L2Normalization::Input::INPUT)}; + + auto ifm_shape = _ctx.at(ifm_index).shape(); + auto ofm_shape = _ctx.at(ofm_index).shape(); + + OP_REQUIRES(ifm_shape.rank() == ofm_shape.rank()); + + for (auto i = 0; i < ifm_shape.rank(); i++) + { + OP_REQUIRES(ifm_shape.dim(i) == ofm_shape.dim(i)); + } +} + void OperationValidator::visit(const ir::operation::Unpack &node) { const auto num{node.param().num}; @@ -904,45 +899,39 @@ void OperationValidator::visit(const ir::operation::Split &node) OP_REQUIRES(_ctx.at(input_index).shape().dim(axis) % num_splits == 0); } -void OperationValidator::visit(const ir::operation::Cos &node) -{ - const auto output_index{node.getOutputs().at(0)}; - if (_ctx.at(output_index).info().isDynamic()) - return; +void OperationValidator::visit(const ir::operation::Cos &node) { checkUnaryOp(node); } - const auto input_index{node.getInputs().at(0)}; - OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape()); -} - -void OperationValidator::visit(const ir::operation::Sin &node) -{ - const auto output_index{node.getOutputs().at(0)}; - if (_ctx.at(output_index).info().isDynamic()) - return; +void OperationValidator::visit(const ir::operation::Sin &node) { checkUnaryOp(node); } - const auto input_index{node.getInputs().at(0)}; - OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape()); -} +void OperationValidator::visit(const ir::operation::RSQRT &node) { checkUnaryOp(node); } -void OperationValidator::visit(const ir::operation::RSQRT &node) +void OperationValidator::visit(const ir::operation::Shape &node) { const auto output_index{node.getOutputs().at(0)}; if (_ctx.at(output_index).info().isDynamic()) return; const auto input_index{node.getInputs().at(0)}; - OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape()); + UNUSED_RELEASE(input_index); + OP_REQUIRES(_ctx.at(output_index).shape().rank() == 1); } -void OperationValidator::visit(const ir::operation::Shape &node) +void OperationValidator::visit(const ir::operation::ResizeBilinear &node) { const auto output_index{node.getOutputs().at(0)}; + const auto input_index{node.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)}; + if (_ctx.at(output_index).info().isDynamic()) + { return; + } + OP_REQUIRES(_ctx.at(input_index).shape().rank() == 4); + OP_REQUIRES(_ctx.at(output_index).shape().rank() == 4); - const auto input_index{node.getInputs().at(0)}; - UNUSED_RELEASE(input_index); - OP_REQUIRES(_ctx.at(output_index).shape().rank() == 1); + auto align_corners = node.param().align_corners; + auto half_pixel_centers = node.param().half_pixel_centers; + + OP_REQUIRES(!align_corners || !half_pixel_centers); } void OperationValidator::visit(const ir::operation::Reverse &node) @@ -972,35 +961,11 @@ void OperationValidator::visit(const ir::operation::While &node) // TODO Add to validate with subgraphs } -void OperationValidator::visit(const ir::operation::Neg &node) -{ - const auto output_index{node.getOutputs().at(0)}; - if (_ctx.at(output_index).info().isDynamic()) - return; +void OperationValidator::visit(const ir::operation::Neg &node) { checkUnaryOp(node); } - const auto input_index{node.getInputs().at(0)}; - OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape()); -} +void OperationValidator::visit(const ir::operation::Log &node) { checkUnaryOp(node); } -void OperationValidator::visit(const ir::operation::Log &node) -{ - const auto output_index{node.getOutputs().at(0)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto input_index{node.getInputs().at(0)}; - OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape()); -} - -void OperationValidator::visit(const ir::operation::LogicalNot &node) -{ - const auto output_index{node.getOutputs().at(0)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto input_index{node.getInputs().at(0)}; - OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape()); -} +void OperationValidator::visit(const ir::operation::LogicalNot &node) { checkUnaryOp(node); } void OperationValidator::visit(const ir::operation::SquaredDifference &node) { @@ -1118,5 +1083,25 @@ void OperationValidator::visit(const ir::operation::LogSoftmax &node) OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank()); } + +void OperationValidator::visit(const ir::operation::Quantize &node) +{ + VERBOSE(Quantize) << "Configure Quantize operation" << std::endl; + + OP_REQUIRES(node.getInputs().size() == 1); + OP_REQUIRES(node.getOutputs().size() == 1); + + const auto input_index{node.getInputs().at(0)}; + const auto output_index{node.getOutputs().at(0)}; + + OP_REQUIRES(_ctx.at(input_index).typeInfo().type() == ir::DataType::FLOAT32); + + if (_ctx.at(output_index).info().isDynamic()) + return; + + OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM); + + OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank()); +} } // namespace compiler } // namespace onert |