diff options
Diffstat (limited to 'compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp')
-rw-r--r-- | compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp | 17 |
1 files changed, 7 insertions, 10 deletions
diff --git a/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp b/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp index cf0313149..64870c057 100644 --- a/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp +++ b/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp @@ -25,19 +25,16 @@ namespace luci bool CircleMatrixSetDiagGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 2) + if (!GraphBuilder::validate(args, 2)) return false; - if (outputs.size() != 1) - return false; - - const auto &tensors = args.reader.tensors(); - const auto &tensor = tensors.at(inputs.at(0)); + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; + const auto tensors = args.reader.tensors(); + const auto tensor = tensors.at(inputs.at(0)); - if (tensors[outputs[0]]->type != tensor->type) + assert(tensors[outputs[0]] != nullptr && tensor != nullptr); + if (tensors[outputs[0]]->type() != tensor->type()) return false; return true; |