summaryrefslogtreecommitdiff
path: root/onert-micro/luci-interpreter/src/kernels/NotEqual.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'onert-micro/luci-interpreter/src/kernels/NotEqual.cpp')
-rw-r--r--onert-micro/luci-interpreter/src/kernels/NotEqual.cpp147
1 files changed, 47 insertions, 100 deletions
diff --git a/onert-micro/luci-interpreter/src/kernels/NotEqual.cpp b/onert-micro/luci-interpreter/src/kernels/NotEqual.cpp
index 304939ee8..92f646f95 100644
--- a/onert-micro/luci-interpreter/src/kernels/NotEqual.cpp
+++ b/onert-micro/luci-interpreter/src/kernels/NotEqual.cpp
@@ -1,6 +1,5 @@
/*
* Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -15,128 +14,76 @@
* limitations under the License.
*/
-#include "kernels/NotEqual.h"
+#include "Builders.h"
#include "kernels/Utils.h"
+#include "TISOKernel.h"
-#include <tensorflow/lite/kernels/internal/reference/comparisons.h>
+#include "PALComparisons.h"
namespace luci_interpreter
{
-namespace kernels
+namespace
{
+// TODO: reduce code duplication with less
+template <typename T>
+void evalGeneric(const circle::Tensor *x, const circle::Tensor *y, const circle::Tensor *output,
+ BaseRuntimeGraph *runtime_graph)
+{
+ auto x_data = kernels::getTensorData<T>(runtime_graph->getDataByTensor(x));
+ if (x_data == nullptr)
+ x_data = kernels::getTensorData<T>(runtime_graph->getConstDataByTensor(x));
+
+ assert(x_data != nullptr);
+
+ auto y_data = kernels::getTensorData<T>(runtime_graph->getDataByTensor(y));
+ if (y_data == nullptr)
+ y_data = kernels::getTensorData<T>(runtime_graph->getConstDataByTensor(y));
+
+ assert(y_data != nullptr);
+
+ auto output_data = kernels::getTensorData<bool>(runtime_graph->getDataByTensor(output));
-NotEqual::NotEqual(const Tensor *x, const Tensor *y, Tensor *output) : Kernel({x, y}, {output}) {}
+ luci_interpreter_pal::ComparisonParams op_params;
+ op_params.is_broadcast = Tensor::num_elements(x) != Tensor::num_elements(y);
-void NotEqual::configure()
+ const int64_t flat_size = kernels::getTensorShape(x).flatSize();
+ luci_interpreter_pal::ComparisonNoScaling<T>(flat_size, x_data, y_data, output_data,
+ luci_interpreter_pal::NotEqualFn);
+}
+
+} // namespace
+
+void configure_kernel_CircleNotEqual(const circle::Operator *cur_op,
+ BaseRuntimeGraph *runtime_graph)
{
- LUCI_INTERPRETER_CHECK(x()->element_type() == y()->element_type());
- LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::BOOL);
+ kernels::TISOKernel kernel(cur_op, runtime_graph);
- if (x()->element_type() == DataType::U8)
- {
- quantizeMultiplierSmallerThanOneExp(x()->scale(), &_x_multiplier, &_x_shift);
- quantizeMultiplierSmallerThanOneExp(y()->scale(), &_y_multiplier, &_y_shift);
- }
- // TODO: enable it only if kernel with dynamic shapes
- output()->resize(calculateShapeForBroadcast(x()->shape(), y()->shape()));
+ LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input1()) ==
+ Tensor::element_type(kernel.input2()));
+ LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.output()) == DataType::BOOL);
}
-void NotEqual::execute() const
+void execute_kernel_CircleNotEqual(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
{
- switch (x()->element_type())
+ kernels::TISOKernel kernel(cur_op, runtime_graph);
+
+ switch (Tensor::element_type(kernel.input1()))
{
- case DataType::FLOAT32:
- evalFloat();
- break;
case DataType::S64:
- evalInteger<int64_t>();
+ evalGeneric<int64_t>(kernel.input1(), kernel.input2(), kernel.output(), runtime_graph);
break;
case DataType::S32:
- evalInteger<int32_t>();
+ evalGeneric<int32_t>(kernel.input1(), kernel.input2(), kernel.output(), runtime_graph);
break;
- case DataType::U8:
- evalQuantized();
+#ifndef DIS_FLOAT
+ case DataType::FLOAT32:
+ evalGeneric<float>(kernel.input1(), kernel.input2(), kernel.output(), runtime_graph);
break;
+#endif // DIS_FLOAT
default:
assert(false && "Unsupported type.");
}
}
-void NotEqual::evalFloat() const
-{
- const auto x_data = getTensorData<float>(x());
- const auto y_data = getTensorData<float>(y());
- auto output_data = getTensorData<bool>(output());
-
- tflite::ComparisonParams op_params;
- op_params.is_broadcast = x()->shape() != y()->shape();
-
- if (op_params.is_broadcast)
- {
- tflite::reference_ops::Broadcast4DSlowNotEqual(op_params, getTensorShape(x()), x_data,
- getTensorShape(y()), y_data,
- getTensorShape(output()), output_data);
- }
- else
- {
- tflite::reference_ops::NotEqual(op_params, getTensorShape(x()), x_data, getTensorShape(y()),
- y_data, getTensorShape(output()), output_data);
- }
-}
-
-template <typename T> void NotEqual::evalInteger() const
-{
- const auto x_data = getTensorData<T>(x());
- const auto y_data = getTensorData<T>(y());
- auto output_data = getTensorData<bool>(output());
-
- tflite::ComparisonParams op_params;
- op_params.is_broadcast = x()->shape() != y()->shape();
-
- if (op_params.is_broadcast)
- {
- tflite::reference_ops::Broadcast4DSlowNotEqualNoScaling(op_params, getTensorShape(x()), x_data,
- getTensorShape(y()), y_data,
- getTensorShape(output()), output_data);
- }
- else
- {
- tflite::reference_ops::NotEqualNoScaling(op_params, getTensorShape(x()), x_data,
- getTensorShape(y()), y_data, getTensorShape(output()),
- output_data);
- }
-}
-
-void NotEqual::evalQuantized() const
-{
- const auto x_data = getTensorData<uint8_t>(x());
- const auto y_data = getTensorData<uint8_t>(y());
- auto output_data = getTensorData<bool>(output());
-
- tflite::ComparisonParams op_params;
- op_params.left_shift = 8;
- op_params.input1_offset = -x()->zero_point(); // Note the '-'
- op_params.input1_shift = _x_shift;
- op_params.input1_multiplier = _x_multiplier;
- op_params.input2_offset = -y()->zero_point(); // Note the '-'
- op_params.input2_shift = _y_shift;
- op_params.input2_multiplier = _y_multiplier;
- op_params.is_broadcast = x()->shape() != y()->shape();
-
- if (op_params.is_broadcast)
- {
- tflite::reference_ops::Broadcast4DSlowNotEqualWithScaling(
- op_params, getTensorShape(x()), x_data, getTensorShape(y()), y_data, getTensorShape(output()),
- output_data);
- }
- else
- {
- tflite::reference_ops::NotEqualWithScaling(op_params, getTensorShape(x()), x_data,
- getTensorShape(y()), y_data,
- getTensorShape(output()), output_data);
- }
-}
-
-} // namespace kernels
} // namespace luci_interpreter