diff options
Diffstat (limited to 'compute/cker/include/cker/operation/BinaryArithmeticOps.h')
-rw-r--r-- | compute/cker/include/cker/operation/BinaryArithmeticOps.h | 128 |
1 files changed, 73 insertions, 55 deletions
diff --git a/compute/cker/include/cker/operation/BinaryArithmeticOps.h b/compute/cker/include/cker/operation/BinaryArithmeticOps.h index 60dd02651..27b3fa49a 100644 --- a/compute/cker/include/cker/operation/BinaryArithmeticOps.h +++ b/compute/cker/include/cker/operation/BinaryArithmeticOps.h @@ -19,6 +19,8 @@ #define __NNFW_CKER_BINARY_ARITHMETIC_OPS_H__ #include <functional> +#include "cker/operation/optimized/BinaryArithmeticOps.h" +#include "cker/operation/reference/BinaryArithmeticOps.h" #include "cker/Shape.h" #include "cker/Types.h" #include "cker/Utils.h" @@ -28,69 +30,82 @@ namespace nnfw namespace cker { -struct BinaryArithmeticOpParam +namespace { - // Shape dependent / common to data / op types. - // BroadcastableOpCategory broadcast_category; - // uint8 inference params. - int32_t input1_offset; - int32_t input2_offset; - int32_t output_offset; - int32_t output_multiplier; - int32_t output_shift; - // Add / Sub, not Mul, uint8 inference params. - int32_t left_shift; - int32_t input1_multiplier; - int32_t input1_shift; - int32_t input2_multiplier; - int32_t input2_shift; - // uint8, etc, activation params. - int32_t quantized_activation_min; - int32_t quantized_activation_max; - // float activation params. - float float_activation_min; - float float_activation_max; - - // Processed output dimensions. - // Let input "a" be the one that broadcasts in the faster-changing dimension. - // Then, after coalescing, for shapes {a0, a1, a2, a3, a4} and - // {b0, b1, b2, b3, b4}, - // broadcast_shape[4] = b0 = a0. - // broadcast_shape[3] = b1; a1 = 1. - // broadcast_shape[2] = b2 = a2. - // broadcast_shape[1] = a3; b3 = 1. - // broadcast_shape[0] = b4 = a4. - // int broadcast_shape[5]; -}; +template <typename T> +const std::function<T(const T &, const T &)> GetBinaryArtithmeticFn(BinaryArithmeticOpType type) +{ + switch (type) + { + case BinaryArithmeticOpType::ADD: + { + return [](const T &a, const T &b) -> T { return a + b; }; + } + case BinaryArithmeticOpType::MUL: + { + return [](const T &a, const T &b) -> T { return a * b; }; + } + case BinaryArithmeticOpType::SUB: + { + return [](const T &a, const T &b) -> T { return a - b; }; + } + case BinaryArithmeticOpType::DIV: + { + return [](const T &a, const T &b) -> T { + if (b == 0) + { + throw std::runtime_error("Divide by zero"); + } + return a / b; + }; + } + default: + { + assert(false); + return nullptr; + } + } +} +} // namespace template <typename T> inline void BinaryArithmeticOp(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape, const T *input1_data, const Shape &input2_shape, - const T *input2_data, const Shape &output_shape, T *output_data, - const std::function<T(const T &, const T &)> &fn) + const T *input2_data, const Shape &output_shape, T *output_data) { - const int32_t flat_size = MatchingFlatSize(input1_shape, input2_shape, output_shape); - for (int i = 0; i < flat_size; ++i) - { - output_data[i] = ActivationFunctionWithMinMax(fn(input1_data[i], input2_data[i]), - params.quantized_activation_min, - params.quantized_activation_max); - } + reference::BinaryArithmeticOp(params, input1_shape, input1_data, input2_shape, input2_data, + output_shape, output_data, GetBinaryArtithmeticFn<T>(params.type)); } template <> inline void BinaryArithmeticOp(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape, const float *input1_data, const Shape &input2_shape, const float *input2_data, const Shape &output_shape, - float *output_data, - const std::function<float(const float &, const float &)> &fn) + float *output_data) { - const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape); - for (int i = 0; i < size; i++) + // Supported type is only float now + switch (params.type) { - output_data[i] = - ActivationFunctionWithMinMax(fn(input1_data[i], input2_data[i]), - params.float_activation_min, params.float_activation_max); + case nnfw::cker::BinaryArithmeticOpType::ADD: + optimized::Add(params, input1_shape, input1_data, input2_shape, input2_data, output_shape, + output_data); + break; + case nnfw::cker::BinaryArithmeticOpType::MUL: + optimized::Mul(params, input1_shape, input1_data, input2_shape, input2_data, output_shape, + output_data); + break; + case nnfw::cker::BinaryArithmeticOpType::SUB: + optimized::Sub(params, input1_shape, input1_data, input2_shape, input2_data, output_shape, + output_data); + break; + case nnfw::cker::BinaryArithmeticOpType::DIV: + reference::BinaryArithmeticOp(params, input1_shape, input1_data, input2_shape, input2_data, + output_shape, output_data, + GetBinaryArtithmeticFn<float>(params.type)); + break; + default: + assert(false); + break; } } @@ -98,14 +113,15 @@ template <typename T> inline void BroadcastBinaryArithmeticOpSlow(const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape, const T *input1_data, const Shape &input2_shape, const T *input2_data, - const Shape &output_shape, T *output_data, - const std::function<T(const T &, const T &)> &fn) + const Shape &output_shape, T *output_data) { NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, &desc2); const Shape extended_output_shape = Shape::ExtendedShape(4, output_shape); + const auto fn = GetBinaryArtithmeticFn<T>(params.type); + // Comment from tensorflow lite: // // In Tensorflow, the dimensions are canonically named (batch_number, row, @@ -138,16 +154,18 @@ inline void BroadcastBinaryArithmeticOpSlow(const BinaryArithmeticOpParam ¶m } template <> -inline void BroadcastBinaryArithmeticOpSlow( - const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape, const float *input1_data, - const Shape &input2_shape, const float *input2_data, const Shape &output_shape, - float *output_data, const std::function<float(const float &, const float &)> &fn) +inline void BroadcastBinaryArithmeticOpSlow(const BinaryArithmeticOpParam ¶ms, + const Shape &input1_shape, const float *input1_data, + const Shape &input2_shape, const float *input2_data, + const Shape &output_shape, float *output_data) { NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, &desc2); const Shape extended_output_shape = Shape::ExtendedShape(4, output_shape); + const auto fn = GetBinaryArtithmeticFn<float>(params.type); + for (int b = 0; b < extended_output_shape.Dims(0); ++b) { for (int y = 0; y < extended_output_shape.Dims(1); ++y) |