summaryrefslogtreecommitdiff
path: root/compiler/luci-interpreter/src/kernels/Add.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci-interpreter/src/kernels/Add.cpp')
-rw-r--r--compiler/luci-interpreter/src/kernels/Add.cpp38
1 files changed, 32 insertions, 6 deletions
diff --git a/compiler/luci-interpreter/src/kernels/Add.cpp b/compiler/luci-interpreter/src/kernels/Add.cpp
index 7381c3849..d7bf3084f 100644
--- a/compiler/luci-interpreter/src/kernels/Add.cpp
+++ b/compiler/luci-interpreter/src/kernels/Add.cpp
@@ -38,8 +38,11 @@ Add::Add(const Tensor *input1, const Tensor *input2, Tensor *output, const AddPa
void Add::configure()
{
LUCI_INTERPRETER_CHECK(input1()->element_type() == input2()->element_type());
+ LUCI_INTERPRETER_CHECK(input1()->element_type() == output()->element_type());
if (input1()->element_type() == DataType::S16)
{
+ LUCI_INTERPRETER_CHECK(input1()->zero_points().size() == 1 &&
+ input2()->zero_points().size() == 1);
LUCI_INTERPRETER_CHECK(input1()->zero_point() == 0 && input2()->zero_point() == 0 &&
output()->zero_point() == 0);
}
@@ -54,6 +57,12 @@ void Add::execute() const
case DataType::FLOAT32:
evalFloat();
break;
+ case DataType::S64:
+ evalInteger<int64_t>();
+ break;
+ case DataType::S32:
+ evalInteger<int32_t>();
+ break;
case DataType::U8:
evalQuantized();
break;
@@ -67,13 +76,8 @@ void Add::execute() const
void Add::evalFloat() const
{
- float activation_min{};
- float activation_max{};
- calculateActivationRange(_params.activation, &activation_min, &activation_max);
-
tflite::ArithmeticParams params{};
- params.float_activation_min = activation_min;
- params.float_activation_max = activation_max;
+ fillArithmeticActivationRange<float>(params, _params.activation);
const bool need_broadcast = tflite::reference_ops::ProcessBroadcastShapes(
getTensorShape(input1()), getTensorShape(input2()), &params);
@@ -92,6 +96,28 @@ void Add::evalFloat() const
}
}
+template <typename T> void Add::evalInteger() const
+{
+ tflite::ArithmeticParams params{};
+ fillArithmeticActivationRange<T>(params, _params.activation);
+
+ const bool need_broadcast = tflite::reference_ops::ProcessBroadcastShapes(
+ getTensorShape(input1()), getTensorShape(input2()), &params);
+
+ if (need_broadcast)
+ {
+ tflite::reference_ops::BroadcastAdd4DSlow(
+ params, getTensorShape(input1()), getTensorData<T>(input1()), getTensorShape(input2()),
+ getTensorData<T>(input2()), getTensorShape(output()), getTensorData<T>(output()));
+ }
+ else
+ {
+ tflite::reference_ops::Add(params, getTensorShape(input1()), getTensorData<T>(input1()),
+ getTensorShape(input2()), getTensorData<T>(input2()),
+ getTensorShape(output()), getTensorData<T>(output()));
+ }
+}
+
void Add::evalQuantized() const
{
const auto input1_scale = static_cast<double>(input1()->scale());