diff options
Diffstat (limited to 'compiler/luci/import/src/Nodes/CircleGatherNd.cpp')
-rw-r--r-- | compiler/luci/import/src/Nodes/CircleGatherNd.cpp | 16 |
1 files changed, 6 insertions, 10 deletions
diff --git a/compiler/luci/import/src/Nodes/CircleGatherNd.cpp b/compiler/luci/import/src/Nodes/CircleGatherNd.cpp index 981adbf63..d336878ad 100644 --- a/compiler/luci/import/src/Nodes/CircleGatherNd.cpp +++ b/compiler/luci/import/src/Nodes/CircleGatherNd.cpp @@ -27,19 +27,15 @@ namespace luci bool CircleGatherNdGraphBuilder::validate(const ValidateArgs &args) const { - const auto &inputs = args.op.inputs; - const auto &outputs = args.op.outputs; - - if (inputs.size() != 2) + if (!GraphBuilder::validate(args, 2)) return false; - if (outputs.size() != 1) - return false; - - auto &indices_tensor = args.reader.tensors()[inputs.at(1)]; + const auto &inputs = args.op.inputs; + auto indices_tensor = args.reader.tensors()[inputs.at(1)]; + assert(indices_tensor != nullptr); - if (!(indices_tensor->type == circle::TensorType::TensorType_INT32 || - indices_tensor->type == circle::TensorType::TensorType_INT64)) + if (!(indices_tensor->type() == circle::TensorType::TensorType_INT32 || + indices_tensor->type() == circle::TensorType::TensorType_INT64)) { return false; } |