diff options
Diffstat (limited to 'runtime/onert/core/src/compiler/OperationValidator.cc')
-rw-r--r-- | runtime/onert/core/src/compiler/OperationValidator.cc | 1053 |
1 files changed, 0 insertions, 1053 deletions
diff --git a/runtime/onert/core/src/compiler/OperationValidator.cc b/runtime/onert/core/src/compiler/OperationValidator.cc deleted file mode 100644 index f7f659e3e..000000000 --- a/runtime/onert/core/src/compiler/OperationValidator.cc +++ /dev/null @@ -1,1053 +0,0 @@ -/* - * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "OperationValidator.h" - -#include <typeinfo> - -#include "ir/Graph.h" -#include "ir/operation/LowerInfo.h" - -#include "util/logging.h" -#include "util/Utils.h" - -#define OP_REQUIRES(EXP) \ - do \ - { \ - if (!(EXP)) \ - throw std::runtime_error("OperationValidator failed at line " + std::to_string(__LINE__)); \ - } while (0) - -namespace onert -{ -namespace compiler -{ - -OperationValidator::OperationValidator(const ir::Graph &graph) - : _graph{graph}, _ctx{graph.operands()}, _current_op_seq_layout{ir::Layout::UNKNOWN} -{ -} - -void OperationValidator::checkUnaryOp(const ir::Operation &node) -{ - const auto output_index{node.getOutputs().at(0)}; - const auto input_index{node.getInputs().at(0)}; - - // Check if I/O types match - OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type()); - - if (_ctx.at(output_index).info().isDynamic()) - return; - - // Check if I/O shapes match - OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape()); -} - -void OperationValidator::operator()() -{ - // There is no reason for each subgraph to have subgraphs since compiler has subgraphs when - // creating Compiler - assert(_graph.subgraphs() == nullptr); - - _current_op_seq_layout = _graph.layout(); - - _graph.operations().iterate( - [&](const ir::OperationIndex &, const ir::Operation &node) { node.accept(*this); }); -} - -void OperationValidator::visit(const ir::operation::BatchMatMul &node) -{ - const auto lhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::LHS)); - const auto rhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::RHS)); - const auto out_index{node.getOutputs().at(0)}; - - // Constant lhs and rhs is not implemented yet - OP_REQUIRES(!_ctx.at(lhs_index).isConstant() && !_ctx.at(rhs_index).isConstant()); - - if (_ctx.at(out_index).info().isDynamic()) - return; - - OP_REQUIRES(_ctx.at(lhs_index).shape().rank() <= 4); - OP_REQUIRES(_ctx.at(rhs_index).shape().rank() <= 4); - OP_REQUIRES(_ctx.at(lhs_index).shape().rank() >= 2); - OP_REQUIRES(_ctx.at(rhs_index).shape().rank() >= 2); -} - -void OperationValidator::visit(const ir::operation::BatchToSpaceND &node) -{ - const auto ofm_index{node.getOutputs().at(0)}; - if (_ctx.at(ofm_index).info().isDynamic()) - return; - - const auto ifm_index{node.getInputs().at(ir::operation::BatchToSpaceND::Input::INPUT)}; - const auto block_size_index{ - node.getInputs().at(ir::operation::BatchToSpaceND::Input::BLOCK_SIZE)}; - - const auto frontend_layout = _current_op_seq_layout; - const auto input_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout); - const auto output_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout); - - // All requirement as per NNAPI specification. - OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4); - OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4); - OP_REQUIRES(_ctx.at(block_size_index).shape().rank() == 1); - - OP_REQUIRES(_ctx.at(block_size_index).shape().dim(0) == 2); - - OP_REQUIRES(_ctx.at(block_size_index).isConstant()); - - OP_REQUIRES(input_shape.C == output_shape.C); -} - -void OperationValidator::visit(const ir::operation::Comparison &node) -{ - const auto output_index{node.getOutputs().at(0)}; - // This validator does not check shape. So checking isDynamic() is skipped. - - const auto lhs_index{node.getInputs().at(ir::operation::Comparison::Input::INPUT0)}; - const auto rhs_index{node.getInputs().at(ir::operation::Comparison::Input::INPUT1)}; - - OP_REQUIRES(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(rhs_index).typeInfo().type()); - OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == ir::DataType::BOOL8); -} - -void OperationValidator::visit(const ir::operation::Softmax &node) -{ - VERBOSE(Softmax) << "Configure SOFTMAX operation" << std::endl; - - const auto output_index{node.getOutputs().at(0)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto input_index{node.getInputs().at(0)}; - - OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank()); -} - -void OperationValidator::visit(const ir::operation::InstanceNorm &node) -{ - const auto ofm_index{node.getOutputs().at(0)}; - if (_ctx.at(ofm_index).info().isDynamic()) - return; - - const auto ifm_index{node.getInputs().at(ir::operation::InstanceNorm::Input::INPUT)}; - const auto gamma_index{node.getInputs().at(ir::operation::InstanceNorm::Input::GAMMA)}; - const auto beta_index{node.getInputs().at(ir::operation::InstanceNorm::Input::BETA)}; - - OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4); - OP_REQUIRES(_ctx.at(ifm_index).shape() == _ctx.at(ofm_index).shape()); - OP_REQUIRES(_ctx.at(gamma_index).shape().rank() == 1); - OP_REQUIRES(_ctx.at(beta_index).shape().rank() == 1); -} - -void OperationValidator::visit(const ir::operation::Pool2D &node) -{ - const auto ofm_index{node.getOutputs().at(0)}; - if (_ctx.at(ofm_index).info().isDynamic()) - return; - - const auto ifm_index{node.getInputs().at(ir::operation::Pool2D::Input::INPUT)}; - - OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4); -} - -void OperationValidator::visit(const ir::operation::Permute &node) -{ - VERBOSE(Permute) << "Configure Permute operation" << std::endl; - - const auto output_index{node.getOutputs().at(0)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto input_index{node.getInputs().at(0)}; - - OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank()); -} - -void OperationValidator::visit(const ir::operation::Reduce &node) -{ - VERBOSE(Permute) << "Configure " + node.name() + " operation" << std::endl; - - const auto output_index{node.getOutputs().at(0)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto input_index{node.getInputs().at(ir::operation::Reduce::Input::INPUT)}; - const auto input_shape = _ctx.at(input_index).shape(); - const auto output_shape = _ctx.at(output_index).shape(); - - OP_REQUIRES(input_shape.rank() <= 4); - OP_REQUIRES(output_shape.rank() <= input_shape.rank()); - - // NOTE For the 4-dimensions, if the rank of input and output are different, this runtime only - // supports cases reducing height and width or reducing depth. - // TODO We have to support all cases of dimensions up to 4. - // For correct permuting, we have to set output's shape to be equal in dimension position of the - // input. But the positions of the same dimensions in the input and output may be set differently. - // For example {2,3,4,5}(input's shape) can be reduced to {3,5}(output's shape). The original - // output shape should be {1,3,1,5}, but real output shape may be {3,5}. If you simply try to - // extend it in 4 dimensions, it should be {1,1,3,5}. - // Even if output shape is changed to {1,3,1,5}, there is another problem. It is that shape of - // output tensor used at next operation is changed to {1,3,1,5} after this operation even if the - // next operation is not desired. - if (input_shape.rank() == 4 && input_shape.rank() != output_shape.rank()) - { - if (output_shape.rank() == 2) - { - // Reducing HW - OP_REQUIRES(input_shape.dim(0) == output_shape.dim(0) && - input_shape.dim(3) == output_shape.dim(1)); - } - else if (output_shape.rank() == 3) - { - // Reducing C or - // (Reducing H and C(input and output) == 1) or (Reducing W and C(input and output) == 1) - OP_REQUIRES((input_shape.dim(0) == output_shape.dim(0) && - input_shape.dim(1) == output_shape.dim(1) && - input_shape.dim(2) == output_shape.dim(2)) || - (input_shape.dim(0) == output_shape.dim(0) && - (input_shape.dim(1) == output_shape.dim(1) || - input_shape.dim(2) == output_shape.dim(1)) && - input_shape.dim(3) == 1 && output_shape.dim(2) == 1)); - } - } -} - -void OperationValidator::visit(const ir::operation::Transpose &node) -{ - const auto output_index{node.getOutputs().at(0)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto input_index{node.getInputs().at(ir::operation::Transpose::Input::INPUT)}; - const auto &perm{node.param().perm}; - - const auto &output_shape = _ctx.at(output_index).shape(); - const auto &input_shape = _ctx.at(input_index).shape(); - - OP_REQUIRES(input_shape.rank() == static_cast<int>(perm.size())); - OP_REQUIRES(input_shape.rank() == output_shape.rank()); -} - -void OperationValidator::visit(const ir::operation::RNN &node) -{ - // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn - // TODO Support dynamic rnn - const auto output_index{node.getOutputs().at(ir::operation::RNN::Output::OUTPUT)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto hidden_state_out_index{ - node.getOutputs().at(ir::operation::RNN::Output::HIDDEN_STATE_OUT)}; - - const auto input_index{node.getInputs().at(ir::operation::RNN::Input::INPUT)}; - const auto weights_index{node.getInputs().at(ir::operation::RNN::Input::WEIGHTS)}; - const auto recurrent_weights_index{ - node.getInputs().at(ir::operation::RNN::Input::RECURRENT_WEIGHTS)}; - const auto bias_index{node.getInputs().at(ir::operation::RNN::Input::BIAS)}; - const auto hidden_state_in_index{node.getInputs().at(ir::operation::RNN::Input::HIDDEN_STATE_IN)}; - - const auto batch_size = _ctx.at(output_index).shape().dim(0); - const auto num_units = _ctx.at(output_index).shape().dim(1); - - OP_REQUIRES(_ctx.at(output_index).shape().rank() == 2 && - _ctx.at(hidden_state_out_index).shape().rank() == 2 && - _ctx.at(input_index).shape().rank() == 2 && - _ctx.at(weights_index).shape().rank() == 2 && - _ctx.at(recurrent_weights_index).shape().rank() == 2 && - _ctx.at(hidden_state_in_index).shape().rank() == 2); - OP_REQUIRES(_ctx.at(bias_index).shape().rank() == 1); - - OP_REQUIRES(batch_size == _ctx.at(input_index).shape().dim(0) && - batch_size == _ctx.at(hidden_state_in_index).shape().dim(0) && - batch_size == _ctx.at(hidden_state_out_index).shape().dim(0)); - OP_REQUIRES(_ctx.at(input_index).shape().dim(1) == _ctx.at(weights_index).shape().dim(1)); - - OP_REQUIRES(num_units == _ctx.at(weights_index).shape().dim(0) && - num_units == _ctx.at(recurrent_weights_index).shape().dim(0) && - num_units == _ctx.at(bias_index).shape().dim(0)); - OP_REQUIRES(num_units == _ctx.at(output_index).shape().dim(1) && - num_units == _ctx.at(recurrent_weights_index).shape().dim(1) && - num_units == _ctx.at(hidden_state_in_index).shape().dim(1) && - num_units == _ctx.at(hidden_state_out_index).shape().dim(1)); -} - -void OperationValidator::visit(const ir::operation::SpaceToBatchND &node) -{ - const auto ofm_index{node.getOutputs().at(0)}; - if (_ctx.at(ofm_index).info().isDynamic()) - return; - - const auto ifm_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::INPUT)}; - const auto block_size_index{ - node.getInputs().at(ir::operation::SpaceToBatchND::Input::BLOCK_SIZE)}; - const auto paddings_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::PADDINGS)}; - - const auto frontend_layout = _current_op_seq_layout; - const auto input_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout); - const auto output_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout); - - // All requirement as per NNAPI specification. - OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4); - OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4); - OP_REQUIRES(_ctx.at(block_size_index).shape().rank() == 1); - OP_REQUIRES(_ctx.at(paddings_index).shape().rank() == 2); - - OP_REQUIRES(_ctx.at(block_size_index).shape().dim(0) == 2); - OP_REQUIRES(_ctx.at(paddings_index).shape().dim(0) == 2); - OP_REQUIRES(_ctx.at(paddings_index).shape().dim(1) == 2); - - OP_REQUIRES(_ctx.at(block_size_index).isConstant()); - OP_REQUIRES(_ctx.at(paddings_index).isConstant()); - - OP_REQUIRES(input_shape.C == output_shape.C); -} - -void OperationValidator::visit(const ir::operation::SpaceToDepth &node) -{ - const auto ofm_index{node.getOutputs().at(0)}; - if (_ctx.at(ofm_index).info().isDynamic()) - return; - - const auto ifm_index{node.getInputs().at(ir::operation::SpaceToDepth::Input::INPUT)}; - - const auto frontend_layout = _current_op_seq_layout; - const auto input_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout); - const auto output_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout); - const auto block_size = node.param().block_size; - - // All assertions as per NNAPI specification. - OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4); - OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4); - OP_REQUIRES((block_size >= 1) && (input_shape.H % block_size == 0) && - (input_shape.W % block_size == 0)); - OP_REQUIRES(input_shape.N == output_shape.N); - OP_REQUIRES(input_shape.C * block_size * block_size == output_shape.C); -} - -void OperationValidator::visit(const ir::operation::ElementwiseActivation &node) -{ - checkUnaryOp(node); -} - -void OperationValidator::visit(const ir::operation::ElementwiseBinary &node) -{ - const auto output_index{node.getOutputs().at(0)}; - const auto lhs_index{node.getInputs().at(ir::operation::ElementwiseBinary::Input::LHS)}; - const auto rhs_index{node.getInputs().at(ir::operation::ElementwiseBinary::Input::RHS)}; - - OP_REQUIRES(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(rhs_index).typeInfo().type()); - OP_REQUIRES(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(output_index).typeInfo().type()); -} - -void OperationValidator::visit(const ir::operation::ElementwiseUnary &node) -{ - const auto output_index{node.getOutputs().at(0)}; - const auto input_index{node.getInputs().at(ir::operation::ElementwiseUnary::Input::INPUT)}; - - OP_REQUIRES(node.getInputs().size() == 1); - OP_REQUIRES(node.getOutputs().size() == 1); - - // Check if I/O types match - if (node.param().op_type == ir::operation::ElementwiseUnary::Type::DEQUANTIZE) - { - OP_REQUIRES(_ctx.at(input_index).typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM); - OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == ir::DataType::FLOAT32); - } - else if (node.param().op_type == ir::operation::ElementwiseUnary::Type::QUANTIZE) - { - OP_REQUIRES(_ctx.at(input_index).typeInfo().type() == ir::DataType::FLOAT32); - OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM); - } - else if (node.param().op_type != ir::operation::ElementwiseUnary::Type::CAST) - { - OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type()); - } - - if (_ctx.at(output_index).info().isDynamic()) - return; - - OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape()); -} - -void OperationValidator::visit(const ir::operation::EmbeddingLookup &node) -{ - const auto output_index{node.getOutputs().at(0)}; - const auto lookups_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::LOOKUPS)}; - const auto values_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::VALUES)}; - - const auto &output_obj = _ctx.at(output_index); - const auto &lookups_obj = _ctx.at(lookups_index); - const auto &values_obj = _ctx.at(values_index); - - // Verify operand here, not at SimpleEmbeddingLookup::configure() to avoid acl's modifying - // TensorShape sometimes(Issue: https://github.sec.samsung.net/STAR/nnfw/issues/729) - { - OP_REQUIRES(lookups_obj.typeInfo().type() == ir::DataType::INT32); - - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto &output_shape = output_obj.shape(); - const auto &lookups_shape = lookups_obj.shape(); - const auto &values_shape = values_obj.shape(); - - OP_REQUIRES(lookups_shape.rank() == 1); - OP_REQUIRES(values_shape.rank() >= 2); - - // output should be a n-D tensor with the same rank and shape as the values tensor, except for - // the first dimension which has the same size as lookups' only dimension. - OP_REQUIRES(output_shape.rank() == values_shape.rank()); - OP_REQUIRES(output_shape.dim(0) == lookups_shape.dim(0)); - for (int n = 1; n < output_shape.rank(); ++n) - { - OP_REQUIRES(output_shape.dim(n) == values_shape.dim(n)); - } - } -} - -void OperationValidator::visit(const ir::operation::ExpandDims &node) -{ - const auto output_index{node.getOutputs().at(0)}; - const auto input_index{node.getInputs().at(ir::operation::ExpandDims::Input::INPUT)}; - const auto axis_index{node.getInputs().at(ir::operation::ExpandDims::Input::AXIS)}; - - OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type()); - OP_REQUIRES(_ctx.at(axis_index).typeInfo().type() == ir::DataType::INT32); - - if (_ctx.at(axis_index).info().isDynamic()) - return; - OP_REQUIRES(_ctx.at(axis_index).shape().rank() <= 1); -} - -void OperationValidator::visit(const ir::operation::HashtableLookup &node) -{ - const auto output_index{node.getOutputs().at(ir::operation::HashtableLookup::Output::OUTPUT)}; - const auto hits_index{node.getOutputs().at(ir::operation::HashtableLookup::Output::HITS)}; - - const auto lookups_index{node.getInputs().at(ir::operation::HashtableLookup::Input::LOOKUPS)}; - const auto keys_index{node.getInputs().at(ir::operation::HashtableLookup::Input::KEYS)}; - const auto values_index{node.getInputs().at(ir::operation::HashtableLookup::Input::VALUES)}; - - const auto &output_obj = _ctx.at(output_index); - const auto &hits_obj = _ctx.at(hits_index); - - const auto &lookups_obj = _ctx.at(lookups_index); - const auto &keys_obj = _ctx.at(keys_index); - const auto &values_obj = _ctx.at(values_index); - - OP_REQUIRES(lookups_obj.typeInfo().type() == ir::DataType::INT32); - OP_REQUIRES(keys_obj.typeInfo().type() == ir::DataType::INT32); - OP_REQUIRES(hits_obj.typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM); - - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto &output_shape = output_obj.shape(); - const auto &lookups_shape = lookups_obj.shape(); - const auto &keys_shape = keys_obj.shape(); - const auto &values_shape = values_obj.shape(); - - OP_REQUIRES(values_shape.rank() == output_shape.rank()); - OP_REQUIRES(lookups_shape.rank() == 1); - OP_REQUIRES(keys_shape.rank() == 1); - OP_REQUIRES(values_shape.dim(0) == keys_shape.dim(0)); - OP_REQUIRES(lookups_shape.dim(0) == output_shape.dim(0)); -} - -void OperationValidator::visit(const ir::operation::TransposeConv &node) -{ - // param check - OP_REQUIRES((node.param().padding.type == ir::PaddingType::SAME) || - (node.param().padding.type == ir::PaddingType::VALID)); - - // shape check - const auto ofm_index{node.getOutputs().at(0)}; - if (_ctx.at(ofm_index).info().isDynamic()) - return; - - const auto ifm_index{node.getInputs().at(ir::operation::TransposeConv::Input::INPUT)}; - const auto ker_index{node.getInputs().at(ir::operation::TransposeConv::Input::KERNEL)}; - - // Only 4D tensors are supported - OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4); - OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == _ctx.at(ifm_index).shape().rank()); - OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == _ctx.at(ker_index).shape().rank()); - - const auto frontend_layout = _current_op_seq_layout; - const auto ofm_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout); - const auto ifm_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout); - // The kernel has only IHWO layout on frontend - // So ker_shape is treated here below - // I -> N - // H -> H - // W -> W - // O -> C - const auto ker_shape = _ctx.at(ker_index).shape().asFeature(ir::Layout::NHWC); - - OP_REQUIRES(ifm_shape.N == ofm_shape.N); - OP_REQUIRES(ifm_shape.C == ker_shape.C); - OP_REQUIRES(ker_shape.N == ofm_shape.C); -} - -void OperationValidator::visit(const ir::operation::Gather &node) -{ - const auto ofm_index{node.getOutputs().at(0)}; - if (_ctx.at(ofm_index).info().isDynamic()) - return; - - const auto ifm_index{node.getInputs().at(ir::operation::Gather::Input::INPUT)}; - const auto indices_index{node.getInputs().at(ir::operation::Gather::Input::INDICES)}; - - const auto ifm_shape = _ctx.at(ifm_index).shape(); - const auto indices_shape = _ctx.at(indices_index).shape(); - const auto ofm_shape = _ctx.at(ofm_index).shape(); - - OP_REQUIRES(ifm_shape.rank() <= 4); - OP_REQUIRES(indices_shape.rank() <= 3); - OP_REQUIRES(ofm_shape.rank() <= 4); -} - -void OperationValidator::visit(const ir::operation::DepthToSpace &node) -{ - // param check - int32_t block_size = node.param().block_size; - - OP_REQUIRES(block_size > 0); - - // shape check - const auto output_index{node.getOutputs().at(0)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto input_index{node.getInputs().at(ir::operation::DepthToSpace::Input::INPUT)}; - - const auto frontend_layout = _current_op_seq_layout; - const auto output_shape = _ctx.at(output_index).shape().asFeature(frontend_layout); - const auto input_shape = _ctx.at(input_index).shape().asFeature(frontend_layout); - - OP_REQUIRES(_ctx.at(input_index).shape().rank() == 4); - OP_REQUIRES(_ctx.at(output_index).shape().rank() == 4); - - { - OP_REQUIRES(output_shape.N == input_shape.N); - OP_REQUIRES(output_shape.H == input_shape.H * block_size); - OP_REQUIRES(output_shape.W == input_shape.W * block_size); - OP_REQUIRES(input_shape.C % (block_size * block_size) == 0); - OP_REQUIRES(output_shape.C == input_shape.C / (block_size * block_size)); - } -} - -void OperationValidator::visit(const ir::operation::Pack &node) -{ - // param check - const auto num{node.param().num}; - const auto axis{node.param().axis}; - OP_REQUIRES(num == static_cast<int32_t>(node.getInputs().size())); - - const auto output_index{node.getOutputs().at(0)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - // shape check - const auto &output_shape = _ctx.at(output_index).shape(); - const auto output_rank = static_cast<int32_t>(output_shape.rank()); - - const auto input1_index{node.getInputs().at(0)}; - const auto input_shape = _ctx.at(input1_index).shape(); - - OP_REQUIRES(axis >= -output_rank && axis < output_rank); - for (const auto &index : node.getInputs()) - { - OP_REQUIRES(input_shape == _ctx.at(index).shape()); - } -} - -void OperationValidator::visit(const ir::operation::LSTM &node) -{ - // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn - // TODO Support dynamic rnn - const auto output_index{node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto scratch_buffer_index{ - node.getOutputs().at(ir::operation::LSTM::Output::SCRATCH_BUFFER)}; - const auto output_state_out_index{ - node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT_STATE_OUT)}; - const auto cell_state_out_index{ - node.getOutputs().at(ir::operation::LSTM::Output::CELL_STATE_OUT)}; - - const auto input_index{node.getInputs().at(ir::operation::LSTM::Input::INPUT)}; - const auto input_to_input_weights_index{ - node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_INPUT_WEIGHTS)}; - const auto input_to_forget_weights_index{ - node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_FORGET_WEIGHTS)}; - const auto input_to_cell_weights_index{ - node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_CELL_WEIGHTS)}; - const auto input_to_output_weights_index{ - node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_OUTPUT_WEIGHTS)}; - const auto recurrent_to_input_weights_index{ - node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)}; - const auto recurrent_to_forget_weights_index{ - node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_FORGET_WEIGHTS)}; - const auto recurrent_to_cell_weights_index{ - node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_CELL_WEIGHTS)}; - const auto recurrent_to_output_weights_index{ - node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_OUTPUT_WEIGHTS)}; - const auto cell_to_input_weights_index{ - node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_INPUT_WEIGHTS)}; - const auto cell_to_forget_weights_index{ - node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_FORGET_WEIGHTS)}; - const auto cell_to_output_weights_index{ - node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_OUTPUT_WEIGHTS)}; - const auto input_gate_bias_index{ - node.getInputs().at(ir::operation::LSTM::Input::INPUT_GATE_BIAS)}; - const auto forget_gate_bias_index{ - node.getInputs().at(ir::operation::LSTM::Input::FORGET_GATE_BIAS)}; - const auto cell_bias_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_BIAS)}; - const auto output_gate_bias_index{ - node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_GATE_BIAS)}; - const auto projection_weights_index{ - node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_WEIGHTS)}; - const auto projection_bias_index{ - node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_BIAS)}; - const auto output_state_in_index{ - node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_STATE_IN)}; - const auto cell_state_in_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_STATE_IN)}; - - OP_REQUIRES(_ctx.at(scratch_buffer_index).shape().rank() == 2 && - _ctx.at(output_state_out_index).shape().rank() == 2 && - _ctx.at(cell_state_out_index).shape().rank() == 2 && - _ctx.at(output_index).shape().rank() == 2 && - _ctx.at(input_index).shape().rank() == 2 && - _ctx.at(input_to_input_weights_index).shape().rank() == 2 && - _ctx.at(input_to_forget_weights_index).shape().rank() == 2 && - _ctx.at(input_to_cell_weights_index).shape().rank() == 2 && - _ctx.at(input_to_output_weights_index).shape().rank() == 2 && - _ctx.at(recurrent_to_input_weights_index).shape().rank() == 2 && - _ctx.at(recurrent_to_forget_weights_index).shape().rank() == 2 && - _ctx.at(recurrent_to_cell_weights_index).shape().rank() == 2 && - _ctx.at(recurrent_to_output_weights_index).shape().rank() == 2 && - _ctx.at(projection_weights_index).shape().rank() == 2 && - _ctx.at(output_state_in_index).shape().rank() == 2 && - _ctx.at(cell_state_in_index).shape().rank() == 2); - - OP_REQUIRES(_ctx.at(cell_to_input_weights_index).shape().rank() == 1 && - _ctx.at(cell_to_forget_weights_index).shape().rank() == 1 && - _ctx.at(cell_to_output_weights_index).shape().rank() == 1 && - _ctx.at(input_gate_bias_index).shape().rank() == 1 && - _ctx.at(forget_gate_bias_index).shape().rank() == 1 && - _ctx.at(cell_bias_index).shape().rank() == 1 && - _ctx.at(output_gate_bias_index).shape().rank() == 1 && - _ctx.at(projection_bias_index).shape().rank() == 1); - - // CIFG assertion - OP_REQUIRES((_ctx.at(input_to_input_weights_index).shape().dim(0) == 0 && - _ctx.at(input_to_input_weights_index).shape().dim(1) == 0 && - _ctx.at(recurrent_to_input_weights_index).shape().dim(0) == 0 && - _ctx.at(recurrent_to_input_weights_index).shape().dim(1) == 0 && - _ctx.at(input_gate_bias_index).shape().dim(0) == 0 && - _ctx.at(cell_to_input_weights_index).shape().dim(0) == 0) || - (_ctx.at(input_to_input_weights_index).shape().dim(0) != 0 && - _ctx.at(input_to_input_weights_index).shape().dim(1) != 0 && - _ctx.at(recurrent_to_input_weights_index).shape().dim(0) != 0 && - _ctx.at(recurrent_to_input_weights_index).shape().dim(1) != 0 && - _ctx.at(input_gate_bias_index).shape().dim(0) != 0)); - - // Peephole assertion - OP_REQUIRES((_ctx.at(cell_to_forget_weights_index).shape().dim(0) == 0 && - _ctx.at(cell_to_output_weights_index).shape().dim(0) == 0) || - (_ctx.at(cell_to_forget_weights_index).shape().dim(0) != 0 && - _ctx.at(cell_to_output_weights_index).shape().dim(0) != 0)); - - bool has_input_to_input_weights = _ctx.at(input_to_input_weights_index).shape().dim(0) != 0 && - _ctx.at(input_to_input_weights_index).shape().dim(1) != 0; - bool has_recurrent_to_input_weights = - _ctx.at(recurrent_to_input_weights_index).shape().dim(0) != 0 && - _ctx.at(recurrent_to_input_weights_index).shape().dim(1) != 0; - bool has_input_gate_bias = _ctx.at(input_gate_bias_index).shape().dim(0) != 0; - bool has_cell_to_input_weights = _ctx.at(cell_to_input_weights_index).shape().dim(0) != 0; - bool has_cell_to_forget_weights = _ctx.at(cell_to_forget_weights_index).shape().dim(0) != 0; - bool has_cell_to_output_weights = _ctx.at(cell_to_output_weights_index).shape().dim(0) != 0; - bool has_projection_weights = _ctx.at(projection_weights_index).shape().dim(0) != 0 && - _ctx.at(projection_weights_index).shape().dim(1) != 0; - bool has_projection_bias = _ctx.at(projection_bias_index).shape().dim(0); - - // NOTE The cell_to_input_weights do not exist in non-peephole although regular LSTM(non-CIFG). - // true: no CIFG - // false: CIFG - bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights; - - // NOTE The cell_to_input_weights do not exist in regular CIFG although peephole. - // true: peephole - // false: no peephole - bool has_peephole_param = has_cell_to_forget_weights && has_cell_to_output_weights; - - // NOTE The projection weights may have data but the projection bias may not. - bool has_projection_param = has_projection_weights; - - const auto batch_size = _ctx.at(input_index).shape().dim(0); - OP_REQUIRES(batch_size == _ctx.at(output_state_in_index).shape().dim(0) && - batch_size == _ctx.at(cell_state_in_index).shape().dim(0) && - batch_size == _ctx.at(scratch_buffer_index).shape().dim(0) && - batch_size == _ctx.at(output_state_out_index).shape().dim(0) && - batch_size == _ctx.at(cell_state_out_index).shape().dim(0) && - batch_size == _ctx.at(output_index).shape().dim(0)); - - const auto input_size = _ctx.at(input_index).shape().dim(1); - OP_REQUIRES(input_size == _ctx.at(input_to_forget_weights_index).shape().dim(1) && - input_size == _ctx.at(input_to_cell_weights_index).shape().dim(1) && - input_size == _ctx.at(input_to_output_weights_index).shape().dim(1)); - - const auto num_units = _ctx.at(cell_state_out_index).shape().dim(1); - OP_REQUIRES(num_units == _ctx.at(input_to_forget_weights_index).shape().dim(0) && - num_units == _ctx.at(input_to_cell_weights_index).shape().dim(0) && - num_units == _ctx.at(input_to_output_weights_index).shape().dim(0) && - num_units == _ctx.at(recurrent_to_forget_weights_index).shape().dim(0) && - num_units == _ctx.at(recurrent_to_cell_weights_index).shape().dim(0) && - num_units == _ctx.at(recurrent_to_output_weights_index).shape().dim(0) && - num_units == _ctx.at(forget_gate_bias_index).shape().dim(0) && - num_units == _ctx.at(cell_bias_index).shape().dim(0) && - num_units == _ctx.at(output_gate_bias_index).shape().dim(0) && - num_units == _ctx.at(cell_state_in_index).shape().dim(1) && - (((num_units * 3) == _ctx.at(scratch_buffer_index).shape().dim(1)) || - ((num_units * 4) == _ctx.at(scratch_buffer_index).shape().dim(1)))); - - const auto output_size = _ctx.at(output_index).shape().dim(1); - OP_REQUIRES(output_size == _ctx.at(recurrent_to_forget_weights_index).shape().dim(1) && - output_size == _ctx.at(recurrent_to_cell_weights_index).shape().dim(1) && - output_size == _ctx.at(recurrent_to_output_weights_index).shape().dim(1) && - output_size == _ctx.at(output_state_in_index).shape().dim(1) && - output_size == _ctx.at(output_state_out_index).shape().dim(1)); - - if (has_cifg_param) - { - OP_REQUIRES(input_size == _ctx.at(input_to_input_weights_index).shape().dim(1)); - OP_REQUIRES(num_units == _ctx.at(input_to_input_weights_index).shape().dim(0) && - num_units == _ctx.at(recurrent_to_input_weights_index).shape().dim(0) && - (num_units == _ctx.at(cell_to_input_weights_index).shape().dim(0) || - _ctx.at(cell_to_input_weights_index).shape().dim(0) == 0 /* non-peephole */) && - num_units == _ctx.at(input_gate_bias_index).shape().dim(0)); - OP_REQUIRES(output_size == _ctx.at(recurrent_to_input_weights_index).shape().dim(1)); - OP_REQUIRES(has_input_to_input_weights && has_recurrent_to_input_weights && - has_input_gate_bias); - if (has_cell_to_input_weights) - { - // NOTE The cell_to_input_weights exist only in case of non-CIFG and peephole. - OP_REQUIRES(has_peephole_param); - } - OP_REQUIRES(_ctx.at(scratch_buffer_index).shape().dim(1) == num_units * 4); - } - else - { - OP_REQUIRES(_ctx.at(scratch_buffer_index).shape().dim(1) == num_units * 3); - } - - if (has_peephole_param) - { - OP_REQUIRES(num_units == _ctx.at(cell_to_forget_weights_index).shape().dim(0) && - num_units == _ctx.at(cell_to_output_weights_index).shape().dim(0) && - (num_units == _ctx.at(cell_to_input_weights_index).shape().dim(0) || - _ctx.at(cell_to_input_weights_index).shape().dim(0) == 0 /* CIFG */)); - } - - if (has_projection_param) - { - OP_REQUIRES(num_units == _ctx.at(projection_weights_index).shape().dim(1)); - OP_REQUIRES(output_size == _ctx.at(projection_weights_index).shape().dim(0)); - if (has_projection_bias) - { - OP_REQUIRES(output_size == _ctx.at(projection_bias_index).shape().dim(0)); - } - } -} - -void OperationValidator::visit(const ir::operation::L2Normalization &node) -{ - const auto ofm_index{node.getOutputs().at(0)}; - if (_ctx.at(ofm_index).info().isDynamic()) - return; - - const auto ifm_index{node.getInputs().at(ir::operation::L2Normalization::Input::INPUT)}; - - auto ifm_shape = _ctx.at(ifm_index).shape(); - auto ofm_shape = _ctx.at(ofm_index).shape(); - - OP_REQUIRES(ifm_shape.rank() == ofm_shape.rank()); - - for (auto i = 0; i < ifm_shape.rank(); i++) - { - OP_REQUIRES(ifm_shape.dim(i) == ofm_shape.dim(i)); - } -} - -void OperationValidator::visit(const ir::operation::Unpack &node) -{ - const auto num{node.param().num}; - OP_REQUIRES(num == static_cast<int32_t>(node.getOutputs().size())); - const auto axis{node.param().axis}; - - const auto output_index{node.getInputs().at(0)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto input_index{node.getInputs().at(ir::operation::Unpack::Input::INPUT)}; - - const auto &input_shape = _ctx.at(input_index).shape(); - const auto input_rank = static_cast<int32_t>(input_shape.rank()); - - OP_REQUIRES(axis >= -input_rank && axis < input_rank); -} - -void OperationValidator::visit(const ir::operation::Pad &node) -{ - const auto pad_index{node.getInputs().at(ir::operation::Pad::Input::PAD)}; - OP_REQUIRES(_ctx.at(pad_index).typeInfo().type() == ir::DataType::INT32); - - const auto output_index{node.getInputs().at(0)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto input_index{node.getInputs().at(ir::operation::Pad::Input::INPUT)}; - - const auto &pad_shape = _ctx.at(pad_index).shape(); - const auto input_rank = static_cast<int32_t>(_ctx.at(input_index).shape().rank()); - - OP_REQUIRES(pad_shape.rank() == 2); - OP_REQUIRES(pad_shape.dim(0) == input_rank); - OP_REQUIRES(pad_shape.dim(1) == 2); - OP_REQUIRES(_ctx.at(input_index).shape().rank() == _ctx.at(output_index).shape().rank()); -} - -void OperationValidator::visit(const ir::operation::Select &node) -{ - const auto output_index{node.getOutputs().at(0)}; - // This validator does not check shape. So checking isDynamic() is skipped. - - const auto condition_index{node.getInputs().at(ir::operation::Select::Input::CONDITION)}; - const auto input_true_index{node.getInputs().at(ir::operation::Select::Input::INPUT_TRUE)}; - const auto input_false_index{node.getInputs().at(ir::operation::Select::Input::INPUT_FALSE)}; - UNUSED_RELEASE(output_index); - UNUSED_RELEASE(input_true_index); - UNUSED_RELEASE(input_false_index); - - OP_REQUIRES(_ctx.at(condition_index).typeInfo().type() == ir::DataType::BOOL8); -} - -void OperationValidator::visit(const ir::operation::StridedSlice &node) -{ - const auto output_index{node.getOutputs().at(0)}; - const auto input_index{node.getInputs().at(ir::operation::StridedSlice::Input::INPUT)}; - const auto starts_index{node.getInputs().at(ir::operation::StridedSlice::Input::STARTS)}; - const auto ends_index{node.getInputs().at(ir::operation::StridedSlice::Input::ENDS)}; - const auto strides_index{node.getInputs().at(ir::operation::StridedSlice::Input::STRIDES)}; - - UNUSED_RELEASE(starts_index); - UNUSED_RELEASE(ends_index); - UNUSED_RELEASE(strides_index); - - OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type()); - - if (_ctx.at(output_index).info().isDynamic()) - return; - - OP_REQUIRES(_ctx.at(input_index).shape().rank() <= 4); -} - -void OperationValidator::visit(const ir::operation::Split &node) -{ - const auto input_index{node.getInputs().at(ir::operation::Split::Input::INPUT)}; - - if (_ctx.at(input_index).info().isDynamic()) - return; - - const auto num_splits = node.param().num_splits; - const auto input_rank = _ctx.at(input_index).shape().rank(); - const auto axis = node.param().axis < 0 ? node.param().axis + input_rank : node.param().axis; - - OP_REQUIRES(num_splits > 0 && num_splits <= 0xFFFF); - OP_REQUIRES(axis >= 0 && axis < input_rank); - OP_REQUIRES(node.getOutputs().size() == static_cast<uint32_t>(num_splits)); - - OP_REQUIRES(_ctx.at(input_index).shape().dim(axis) % num_splits == 0); -} - -void OperationValidator::visit(const ir::operation::Shape &node) -{ - const auto output_index{node.getOutputs().at(0)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto input_index{node.getInputs().at(0)}; - UNUSED_RELEASE(input_index); - OP_REQUIRES(_ctx.at(output_index).shape().rank() == 1); -} - -void OperationValidator::visit(const ir::operation::ResizeBilinear &node) -{ - const auto output_index{node.getOutputs().at(0)}; - const auto input_index{node.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)}; - - if (_ctx.at(output_index).info().isDynamic()) - { - return; - } - OP_REQUIRES(_ctx.at(input_index).shape().rank() == 4); - OP_REQUIRES(_ctx.at(output_index).shape().rank() == 4); - - auto align_corners = node.param().align_corners; - auto half_pixel_centers = node.param().half_pixel_centers; - - OP_REQUIRES(!align_corners || !half_pixel_centers); -} - -void OperationValidator::visit(const ir::operation::Reverse &node) -{ - const auto output_index{node.getOutputs().at(0)}; - const auto input_index{node.getInputs().at(ir::operation::Reverse::Input::INPUT)}; - const auto axis_index{node.getInputs().at(ir::operation::Reverse::Input::AXIS)}; - - OP_REQUIRES(_ctx.at(axis_index).typeInfo().type() == ir::DataType::INT32); - OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type()); - - if (_ctx.at(output_index).info().isDynamic()) - return; - OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape()); -} - -void OperationValidator::visit(const ir::operation::If &) -{ - // TODO Add to validate with subgraphs -} - -void OperationValidator::visit(const ir::operation::While &node) -{ - // This validator does not check shape. So checking isDynamic() is skipped. - - OP_REQUIRES(node.getInputs().size() == node.getOutputs().size()); - // TODO Add to validate with subgraphs -} - -void OperationValidator::visit(const ir::operation::SquaredDifference &node) -{ - const auto output_index{node.getOutputs().at(0)}; - const auto lhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::LHS)}; - const auto rhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::RHS)}; - - // Check for Type equivalence - OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(lhs_index).typeInfo().type()); - OP_REQUIRES(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(rhs_index).typeInfo().type()); - - // Check for dimension constraints - if (_ctx.at(output_index).info().isDynamic()) - return; - - auto output_shape = _ctx.at(output_index).shape(); - auto lhs_shape = _ctx.at(lhs_index).shape(); - auto rhs_shape = _ctx.at(rhs_index).shape(); - // Check for output rank - OP_REQUIRES(output_shape.rank() == std::max(lhs_shape.rank(), rhs_shape.rank())); - auto min_rank = std::min(lhs_shape.rank(), rhs_shape.rank()); - - for (int idx = 1; idx <= min_rank; idx++) - { - int l_idx = lhs_shape.rank() - idx; - int r_idx = rhs_shape.rank() - idx; - int out_idx = output_shape.rank() - idx; - - OP_REQUIRES((l_idx >= 0) && (r_idx >= 0) && (out_idx >= 0)); - - auto l_dims = lhs_shape.dim(l_idx); - auto r_dims = rhs_shape.dim(r_idx); - auto out_dims = output_shape.dim(out_idx); - - OP_REQUIRES(((l_dims == r_dims) && (out_dims == l_dims)) || - ((l_dims == 1) && (out_dims == r_dims)) || ((r_dims == 1) && (out_dims == l_dims))); - } - auto &tmp_shape = (lhs_shape.rank() > rhs_shape.rank()) ? lhs_shape : rhs_shape; - for (int idx = min_rank + 1; idx <= output_shape.rank(); idx++) - { - int out_idx = output_shape.rank() - idx; - int tmp_idx = tmp_shape.rank() - idx; - - OP_REQUIRES((out_idx >= 0) && (tmp_idx >= 0) && - (output_shape.dim(out_idx) == tmp_shape.dim(tmp_idx))); - } -} -void OperationValidator::visit(const ir::operation::Tile &node) -{ - const auto output_index{node.getOutputs().at(0)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto input_index{node.getInputs().at(0)}; - const auto multiple_index{node.getInputs().at(1)}; - - OP_REQUIRES(_ctx.at(multiple_index).shape().rank() == 1); - OP_REQUIRES(_ctx.at(multiple_index).shape().dim(0) == _ctx.at(input_index).shape().rank()); - OP_REQUIRES(_ctx.at(input_index).shape().rank() == _ctx.at(output_index).shape().rank()); -} - -void OperationValidator::visit(const ir::operation::Range &node) -{ - const auto output_index{node.getOutputs().at(0)}; - const auto start_index{node.getInputs().at(ir::operation::Range::Input::START)}; - const auto limit_index{node.getInputs().at(ir::operation::Range::Input::LIMIT)}; - const auto delta_index{node.getInputs().at(ir::operation::Range::Input::DELTA)}; - - // Check for dimension constraints - if (_ctx.at(output_index).info().isDynamic()) - return; - - OP_REQUIRES(_ctx.at(start_index).shape().rank() == 0); - OP_REQUIRES(_ctx.at(limit_index).shape().rank() == 0); - OP_REQUIRES(_ctx.at(delta_index).shape().rank() == 0); -} - -void OperationValidator::visit(const ir::operation::MatrixBandPart &node) -{ - const auto output_index{node.getOutputs().at(0)}; - const auto input_index{node.getInputs().at(ir::operation::MatrixBandPart::Input::INPUT)}; - const auto num_lower_index{ - node.getInputs().at(ir::operation::MatrixBandPart::Input::NUM_LOWER_DIAG)}; - const auto num_upper_index{ - node.getInputs().at(ir::operation::MatrixBandPart::Input::NUM_UPPER_DIAG)}; - - // Check for dimension constraints - if (_ctx.at(output_index).info().isDynamic()) - return; - - OP_REQUIRES(_ctx.at(input_index).shape().rank() >= 2); // input must be more than 2 dim matrix - OP_REQUIRES(_ctx.at(num_upper_index).shape().rank() == 0); // num_lower must be scalar - OP_REQUIRES(_ctx.at(num_lower_index).shape().rank() == 0); // num_upper must be scalar -} - -void OperationValidator::visit(const ir::operation::LogSoftmax &node) -{ - VERBOSE(LogSoftmax) << "Configure LOGSOFTMAX operation" << std::endl; - - const auto output_index{node.getOutputs().at(0)}; - if (_ctx.at(output_index).info().isDynamic()) - return; - - const auto input_index{node.getInputs().at(0)}; - - OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank()); -} - -} // namespace compiler -} // namespace onert |