summaryrefslogtreecommitdiff
path: root/compute/cker/include/cker/operation/BinaryArithmeticOps.h
diff options
context:
space:
mode:
Diffstat (limited to 'compute/cker/include/cker/operation/BinaryArithmeticOps.h')
-rw-r--r--compute/cker/include/cker/operation/BinaryArithmeticOps.h128
1 files changed, 73 insertions, 55 deletions
diff --git a/compute/cker/include/cker/operation/BinaryArithmeticOps.h b/compute/cker/include/cker/operation/BinaryArithmeticOps.h
index 60dd02651..27b3fa49a 100644
--- a/compute/cker/include/cker/operation/BinaryArithmeticOps.h
+++ b/compute/cker/include/cker/operation/BinaryArithmeticOps.h
@@ -19,6 +19,8 @@
#define __NNFW_CKER_BINARY_ARITHMETIC_OPS_H__
#include <functional>
+#include "cker/operation/optimized/BinaryArithmeticOps.h"
+#include "cker/operation/reference/BinaryArithmeticOps.h"
#include "cker/Shape.h"
#include "cker/Types.h"
#include "cker/Utils.h"
@@ -28,69 +30,82 @@ namespace nnfw
namespace cker
{
-struct BinaryArithmeticOpParam
+namespace
{
- // Shape dependent / common to data / op types.
- // BroadcastableOpCategory broadcast_category;
- // uint8 inference params.
- int32_t input1_offset;
- int32_t input2_offset;
- int32_t output_offset;
- int32_t output_multiplier;
- int32_t output_shift;
- // Add / Sub, not Mul, uint8 inference params.
- int32_t left_shift;
- int32_t input1_multiplier;
- int32_t input1_shift;
- int32_t input2_multiplier;
- int32_t input2_shift;
- // uint8, etc, activation params.
- int32_t quantized_activation_min;
- int32_t quantized_activation_max;
- // float activation params.
- float float_activation_min;
- float float_activation_max;
-
- // Processed output dimensions.
- // Let input "a" be the one that broadcasts in the faster-changing dimension.
- // Then, after coalescing, for shapes {a0, a1, a2, a3, a4} and
- // {b0, b1, b2, b3, b4},
- // broadcast_shape[4] = b0 = a0.
- // broadcast_shape[3] = b1; a1 = 1.
- // broadcast_shape[2] = b2 = a2.
- // broadcast_shape[1] = a3; b3 = 1.
- // broadcast_shape[0] = b4 = a4.
- // int broadcast_shape[5];
-};
+template <typename T>
+const std::function<T(const T &, const T &)> GetBinaryArtithmeticFn(BinaryArithmeticOpType type)
+{
+ switch (type)
+ {
+ case BinaryArithmeticOpType::ADD:
+ {
+ return [](const T &a, const T &b) -> T { return a + b; };
+ }
+ case BinaryArithmeticOpType::MUL:
+ {
+ return [](const T &a, const T &b) -> T { return a * b; };
+ }
+ case BinaryArithmeticOpType::SUB:
+ {
+ return [](const T &a, const T &b) -> T { return a - b; };
+ }
+ case BinaryArithmeticOpType::DIV:
+ {
+ return [](const T &a, const T &b) -> T {
+ if (b == 0)
+ {
+ throw std::runtime_error("Divide by zero");
+ }
+ return a / b;
+ };
+ }
+ default:
+ {
+ assert(false);
+ return nullptr;
+ }
+ }
+}
+} // namespace
template <typename T>
inline void BinaryArithmeticOp(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
const T *input1_data, const Shape &input2_shape,
- const T *input2_data, const Shape &output_shape, T *output_data,
- const std::function<T(const T &, const T &)> &fn)
+ const T *input2_data, const Shape &output_shape, T *output_data)
{
- const int32_t flat_size = MatchingFlatSize(input1_shape, input2_shape, output_shape);
- for (int i = 0; i < flat_size; ++i)
- {
- output_data[i] = ActivationFunctionWithMinMax(fn(input1_data[i], input2_data[i]),
- params.quantized_activation_min,
- params.quantized_activation_max);
- }
+ reference::BinaryArithmeticOp(params, input1_shape, input1_data, input2_shape, input2_data,
+ output_shape, output_data, GetBinaryArtithmeticFn<T>(params.type));
}
template <>
inline void BinaryArithmeticOp(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
const float *input1_data, const Shape &input2_shape,
const float *input2_data, const Shape &output_shape,
- float *output_data,
- const std::function<float(const float &, const float &)> &fn)
+ float *output_data)
{
- const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape);
- for (int i = 0; i < size; i++)
+ // Supported type is only float now
+ switch (params.type)
{
- output_data[i] =
- ActivationFunctionWithMinMax(fn(input1_data[i], input2_data[i]),
- params.float_activation_min, params.float_activation_max);
+ case nnfw::cker::BinaryArithmeticOpType::ADD:
+ optimized::Add(params, input1_shape, input1_data, input2_shape, input2_data, output_shape,
+ output_data);
+ break;
+ case nnfw::cker::BinaryArithmeticOpType::MUL:
+ optimized::Mul(params, input1_shape, input1_data, input2_shape, input2_data, output_shape,
+ output_data);
+ break;
+ case nnfw::cker::BinaryArithmeticOpType::SUB:
+ optimized::Sub(params, input1_shape, input1_data, input2_shape, input2_data, output_shape,
+ output_data);
+ break;
+ case nnfw::cker::BinaryArithmeticOpType::DIV:
+ reference::BinaryArithmeticOp(params, input1_shape, input1_data, input2_shape, input2_data,
+ output_shape, output_data,
+ GetBinaryArtithmeticFn<float>(params.type));
+ break;
+ default:
+ assert(false);
+ break;
}
}
@@ -98,14 +113,15 @@ template <typename T>
inline void BroadcastBinaryArithmeticOpSlow(const BinaryArithmeticOpParam &params,
const Shape &input1_shape, const T *input1_data,
const Shape &input2_shape, const T *input2_data,
- const Shape &output_shape, T *output_data,
- const std::function<T(const T &, const T &)> &fn)
+ const Shape &output_shape, T *output_data)
{
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, &desc2);
const Shape extended_output_shape = Shape::ExtendedShape(4, output_shape);
+ const auto fn = GetBinaryArtithmeticFn<T>(params.type);
+
// Comment from tensorflow lite:
//
// In Tensorflow, the dimensions are canonically named (batch_number, row,
@@ -138,16 +154,18 @@ inline void BroadcastBinaryArithmeticOpSlow(const BinaryArithmeticOpParam &param
}
template <>
-inline void BroadcastBinaryArithmeticOpSlow(
- const BinaryArithmeticOpParam &params, const Shape &input1_shape, const float *input1_data,
- const Shape &input2_shape, const float *input2_data, const Shape &output_shape,
- float *output_data, const std::function<float(const float &, const float &)> &fn)
+inline void BroadcastBinaryArithmeticOpSlow(const BinaryArithmeticOpParam &params,
+ const Shape &input1_shape, const float *input1_data,
+ const Shape &input2_shape, const float *input2_data,
+ const Shape &output_shape, float *output_data)
{
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, &desc2);
const Shape extended_output_shape = Shape::ExtendedShape(4, output_shape);
+ const auto fn = GetBinaryArtithmeticFn<float>(params.type);
+
for (int b = 0; b < extended_output_shape.Dims(0); ++b)
{
for (int y = 0; y < extended_output_shape.Dims(1); ++y)