diff options
Diffstat (limited to 'compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp')
-rw-r--r-- | compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp | 22 |
1 files changed, 8 insertions, 14 deletions
diff --git a/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp b/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp index 0ac63b017..ac4ce62f5 100644 --- a/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp +++ b/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp @@ -25,27 +25,21 @@ namespace luci bool CircleGreaterEqualGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 2) - { + if (!GraphBuilder::validate(args, 2)) return false; - } - if (outputs.size() != 1) - { - return false; - } - - const auto &tensors = args.reader.tensors(); + const auto &inputs = args.op.inputs; + const auto &outputs = args.op.outputs; + const auto tensors = args.reader.tensors(); - if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type) + assert(tensors[inputs.at(0)] != nullptr && tensors[inputs.at(1)] != nullptr); + if (tensors[inputs.at(0)]->type() != tensors[inputs.at(1)]->type()) { return false; } - return tensors[outputs[0]]->type == circle::TensorType::TensorType_BOOL; + assert(tensors[outputs[0]] != nullptr); + return tensors[outputs[0]]->type() == circle::TensorType::TensorType_BOOL; } CircleNode *CircleGreaterEqualGraphBuilder::build_node(const circle::OperatorT &, |