summaryrefslogtreecommitdiff
path: root/compute/cker/include/cker/operation/reference/BinaryArithmeticOps.h
diff options
context:
space:
mode:
Diffstat (limited to 'compute/cker/include/cker/operation/reference/BinaryArithmeticOps.h')
-rw-r--r--compute/cker/include/cker/operation/reference/BinaryArithmeticOps.h47
1 files changed, 20 insertions, 27 deletions
diff --git a/compute/cker/include/cker/operation/reference/BinaryArithmeticOps.h b/compute/cker/include/cker/operation/reference/BinaryArithmeticOps.h
index f7e39248c..96e1d9127 100644
--- a/compute/cker/include/cker/operation/reference/BinaryArithmeticOps.h
+++ b/compute/cker/include/cker/operation/reference/BinaryArithmeticOps.h
@@ -56,28 +56,22 @@ inline void BinaryArithmeticOp(const BinaryArithmeticOpParam &params, const Shap
const int size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < size; i++)
{
- output_data[i] =
- ActivationFunctionWithMinMax(fn(input1_data[i], input2_data[i]),
- params.float_activation_min, params.float_activation_max);
+ output_data[i] = ActivationFunctionWithMinMax(
+ fn(input1_data[i], input2_data[i]), params.float_activation_min, params.float_activation_max);
}
}
template <typename T>
-inline void BroadcastBinaryArithmeticOpSlowQuant8(
- const BinaryArithmeticOpParam &params, const Shape &input1_shape, const T *input1_data,
- const Shape &input2_shape, const T *input2_data, const Shape &output_shape, T *output_data,
- const std::function<T(const BinaryArithmeticOpParam &params, const T &, const T &)> &fn)
+inline typename std::enable_if_t<is_quant8<T>::value> BroadcastBinaryArithmeticOpSlow(
+ const BinaryArithmeticOpParam &params, const Shape &input1_shape, const T *input1_data,
+ const Shape &input2_shape, const T *input2_data, const Shape &output_shape, T *output_data,
+ const std::function<T(const BinaryArithmeticOpParam &params, const T &, const T &)> &fn)
{
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, &desc2);
const Shape extended_output_shape = Shape::ExtendedShape(4, output_shape);
- if ((params.quantized_activation_min < 0) && (params.quantized_activation_max > 255))
- {
- throw std::runtime_error{"Support only for Quant8."};
- }
-
// Comment from tensorflow lite:
//
// In Tensorflow, the dimensions are canonically named (batch_number, row,
@@ -99,11 +93,10 @@ inline void BroadcastBinaryArithmeticOpSlowQuant8(
{
for (int c = 0; c < extended_output_shape.Dims(3); ++c)
{
- output_data[Offset(extended_output_shape, b, y, x, c)] =
- ActivationFunctionWithMinMax<uint8_t>(
- fn(params, input1_data[SubscriptToIndex(desc1, b, y, x, c)],
- input2_data[SubscriptToIndex(desc2, b, y, x, c)]),
- params.quantized_activation_min, params.quantized_activation_max);
+ output_data[Offset(extended_output_shape, b, y, x, c)] = ActivationFunctionWithMinMax<T>(
+ fn(params, input1_data[SubscriptToIndex(desc1, b, y, x, c)],
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)]),
+ params.quantized_activation_min, params.quantized_activation_max);
}
}
}
@@ -143,9 +136,9 @@ inline void BroadcastBinaryArithmeticOpSlow(const BinaryArithmeticOpParam &param
for (int c = 0; c < extended_output_shape.Dims(3); ++c)
{
output_data[Offset(extended_output_shape, b, y, x, c)] = ActivationFunctionWithMinMax<T>(
- fn(input1_data[SubscriptToIndex(desc1, b, y, x, c)],
- input2_data[SubscriptToIndex(desc2, b, y, x, c)]),
- params.quantized_activation_min, params.quantized_activation_max);
+ fn(input1_data[SubscriptToIndex(desc1, b, y, x, c)],
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)]),
+ params.quantized_activation_min, params.quantized_activation_max);
}
}
}
@@ -154,9 +147,9 @@ inline void BroadcastBinaryArithmeticOpSlow(const BinaryArithmeticOpParam &param
template <>
inline void BroadcastBinaryArithmeticOpSlow(
- const BinaryArithmeticOpParam &params, const Shape &input1_shape, const float *input1_data,
- const Shape &input2_shape, const float *input2_data, const Shape &output_shape,
- float *output_data, const std::function<float(const float &, const float &)> &fn)
+ const BinaryArithmeticOpParam &params, const Shape &input1_shape, const float *input1_data,
+ const Shape &input2_shape, const float *input2_data, const Shape &output_shape,
+ float *output_data, const std::function<float(const float &, const float &)> &fn)
{
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
@@ -171,10 +164,10 @@ inline void BroadcastBinaryArithmeticOpSlow(
{
for (int c = 0; c < extended_output_shape.Dims(3); ++c)
{
- output_data[Offset(extended_output_shape, b, y, x, c)] = ActivationFunctionWithMinMax(
- fn(input1_data[SubscriptToIndex(desc1, b, y, x, c)],
- input2_data[SubscriptToIndex(desc2, b, y, x, c)]),
- params.float_activation_min, params.float_activation_max);
+ output_data[Offset(extended_output_shape, b, y, x, c)] =
+ ActivationFunctionWithMinMax(fn(input1_data[SubscriptToIndex(desc1, b, y, x, c)],
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)]),
+ params.float_activation_min, params.float_activation_max);
}
}
}