summaryrefslogtreecommitdiff
path: root/compiler/luci/import/src/Nodes/CircleGatherNd.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/import/src/Nodes/CircleGatherNd.cpp')
-rw-r--r--compiler/luci/import/src/Nodes/CircleGatherNd.cpp16
1 files changed, 6 insertions, 10 deletions
diff --git a/compiler/luci/import/src/Nodes/CircleGatherNd.cpp b/compiler/luci/import/src/Nodes/CircleGatherNd.cpp
index 981adbf63..d336878ad 100644
--- a/compiler/luci/import/src/Nodes/CircleGatherNd.cpp
+++ b/compiler/luci/import/src/Nodes/CircleGatherNd.cpp
@@ -27,19 +27,15 @@ namespace luci
bool CircleGatherNdGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 2)
+ if (!GraphBuilder::validate(args, 2))
return false;
- if (outputs.size() != 1)
- return false;
-
- auto &indices_tensor = args.reader.tensors()[inputs.at(1)];
+ const auto &inputs = args.op.inputs;
+ auto indices_tensor = args.reader.tensors()[inputs.at(1)];
+ assert(indices_tensor != nullptr);
- if (!(indices_tensor->type == circle::TensorType::TensorType_INT32 ||
- indices_tensor->type == circle::TensorType::TensorType_INT64))
+ if (!(indices_tensor->type() == circle::TensorType::TensorType_INT32 ||
+ indices_tensor->type() == circle::TensorType::TensorType_INT64))
{
return false;
}