summaryrefslogtreecommitdiff
path: root/runtime/libs/tflite/port/1.13.1/src/nnapi_delegate_ex_AddOpsAndParams_lambda.inc
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/libs/tflite/port/1.13.1/src/nnapi_delegate_ex_AddOpsAndParams_lambda.inc')
-rw-r--r--runtime/libs/tflite/port/1.13.1/src/nnapi_delegate_ex_AddOpsAndParams_lambda.inc32
1 files changed, 21 insertions, 11 deletions
diff --git a/runtime/libs/tflite/port/1.13.1/src/nnapi_delegate_ex_AddOpsAndParams_lambda.inc b/runtime/libs/tflite/port/1.13.1/src/nnapi_delegate_ex_AddOpsAndParams_lambda.inc
index 5b718029b..ee758105f 100644
--- a/runtime/libs/tflite/port/1.13.1/src/nnapi_delegate_ex_AddOpsAndParams_lambda.inc
+++ b/runtime/libs/tflite/port/1.13.1/src/nnapi_delegate_ex_AddOpsAndParams_lambda.inc
@@ -135,7 +135,7 @@
assert(count == 1);
};
- auto add_reducer_v12_params = [&add_scalar_bool8](void* data) {
+ auto add_reducer_params = [&add_scalar_bool8](void* data) {
auto builtin = reinterpret_cast<TfLiteReducerParams*>(data);
if (builtin == nullptr)
{
@@ -147,14 +147,24 @@
}
};
- auto add_reducer_params = [&add_scalar_int32](void* data) {
- auto builtin = reinterpret_cast<TfLiteReducerParams*>(data);
- if (builtin == nullptr)
- {
- add_scalar_int32(0);
- }
- else
- {
- add_scalar_int32(builtin->keep_dims);
- }
+ auto add_one_hot_tensor_inputs_as_scalar = [subgraph, &node, &augmented_inputs,
+ &add_scalar_float32]() {
+ assert(augmented_inputs.size() == 4);
+ const auto on_value_idx = node.inputs->data[2];
+ const auto off_value_idx = node.inputs->data[3];
+ const auto on_value_tensor = subgraph->tensor(on_value_idx);
+ const auto off_value_tensor = subgraph->tensor(off_value_idx);
+ assert(on_value_tensor->type == kTfLiteFloat32);
+ assert(off_value_tensor->type == kTfLiteFloat32);
+ const auto on_value = *on_value_tensor->data.f;
+ const auto off_value = *off_value_tensor->data.f;
+ augmented_inputs.pop_back();
+ augmented_inputs.pop_back();
+ add_scalar_float32(on_value);
+ add_scalar_float32(off_value);
+ };
+
+ auto add_one_hot_params = [&add_scalar_int32](void* data) {
+ const auto* builtin = reinterpret_cast<TfLiteOneHotParams*>(data);
+ add_scalar_int32(builtin->axis);
};