diff options
Diffstat (limited to 'compute/cker/include/cker/operation/reference/BinaryArithmeticOps.h')
-rw-r--r-- | compute/cker/include/cker/operation/reference/BinaryArithmeticOps.h | 47 |
1 files changed, 20 insertions, 27 deletions
diff --git a/compute/cker/include/cker/operation/reference/BinaryArithmeticOps.h b/compute/cker/include/cker/operation/reference/BinaryArithmeticOps.h index f7e39248c..96e1d9127 100644 --- a/compute/cker/include/cker/operation/reference/BinaryArithmeticOps.h +++ b/compute/cker/include/cker/operation/reference/BinaryArithmeticOps.h @@ -56,28 +56,22 @@ inline void BinaryArithmeticOp(const BinaryArithmeticOpParam ¶ms, const Shap const int size = MatchingElementsSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < size; i++) { - output_data[i] = - ActivationFunctionWithMinMax(fn(input1_data[i], input2_data[i]), - params.float_activation_min, params.float_activation_max); + output_data[i] = ActivationFunctionWithMinMax( + fn(input1_data[i], input2_data[i]), params.float_activation_min, params.float_activation_max); } } template <typename T> -inline void BroadcastBinaryArithmeticOpSlowQuant8( - const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape, const T *input1_data, - const Shape &input2_shape, const T *input2_data, const Shape &output_shape, T *output_data, - const std::function<T(const BinaryArithmeticOpParam ¶ms, const T &, const T &)> &fn) +inline typename std::enable_if_t<is_quant8<T>::value> BroadcastBinaryArithmeticOpSlow( + const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape, const T *input1_data, + const Shape &input2_shape, const T *input2_data, const Shape &output_shape, T *output_data, + const std::function<T(const BinaryArithmeticOpParam ¶ms, const T &, const T &)> &fn) { NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, &desc2); const Shape extended_output_shape = Shape::ExtendedShape(4, output_shape); - if ((params.quantized_activation_min < 0) && (params.quantized_activation_max > 255)) - { - throw std::runtime_error{"Support only for Quant8."}; - } - // Comment from tensorflow lite: // // In Tensorflow, the dimensions are canonically named (batch_number, row, @@ -99,11 +93,10 @@ inline void BroadcastBinaryArithmeticOpSlowQuant8( { for (int c = 0; c < extended_output_shape.Dims(3); ++c) { - output_data[Offset(extended_output_shape, b, y, x, c)] = - ActivationFunctionWithMinMax<uint8_t>( - fn(params, input1_data[SubscriptToIndex(desc1, b, y, x, c)], - input2_data[SubscriptToIndex(desc2, b, y, x, c)]), - params.quantized_activation_min, params.quantized_activation_max); + output_data[Offset(extended_output_shape, b, y, x, c)] = ActivationFunctionWithMinMax<T>( + fn(params, input1_data[SubscriptToIndex(desc1, b, y, x, c)], + input2_data[SubscriptToIndex(desc2, b, y, x, c)]), + params.quantized_activation_min, params.quantized_activation_max); } } } @@ -143,9 +136,9 @@ inline void BroadcastBinaryArithmeticOpSlow(const BinaryArithmeticOpParam ¶m for (int c = 0; c < extended_output_shape.Dims(3); ++c) { output_data[Offset(extended_output_shape, b, y, x, c)] = ActivationFunctionWithMinMax<T>( - fn(input1_data[SubscriptToIndex(desc1, b, y, x, c)], - input2_data[SubscriptToIndex(desc2, b, y, x, c)]), - params.quantized_activation_min, params.quantized_activation_max); + fn(input1_data[SubscriptToIndex(desc1, b, y, x, c)], + input2_data[SubscriptToIndex(desc2, b, y, x, c)]), + params.quantized_activation_min, params.quantized_activation_max); } } } @@ -154,9 +147,9 @@ inline void BroadcastBinaryArithmeticOpSlow(const BinaryArithmeticOpParam ¶m template <> inline void BroadcastBinaryArithmeticOpSlow( - const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape, const float *input1_data, - const Shape &input2_shape, const float *input2_data, const Shape &output_shape, - float *output_data, const std::function<float(const float &, const float &)> &fn) + const BinaryArithmeticOpParam ¶ms, const Shape &input1_shape, const float *input1_data, + const Shape &input2_shape, const float *input2_data, const Shape &output_shape, + float *output_data, const std::function<float(const float &, const float &)> &fn) { NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; @@ -171,10 +164,10 @@ inline void BroadcastBinaryArithmeticOpSlow( { for (int c = 0; c < extended_output_shape.Dims(3); ++c) { - output_data[Offset(extended_output_shape, b, y, x, c)] = ActivationFunctionWithMinMax( - fn(input1_data[SubscriptToIndex(desc1, b, y, x, c)], - input2_data[SubscriptToIndex(desc2, b, y, x, c)]), - params.float_activation_min, params.float_activation_max); + output_data[Offset(extended_output_shape, b, y, x, c)] = + ActivationFunctionWithMinMax(fn(input1_data[SubscriptToIndex(desc1, b, y, x, c)], + input2_data[SubscriptToIndex(desc2, b, y, x, c)]), + params.float_activation_min, params.float_activation_max); } } } |