summaryrefslogtreecommitdiff
path: root/onert-micro/luci-interpreter/src/kernels/Softmax.cpp
diff options
context:
space:
mode:
authorHyeongseok Oh <hseok82.oh@samsung.com>2023-09-08 10:51:25 +0000
committerHyeongseok Oh <hseok82.oh@samsung.com>2023-09-08 10:51:25 +0000
commiteed258505ee1ad0f72d9e0a8a3934f2e9e7b5e79 (patch)
tree1aa860656489469003375a0f67edb1d729f7dc6b /onert-micro/luci-interpreter/src/kernels/Softmax.cpp
parent3a0ad354832744d138b361ffcfd21f33494beb6b (diff)
downloadnnfw-eed258505ee1ad0f72d9e0a8a3934f2e9e7b5e79.tar.gz
nnfw-eed258505ee1ad0f72d9e0a8a3934f2e9e7b5e79.tar.bz2
nnfw-eed258505ee1ad0f72d9e0a8a3934f2e9e7b5e79.zip
Diffstat (limited to 'onert-micro/luci-interpreter/src/kernels/Softmax.cpp')
-rw-r--r--onert-micro/luci-interpreter/src/kernels/Softmax.cpp88
1 files changed, 19 insertions, 69 deletions
diff --git a/onert-micro/luci-interpreter/src/kernels/Softmax.cpp b/onert-micro/luci-interpreter/src/kernels/Softmax.cpp
index cfc1dfc6f..4647cc94e 100644
--- a/onert-micro/luci-interpreter/src/kernels/Softmax.cpp
+++ b/onert-micro/luci-interpreter/src/kernels/Softmax.cpp
@@ -1,6 +1,5 @@
/*
* Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -17,8 +16,8 @@
#include "Builders.h"
#include "kernels/Utils.h"
+#include "SISOKernel.h"
-#include <tensorflow/lite/kernels/internal/reference/softmax.h>
#include "PALSoftmax.h"
namespace luci_interpreter
@@ -34,97 +33,48 @@ void evalFloat(const circle::Tensor *input, const circle::Tensor *output,
const auto *input_data = runtime_graph->getDataByTensor(input);
auto *output_data = runtime_graph->getDataByTensor(output);
- tflite::SoftmaxParams op_params{};
- op_params.beta = options->beta();
-
- tflite::reference_ops::Softmax(
- op_params, kernels::getTensorShape(input), kernels::getTensorData<float>(input_data),
- kernels::getTensorShape(output), kernels::getTensorData<float>(output_data));
+ luci_interpreter_pal::Softmax(options->beta(), kernels::getTensorShape(input),
+ kernels::getTensorData<float>(input_data),
+ kernels::getTensorData<float>(output_data));
}
#endif // DIS_FLOAT
-#ifndef DIS_QUANT
-template <typename T>
-void evalQuantized(const circle::Tensor *input, const circle::Tensor *output,
- const circle::SoftmaxOptions *options, BaseRuntimeGraph *runtime_graph)
-{
- // TODO: Enable it
- assert(false && "Not impl yet");
-
- const auto *input_data = runtime_graph->getDataByTensor(input);
- auto *output_data = runtime_graph->getDataByTensor(output);
-
- tflite::SoftmaxParams op_params{};
-
- luci_interpreter_pal::InitializeParams(&op_params, Tensor::scale(input), options->beta());
- luci_interpreter_pal::Softmax(
- op_params, kernels::getTensorShape(input), kernels::getTensorData<T>(input_data),
- kernels::getTensorShape(output), kernels::getTensorData<T>(output_data));
-}
-#endif
-
} // namespace
void configure_kernel_CircleSoftmax(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
{
- const auto input_index = cur_op->inputs()->operator[](0);
- const auto output_index = cur_op->outputs()->operator[](0);
-
- assert(input_index != -1);
- assert(output_index != -1);
-
- const auto input = runtime_graph->getCircleTensorByIndex(input_index);
- auto output = runtime_graph->getCircleTensorByIndex(output_index);
+ kernels::SISOKernel kernel(cur_op, runtime_graph);
- assert(input != nullptr);
- assert(output != nullptr);
-
- LUCI_INTERPRETER_CHECK(Tensor::element_type(input) == Tensor::element_type(output));
- LUCI_INTERPRETER_CHECK(Tensor::num_dims(input) >= 1);
+ LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input()) ==
+ Tensor::element_type(kernel.output()));
+ LUCI_INTERPRETER_CHECK(Tensor::num_dims(kernel.input()) >= 1);
#ifndef DIS_QUANT
- if (Tensor::element_type(input) == DataType::U8 || Tensor::element_type(input) == DataType::S8)
+ if (Tensor::element_type(kernel.input()) == DataType::U8 ||
+ Tensor::element_type(kernel.input()) == DataType::S8)
{
- LUCI_INTERPRETER_CHECK(Tensor::element_type(input) == DataType::S8 ||
- Tensor::zero_point(output) == 0);
- LUCI_INTERPRETER_CHECK(Tensor::element_type(input) == DataType::U8 ||
- Tensor::zero_point(output) == std::numeric_limits<int8_t>::min());
+ LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input()) == DataType::S8 ||
+ Tensor::zero_point(kernel.output()) == 0);
+ LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input()) == DataType::U8 ||
+ Tensor::zero_point(kernel.output()) ==
+ std::numeric_limits<int8_t>::min());
}
#endif
}
-void execute_kernel_CircleSoftmax(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph,
- bool)
+void execute_kernel_CircleSoftmax(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
{
- const auto input_index = cur_op->inputs()->operator[](0);
- const auto output_index = cur_op->outputs()->operator[](0);
-
- assert(input_index != -1);
- assert(output_index != -1);
-
- const auto input = runtime_graph->getCircleTensorByIndex(input_index);
- auto output = runtime_graph->getCircleTensorByIndex(output_index);
-
- assert(input != nullptr);
- assert(output != nullptr);
+ kernels::SISOKernel kernel(cur_op, runtime_graph);
const auto *options = cur_op->builtin_options_as_SoftmaxOptions();
- switch (Tensor::element_type(input))
+ switch (Tensor::element_type(kernel.input()))
{
#ifndef DIS_FLOAT
case DataType::FLOAT32:
- evalFloat(input, output, options, runtime_graph);
+ evalFloat(kernel.input(), kernel.output(), options, runtime_graph);
break;
#endif // DIS_FLOAT
-#ifndef DIS_QUANT
- case DataType::S8:
- evalQuantized<int8_t>(input, output, options, runtime_graph);
- break;
- case DataType::U8:
- evalQuantized<uint8_t>(input, output, options, runtime_graph);
- break;
-#endif // DIS_QUANT
default:
assert(false && "Unsupported type.");
}