diff options
Diffstat (limited to 'compiler/luci-interpreter/src/kernels/Add.cpp')
-rw-r--r-- | compiler/luci-interpreter/src/kernels/Add.cpp | 38 |
1 files changed, 32 insertions, 6 deletions
diff --git a/compiler/luci-interpreter/src/kernels/Add.cpp b/compiler/luci-interpreter/src/kernels/Add.cpp index 7381c3849..d7bf3084f 100644 --- a/compiler/luci-interpreter/src/kernels/Add.cpp +++ b/compiler/luci-interpreter/src/kernels/Add.cpp @@ -38,8 +38,11 @@ Add::Add(const Tensor *input1, const Tensor *input2, Tensor *output, const AddPa void Add::configure() { LUCI_INTERPRETER_CHECK(input1()->element_type() == input2()->element_type()); + LUCI_INTERPRETER_CHECK(input1()->element_type() == output()->element_type()); if (input1()->element_type() == DataType::S16) { + LUCI_INTERPRETER_CHECK(input1()->zero_points().size() == 1 && + input2()->zero_points().size() == 1); LUCI_INTERPRETER_CHECK(input1()->zero_point() == 0 && input2()->zero_point() == 0 && output()->zero_point() == 0); } @@ -54,6 +57,12 @@ void Add::execute() const case DataType::FLOAT32: evalFloat(); break; + case DataType::S64: + evalInteger<int64_t>(); + break; + case DataType::S32: + evalInteger<int32_t>(); + break; case DataType::U8: evalQuantized(); break; @@ -67,13 +76,8 @@ void Add::execute() const void Add::evalFloat() const { - float activation_min{}; - float activation_max{}; - calculateActivationRange(_params.activation, &activation_min, &activation_max); - tflite::ArithmeticParams params{}; - params.float_activation_min = activation_min; - params.float_activation_max = activation_max; + fillArithmeticActivationRange<float>(params, _params.activation); const bool need_broadcast = tflite::reference_ops::ProcessBroadcastShapes( getTensorShape(input1()), getTensorShape(input2()), ¶ms); @@ -92,6 +96,28 @@ void Add::evalFloat() const } } +template <typename T> void Add::evalInteger() const +{ + tflite::ArithmeticParams params{}; + fillArithmeticActivationRange<T>(params, _params.activation); + + const bool need_broadcast = tflite::reference_ops::ProcessBroadcastShapes( + getTensorShape(input1()), getTensorShape(input2()), ¶ms); + + if (need_broadcast) + { + tflite::reference_ops::BroadcastAdd4DSlow( + params, getTensorShape(input1()), getTensorData<T>(input1()), getTensorShape(input2()), + getTensorData<T>(input2()), getTensorShape(output()), getTensorData<T>(output())); + } + else + { + tflite::reference_ops::Add(params, getTensorShape(input1()), getTensorData<T>(input1()), + getTensorShape(input2()), getTensorData<T>(input2()), + getTensorShape(output()), getTensorData<T>(output())); + } +} + void Add::evalQuantized() const { const auto input1_scale = static_cast<double>(input1()->scale()); |