diff options
Diffstat (limited to 'onert-micro/luci-interpreter/src/kernels/NotEqual.cpp')
-rw-r--r-- | onert-micro/luci-interpreter/src/kernels/NotEqual.cpp | 147 |
1 files changed, 47 insertions, 100 deletions
diff --git a/onert-micro/luci-interpreter/src/kernels/NotEqual.cpp b/onert-micro/luci-interpreter/src/kernels/NotEqual.cpp index 304939ee8..92f646f95 100644 --- a/onert-micro/luci-interpreter/src/kernels/NotEqual.cpp +++ b/onert-micro/luci-interpreter/src/kernels/NotEqual.cpp @@ -1,6 +1,5 @@ /* * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * Copyright 2019 The TensorFlow Authors. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,128 +14,76 @@ * limitations under the License. */ -#include "kernels/NotEqual.h" +#include "Builders.h" #include "kernels/Utils.h" +#include "TISOKernel.h" -#include <tensorflow/lite/kernels/internal/reference/comparisons.h> +#include "PALComparisons.h" namespace luci_interpreter { -namespace kernels +namespace { +// TODO: reduce code duplication with less +template <typename T> +void evalGeneric(const circle::Tensor *x, const circle::Tensor *y, const circle::Tensor *output, + BaseRuntimeGraph *runtime_graph) +{ + auto x_data = kernels::getTensorData<T>(runtime_graph->getDataByTensor(x)); + if (x_data == nullptr) + x_data = kernels::getTensorData<T>(runtime_graph->getConstDataByTensor(x)); + + assert(x_data != nullptr); + + auto y_data = kernels::getTensorData<T>(runtime_graph->getDataByTensor(y)); + if (y_data == nullptr) + y_data = kernels::getTensorData<T>(runtime_graph->getConstDataByTensor(y)); + + assert(y_data != nullptr); + + auto output_data = kernels::getTensorData<bool>(runtime_graph->getDataByTensor(output)); -NotEqual::NotEqual(const Tensor *x, const Tensor *y, Tensor *output) : Kernel({x, y}, {output}) {} + luci_interpreter_pal::ComparisonParams op_params; + op_params.is_broadcast = Tensor::num_elements(x) != Tensor::num_elements(y); -void NotEqual::configure() + const int64_t flat_size = kernels::getTensorShape(x).flatSize(); + luci_interpreter_pal::ComparisonNoScaling<T>(flat_size, x_data, y_data, output_data, + luci_interpreter_pal::NotEqualFn); +} + +} // namespace + +void configure_kernel_CircleNotEqual(const circle::Operator *cur_op, + BaseRuntimeGraph *runtime_graph) { - LUCI_INTERPRETER_CHECK(x()->element_type() == y()->element_type()); - LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::BOOL); + kernels::TISOKernel kernel(cur_op, runtime_graph); - if (x()->element_type() == DataType::U8) - { - quantizeMultiplierSmallerThanOneExp(x()->scale(), &_x_multiplier, &_x_shift); - quantizeMultiplierSmallerThanOneExp(y()->scale(), &_y_multiplier, &_y_shift); - } - // TODO: enable it only if kernel with dynamic shapes - output()->resize(calculateShapeForBroadcast(x()->shape(), y()->shape())); + LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input1()) == + Tensor::element_type(kernel.input2())); + LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.output()) == DataType::BOOL); } -void NotEqual::execute() const +void execute_kernel_CircleNotEqual(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph) { - switch (x()->element_type()) + kernels::TISOKernel kernel(cur_op, runtime_graph); + + switch (Tensor::element_type(kernel.input1())) { - case DataType::FLOAT32: - evalFloat(); - break; case DataType::S64: - evalInteger<int64_t>(); + evalGeneric<int64_t>(kernel.input1(), kernel.input2(), kernel.output(), runtime_graph); break; case DataType::S32: - evalInteger<int32_t>(); + evalGeneric<int32_t>(kernel.input1(), kernel.input2(), kernel.output(), runtime_graph); break; - case DataType::U8: - evalQuantized(); +#ifndef DIS_FLOAT + case DataType::FLOAT32: + evalGeneric<float>(kernel.input1(), kernel.input2(), kernel.output(), runtime_graph); break; +#endif // DIS_FLOAT default: assert(false && "Unsupported type."); } } -void NotEqual::evalFloat() const -{ - const auto x_data = getTensorData<float>(x()); - const auto y_data = getTensorData<float>(y()); - auto output_data = getTensorData<bool>(output()); - - tflite::ComparisonParams op_params; - op_params.is_broadcast = x()->shape() != y()->shape(); - - if (op_params.is_broadcast) - { - tflite::reference_ops::Broadcast4DSlowNotEqual(op_params, getTensorShape(x()), x_data, - getTensorShape(y()), y_data, - getTensorShape(output()), output_data); - } - else - { - tflite::reference_ops::NotEqual(op_params, getTensorShape(x()), x_data, getTensorShape(y()), - y_data, getTensorShape(output()), output_data); - } -} - -template <typename T> void NotEqual::evalInteger() const -{ - const auto x_data = getTensorData<T>(x()); - const auto y_data = getTensorData<T>(y()); - auto output_data = getTensorData<bool>(output()); - - tflite::ComparisonParams op_params; - op_params.is_broadcast = x()->shape() != y()->shape(); - - if (op_params.is_broadcast) - { - tflite::reference_ops::Broadcast4DSlowNotEqualNoScaling(op_params, getTensorShape(x()), x_data, - getTensorShape(y()), y_data, - getTensorShape(output()), output_data); - } - else - { - tflite::reference_ops::NotEqualNoScaling(op_params, getTensorShape(x()), x_data, - getTensorShape(y()), y_data, getTensorShape(output()), - output_data); - } -} - -void NotEqual::evalQuantized() const -{ - const auto x_data = getTensorData<uint8_t>(x()); - const auto y_data = getTensorData<uint8_t>(y()); - auto output_data = getTensorData<bool>(output()); - - tflite::ComparisonParams op_params; - op_params.left_shift = 8; - op_params.input1_offset = -x()->zero_point(); // Note the '-' - op_params.input1_shift = _x_shift; - op_params.input1_multiplier = _x_multiplier; - op_params.input2_offset = -y()->zero_point(); // Note the '-' - op_params.input2_shift = _y_shift; - op_params.input2_multiplier = _y_multiplier; - op_params.is_broadcast = x()->shape() != y()->shape(); - - if (op_params.is_broadcast) - { - tflite::reference_ops::Broadcast4DSlowNotEqualWithScaling( - op_params, getTensorShape(x()), x_data, getTensorShape(y()), y_data, getTensorShape(output()), - output_data); - } - else - { - tflite::reference_ops::NotEqualWithScaling(op_params, getTensorShape(x()), x_data, - getTensorShape(y()), y_data, - getTensorShape(output()), output_data); - } -} - -} // namespace kernels } // namespace luci_interpreter |