summaryrefslogtreecommitdiff
path: root/runtime/onert/core/src/compiler/OperationValidator.cc
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/onert/core/src/compiler/OperationValidator.cc')
-rw-r--r--runtime/onert/core/src/compiler/OperationValidator.cc171
1 files changed, 78 insertions, 93 deletions
diff --git a/runtime/onert/core/src/compiler/OperationValidator.cc b/runtime/onert/core/src/compiler/OperationValidator.cc
index 5c545aedd..44496318f 100644
--- a/runtime/onert/core/src/compiler/OperationValidator.cc
+++ b/runtime/onert/core/src/compiler/OperationValidator.cc
@@ -41,6 +41,21 @@ OperationValidator::OperationValidator(const ir::Graph &graph)
{
}
+void OperationValidator::checkUnaryOp(const ir::Operation &node)
+{
+ const auto output_index{node.getOutputs().at(0)};
+ const auto input_index{node.getInputs().at(0)};
+
+ // Check if I/O types match
+ OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type());
+
+ if (_ctx.at(output_index).info().isDynamic())
+ return;
+
+ // Check if I/O shapes match
+ OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
+}
+
void OperationValidator::operator()()
{
// There is no reason for each subgraph to have subgraphs since compiler has subgraphs when
@@ -53,16 +68,7 @@ void OperationValidator::operator()()
[&](const ir::OperationIndex &, const ir::Operation &node) { node.accept(*this); });
}
-void OperationValidator::visit(const ir::operation::Abs &node)
-{
- const auto output_index{node.getOutputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
- return;
-
- const auto input_index{node.getInputs().at(0)};
-
- OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
-}
+void OperationValidator::visit(const ir::operation::Abs &node) { checkUnaryOp(node); }
void OperationValidator::visit(const ir::operation::AvgPool2D &node)
{
@@ -292,17 +298,7 @@ void OperationValidator::visit(const ir::operation::RNN &node)
num_units == _ctx.at(hidden_state_out_index).shape().dim(1));
}
-void OperationValidator::visit(const ir::operation::Round &node)
-{
- const auto output_index{node.getOutputs().at(0)};
- const auto input_index{node.getInputs().at(ir::operation::Round::Input::INPUT)};
-
- OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type());
-
- if (_ctx.at(output_index).info().isDynamic())
- return;
- OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
-}
+void OperationValidator::visit(const ir::operation::Round &node) { checkUnaryOp(node); }
void OperationValidator::visit(const ir::operation::SpaceToBatchND &node)
{
@@ -393,17 +389,7 @@ void OperationValidator::visit(const ir::operation::EmbeddingLookup &node)
}
}
-void OperationValidator::visit(const ir::operation::Exp &node)
-{
- const auto output_index{node.getOutputs().at(0)};
- const auto input_index{node.getInputs().at(ir::operation::Exp::Input::INPUT)};
-
- OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type());
-
- if (_ctx.at(output_index).info().isDynamic())
- return;
- OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
-}
+void OperationValidator::visit(const ir::operation::Exp &node) { checkUnaryOp(node); }
void OperationValidator::visit(const ir::operation::ExpandDims &node)
{
@@ -419,17 +405,7 @@ void OperationValidator::visit(const ir::operation::ExpandDims &node)
OP_REQUIRES(_ctx.at(axis_index).shape().rank() <= 1);
}
-void OperationValidator::visit(const ir::operation::Floor &node)
-{
- const auto output_index{node.getOutputs().at(0)};
- const auto input_index{node.getInputs().at(ir::operation::Floor::Input::INPUT)};
-
- OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type());
-
- if (_ctx.at(output_index).info().isDynamic())
- return;
- OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
-}
+void OperationValidator::visit(const ir::operation::Floor &node) { checkUnaryOp(node); }
void OperationValidator::visit(const ir::operation::HashtableLookup &node)
{
@@ -789,6 +765,25 @@ void OperationValidator::visit(const ir::operation::LSTM &node)
}
}
+void OperationValidator::visit(const ir::operation::L2Normalization &node)
+{
+ const auto ofm_index{node.getOutputs().at(0)};
+ if (_ctx.at(ofm_index).info().isDynamic())
+ return;
+
+ const auto ifm_index{node.getInputs().at(ir::operation::L2Normalization::Input::INPUT)};
+
+ auto ifm_shape = _ctx.at(ifm_index).shape();
+ auto ofm_shape = _ctx.at(ofm_index).shape();
+
+ OP_REQUIRES(ifm_shape.rank() == ofm_shape.rank());
+
+ for (auto i = 0; i < ifm_shape.rank(); i++)
+ {
+ OP_REQUIRES(ifm_shape.dim(i) == ofm_shape.dim(i));
+ }
+}
+
void OperationValidator::visit(const ir::operation::Unpack &node)
{
const auto num{node.param().num};
@@ -904,45 +899,39 @@ void OperationValidator::visit(const ir::operation::Split &node)
OP_REQUIRES(_ctx.at(input_index).shape().dim(axis) % num_splits == 0);
}
-void OperationValidator::visit(const ir::operation::Cos &node)
-{
- const auto output_index{node.getOutputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
- return;
+void OperationValidator::visit(const ir::operation::Cos &node) { checkUnaryOp(node); }
- const auto input_index{node.getInputs().at(0)};
- OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
-}
-
-void OperationValidator::visit(const ir::operation::Sin &node)
-{
- const auto output_index{node.getOutputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
- return;
+void OperationValidator::visit(const ir::operation::Sin &node) { checkUnaryOp(node); }
- const auto input_index{node.getInputs().at(0)};
- OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
-}
+void OperationValidator::visit(const ir::operation::RSQRT &node) { checkUnaryOp(node); }
-void OperationValidator::visit(const ir::operation::RSQRT &node)
+void OperationValidator::visit(const ir::operation::Shape &node)
{
const auto output_index{node.getOutputs().at(0)};
if (_ctx.at(output_index).info().isDynamic())
return;
const auto input_index{node.getInputs().at(0)};
- OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
+ UNUSED_RELEASE(input_index);
+ OP_REQUIRES(_ctx.at(output_index).shape().rank() == 1);
}
-void OperationValidator::visit(const ir::operation::Shape &node)
+void OperationValidator::visit(const ir::operation::ResizeBilinear &node)
{
const auto output_index{node.getOutputs().at(0)};
+ const auto input_index{node.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)};
+
if (_ctx.at(output_index).info().isDynamic())
+ {
return;
+ }
+ OP_REQUIRES(_ctx.at(input_index).shape().rank() == 4);
+ OP_REQUIRES(_ctx.at(output_index).shape().rank() == 4);
- const auto input_index{node.getInputs().at(0)};
- UNUSED_RELEASE(input_index);
- OP_REQUIRES(_ctx.at(output_index).shape().rank() == 1);
+ auto align_corners = node.param().align_corners;
+ auto half_pixel_centers = node.param().half_pixel_centers;
+
+ OP_REQUIRES(!align_corners || !half_pixel_centers);
}
void OperationValidator::visit(const ir::operation::Reverse &node)
@@ -972,35 +961,11 @@ void OperationValidator::visit(const ir::operation::While &node)
// TODO Add to validate with subgraphs
}
-void OperationValidator::visit(const ir::operation::Neg &node)
-{
- const auto output_index{node.getOutputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
- return;
+void OperationValidator::visit(const ir::operation::Neg &node) { checkUnaryOp(node); }
- const auto input_index{node.getInputs().at(0)};
- OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
-}
+void OperationValidator::visit(const ir::operation::Log &node) { checkUnaryOp(node); }
-void OperationValidator::visit(const ir::operation::Log &node)
-{
- const auto output_index{node.getOutputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
- return;
-
- const auto input_index{node.getInputs().at(0)};
- OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
-}
-
-void OperationValidator::visit(const ir::operation::LogicalNot &node)
-{
- const auto output_index{node.getOutputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
- return;
-
- const auto input_index{node.getInputs().at(0)};
- OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
-}
+void OperationValidator::visit(const ir::operation::LogicalNot &node) { checkUnaryOp(node); }
void OperationValidator::visit(const ir::operation::SquaredDifference &node)
{
@@ -1118,5 +1083,25 @@ void OperationValidator::visit(const ir::operation::LogSoftmax &node)
OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank());
}
+
+void OperationValidator::visit(const ir::operation::Quantize &node)
+{
+ VERBOSE(Quantize) << "Configure Quantize operation" << std::endl;
+
+ OP_REQUIRES(node.getInputs().size() == 1);
+ OP_REQUIRES(node.getOutputs().size() == 1);
+
+ const auto input_index{node.getInputs().at(0)};
+ const auto output_index{node.getOutputs().at(0)};
+
+ OP_REQUIRES(_ctx.at(input_index).typeInfo().type() == ir::DataType::FLOAT32);
+
+ if (_ctx.at(output_index).info().isDynamic())
+ return;
+
+ OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM);
+
+ OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank());
+}
} // namespace compiler
} // namespace onert