summaryrefslogtreecommitdiff
path: root/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp')
-rw-r--r--compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp17
1 files changed, 7 insertions, 10 deletions
diff --git a/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp b/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp
index cf0313149..64870c057 100644
--- a/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp
+++ b/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp
@@ -25,19 +25,16 @@ namespace luci
bool CircleMatrixSetDiagGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 2)
+ if (!GraphBuilder::validate(args, 2))
return false;
- if (outputs.size() != 1)
- return false;
-
- const auto &tensors = args.reader.tensors();
- const auto &tensor = tensors.at(inputs.at(0));
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
+ const auto tensors = args.reader.tensors();
+ const auto tensor = tensors.at(inputs.at(0));
- if (tensors[outputs[0]]->type != tensor->type)
+ assert(tensors[outputs[0]] != nullptr && tensor != nullptr);
+ if (tensors[outputs[0]]->type() != tensor->type())
return false;
return true;