From 83d66c0551e0ecf01316b88fc672f8155fdef939 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=98=A4=ED=98=95=EC=84=9D/On-Device=20Lab=28SR=29/Staff?= =?UTF-8?q?=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 16 Apr 2019 08:44:19 +0900 Subject: Move gemmlowp code in cker (#4998) Move gemmlowp code in cker into cker/gemmlowp Add namespace for gemmlowp Signed-off-by: Hyeongseok Oh --- libs/cker/include/cker/FixedPoint.h | 288 -------------------------- libs/cker/include/cker/Utils.h | 9 +- libs/cker/include/cker/gemmlowp/FixedPoint.h | 291 +++++++++++++++++++++++++++ libs/cker/include/cker/operation/SoftMax.h | 16 +- 4 files changed, 304 insertions(+), 300 deletions(-) delete mode 100644 libs/cker/include/cker/FixedPoint.h create mode 100644 libs/cker/include/cker/gemmlowp/FixedPoint.h diff --git a/libs/cker/include/cker/FixedPoint.h b/libs/cker/include/cker/FixedPoint.h deleted file mode 100644 index 653a56d04..000000000 --- a/libs/cker/include/cker/FixedPoint.h +++ /dev/null @@ -1,288 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * Copyright 2015 The Gemmlowp Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __NNFW_CKER_FXIED_POINT_H__ -#define __NNFW_CKER_FXIED_POINT_H__ - -#include -#include - -namespace nnfw -{ -namespace cker -{ - -inline int32_t RoundingHalfSum(int32_t a, int32_t b) -{ - int64_t a64 = a; - int64_t b64 = b; - int64_t sum = a64 + b64; - int64_t sign = sum >= 0 ? 1 : -1; - return static_cast((sum + sign) / 2); -} - -inline int32_t SaturatingRoundingDoublingHighMul(int32_t a, int32_t b) -{ - bool overflow = a == b && a == std::numeric_limits::min(); - int64_t a_64(a); - int64_t b_64(b); - int64_t ab_64 = a_64 * b_64; - int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30)); - int32_t ab_x2_high32 = static_cast((ab_64 + nudge) / (1ll << 31)); - return overflow ? std::numeric_limits::max() : ab_x2_high32; -} - -// Correctly-rounded-to-nearest division by a power-of-two. -// Also known as a rounding arithmetic right shift. -inline int32_t RoundingDivideByPOT(int32_t x, int exponent) -{ - assert(exponent >= 0); - assert(exponent <= 31); - const int32_t mask = ((1ll << exponent) - 1); - const int32_t zero = 0; - const int32_t one = 1; - const int32_t remainder = x & mask; - const int32_t threshold = (mask >> 1) + ((x < zero) ? one : zero); - return ((x >> exponent) + ((remainder > threshold) ? one : zero)); -} - -// Returns the product of a run-time integer value by a compile-time power -// of two, with either a positive exponent (equivalent to an arithmetic -// left shift, saturating) or a negative exponent (equivalent to an arithmetic -// right shift, rounding to nearest). -template 0 ? 1 : Exponent < 0 ? -1 : 0)> -struct ImplSaturatingRoundingMultiplyByPOT -{ -}; - -template struct ImplSaturatingRoundingMultiplyByPOT -{ - static int32_t eval(int32_t x) { return x; } -}; - -template struct ImplSaturatingRoundingMultiplyByPOT -{ - static int32_t eval(int32_t x) - { - const int32_t min = (std::numeric_limits::min()); - const int32_t max = (std::numeric_limits::max()); - const int32_t threshold = ((1 << (31 - Exponent)) - 1); - const int32_t zero = 0; - const int32_t one = 1; - - const int32_t positive_mask = ((x > threshold) ? ~zero : zero); - const int32_t negative_mask = ((x < -threshold) ? ~zero : zero); - - int32_t result = (x * (one << Exponent)); - result = (positive_mask ? max : result); - result = (negative_mask ? min : result); - return result; - } -}; - -template struct ImplSaturatingRoundingMultiplyByPOT -{ - static int32_t eval(int32_t x) { return RoundingDivideByPOT(x, -Exponent); } -}; - -template int32_t SaturatingRoundingMultiplyByPOT(int32_t x) -{ - return ImplSaturatingRoundingMultiplyByPOT::eval(x); -} - -template class FixedPoint -{ -public: - static constexpr int kTotalBits = 8 * sizeof(int32_t); - static constexpr int kIntegerBits = tIntegerBits; - static constexpr int kFractionalBits = kTotalBits - 1 - kIntegerBits; - static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits, "bad IntegerBits"); - - static const int32_t ScalarRawMax() { return std::numeric_limits::max(); } - - static FixedPoint FromRaw(int32_t x) - { - FixedPoint retval; - retval.raw() = x; - return retval; - } - - static FixedPoint FromScalarRaw(int32_t x) { return FromRaw(x); } - - template static FixedPoint ConstantPOT() - { - static constexpr int kOffset = kFractionalBits + Exponent; - static_assert(kOffset < 31, "Constant not exactly representable in this fixed-point format"); - return FromScalarRaw((int32_t)1 << kOffset); - } - - static FixedPoint Zero() { return FromScalarRaw(0); } - - static FixedPoint One() - { - return FromScalarRaw(kIntegerBits == 0 ? ScalarRawMax() : ((int32_t)1 << kFractionalBits)); - } - - int32_t raw() const { return i_; } - int32_t &raw() { return i_; } - -private: - int32_t i_; -}; - -// A FixedPoint multiplication is just a -// SaturatingRoundingDoublingHighMul operation on the underlying -// raw integer values. The IntegerBits simply add up, as is obvious -// from the fact that the range is [-2^IntegerBits, 2^IntegerBits). -template -FixedPoint operator*(FixedPoint a, - FixedPoint b) -{ - FixedPoint c; - c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw()); - return c; -} - -// Tweaking IntegerBits gives exact multiplication by a power of two. -template -FixedPoint ExactMulByPot(FixedPoint a) -{ - FixedPoint c; - c.raw() = a.raw(); - return c; -} - -template -FixedPoint operator+(FixedPoint a, FixedPoint b) -{ - return FixedPoint::FromRaw((a.raw() + b.raw())); -} -template -FixedPoint operator-(FixedPoint a, FixedPoint b) -{ - return FixedPoint::FromRaw((a.raw() - b.raw())); -} -template -FixedPoint operator&(FixedPoint a, FixedPoint b) -{ - return FixedPoint::FromRaw((a.raw() & b.raw())); -} - -// Rescale changes the number of IntegerBits and updates the underlying -// raw integer value accordingly. -template -FixedPoint Rescale(FixedPoint x) -{ - static constexpr int kExponent = tIntegerBitsSrc - tIntegerBitsDst; - FixedPoint result; - result.raw() = SaturatingRoundingMultiplyByPOT(x.raw()); - return result; -} - -// Implementation of exponential function. - -// Returns exp(x) for x in [-1/4, 0). -inline FixedPoint<0> exp_on_interval_between_negative_one_quarter_and_0_excl(FixedPoint<0> a) -{ - typedef FixedPoint<0> F; - const F constant_term = F::FromScalarRaw(RoundingDivideByPOT(1895147668, 0)); - const F constant_1_over_3 = F::FromScalarRaw(RoundingDivideByPOT(715827883, 0)); - // We're evaluating a Taylor expansion around -1/8, so we do the change of - // variable: x = a + 1/8. - // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28. - F x = a + F::template ConstantPOT<-3>(); - F x2 = x * x; - F x3 = x2 * x; - F x4 = x2 * x2; - F x4_over_4 = F::FromScalarRaw(SaturatingRoundingMultiplyByPOT<-2>(x4.raw())); - F x4_over_24_plus_x3_over_6_plus_x2_over_2 = F::FromScalarRaw( - SaturatingRoundingMultiplyByPOT<-1>((((x4_over_4 + x3) * constant_1_over_3) + x2).raw())); - return (constant_term + constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2)); -} - -// Returns exp(x) for x < 0. -template FixedPoint<0> exp_on_negative_values(FixedPoint a) -{ - typedef FixedPoint InputF; - typedef FixedPoint<0> ResultF; - static constexpr int kFractionalBits = InputF::kFractionalBits; - static constexpr int kIntegerBits = InputF::kIntegerBits; - const InputF kOneQuarter = InputF::template ConstantPOT<-2>(); - InputF mask = kOneQuarter - InputF::FromScalarRaw(1); - InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter; - ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl( - Rescale<0>(a_mod_quarter_minus_one_quarter)); - int32_t remainder = (a_mod_quarter_minus_one_quarter - a).raw(); - - const int32_t zero = 0; - -#define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier) \ - if (kIntegerBits > Exponent) \ - { \ - const ResultF kMultiplier = \ - ResultF::FromScalarRaw(RoundingDivideByPOT(FixedPointMultiplier, 0)); \ - static constexpr int kShiftAmount = \ - ((kIntegerBits > Exponent) ? (kFractionalBits + Exponent) : 0); \ - result = ((remainder & (1 << kShiftAmount)) ? (result * kMultiplier) : result); \ - } - - GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947); - GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674); - GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084); - GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308); - GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535); - GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401); - GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242); - -#undef GEMMLOWP_EXP_BARREL_SHIFTER - - static constexpr int clampB = ((kIntegerBits > 5) ? (36 - kIntegerBits) : 0); - if (kIntegerBits > 5) - { - const InputF clamp = InputF::FromScalarRaw(RoundingDivideByPOT(-(1 << clampB), 0)); - result.raw() = ((a.raw() < clamp.raw()) ? ResultF::Zero().raw() : result.raw()); - } - - result.raw() = (a.raw() ? result.raw() : ResultF::One().raw()); - return result; -} - -// Returns 1 / (1 + x) for x in (0, 1). -inline FixedPoint<0> one_over_one_plus_x_for_x_in_0_1(FixedPoint<0> a) -{ - typedef FixedPoint<0> F0; - typedef FixedPoint<2> F2; - F0 half_denominator = F0::FromScalarRaw(RoundingHalfSum(a.raw(), F0::One().raw())); - // Newton-Raphson division - // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division - // Refer to that page for the logic behind the 48/17 and 32/17 constants. - const F2 constant_48_over_17 = F2::FromScalarRaw(RoundingDivideByPOT(1515870810, 0)); - const F2 constant_neg_32_over_17 = F2::FromScalarRaw(RoundingDivideByPOT(-1010580540, 0)); - F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17; - for (int i = 0; i < 3; i++) - { - F2 half_denominator_times_x = half_denominator * x; - F2 one_minus_half_denominator_times_x = F2::One() - half_denominator_times_x; - x = x + Rescale<2>(x * one_minus_half_denominator_times_x); - } - return Rescale<0>(ExactMulByPot<-1>(x)); -} - -} // namespace cker -} // namespace nnfw - -#endif // __NNFW_CKER_FXIED_POINT_H__ diff --git a/libs/cker/include/cker/Utils.h b/libs/cker/include/cker/Utils.h index af98fd8b0..673423989 100644 --- a/libs/cker/include/cker/Utils.h +++ b/libs/cker/include/cker/Utils.h @@ -21,7 +21,7 @@ #include #include -#include "cker/FixedPoint.h" +#include "cker/gemmlowp/FixedPoint.h" namespace nnfw { @@ -38,14 +38,15 @@ inline int32_t MultiplyByQuantizedMultiplier(int32_t x, int32_t quantized_multip { int left_shift = shift > 0 ? shift : 0; int right_shift = shift > 0 ? 0 : -shift; - return RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(x * (1 << left_shift), quantized_multiplier), right_shift); + return gemmlowp::RoundingDivideByPOT( + gemmlowp::SaturatingRoundingDoublingHighMul(x * (1 << left_shift), quantized_multiplier), + right_shift); } inline int32_t MultiplyByQuantizedMultiplierGreaterThanOne(int32_t x, int32_t quantized_multiplier, int left_shift) { - return SaturatingRoundingDoublingHighMul(x * (1 << left_shift), quantized_multiplier); + return gemmlowp::SaturatingRoundingDoublingHighMul(x * (1 << left_shift), quantized_multiplier); } inline int CountLeadingZeros(uint32_t integer_input) diff --git a/libs/cker/include/cker/gemmlowp/FixedPoint.h b/libs/cker/include/cker/gemmlowp/FixedPoint.h new file mode 100644 index 000000000..da9d25bd4 --- /dev/null +++ b/libs/cker/include/cker/gemmlowp/FixedPoint.h @@ -0,0 +1,291 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2015 The Gemmlowp Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_CKER_FXIED_POINT_H__ +#define __NNFW_CKER_FXIED_POINT_H__ + +#include +#include + +namespace nnfw +{ +namespace cker +{ +namespace gemmlowp +{ + +inline int32_t RoundingHalfSum(int32_t a, int32_t b) +{ + int64_t a64 = a; + int64_t b64 = b; + int64_t sum = a64 + b64; + int64_t sign = sum >= 0 ? 1 : -1; + return static_cast((sum + sign) / 2); +} + +inline int32_t SaturatingRoundingDoublingHighMul(int32_t a, int32_t b) +{ + bool overflow = a == b && a == std::numeric_limits::min(); + int64_t a_64(a); + int64_t b_64(b); + int64_t ab_64 = a_64 * b_64; + int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30)); + int32_t ab_x2_high32 = static_cast((ab_64 + nudge) / (1ll << 31)); + return overflow ? std::numeric_limits::max() : ab_x2_high32; +} + +// Correctly-rounded-to-nearest division by a power-of-two. +// Also known as a rounding arithmetic right shift. +inline int32_t RoundingDivideByPOT(int32_t x, int exponent) +{ + assert(exponent >= 0); + assert(exponent <= 31); + const int32_t mask = ((1ll << exponent) - 1); + const int32_t zero = 0; + const int32_t one = 1; + const int32_t remainder = x & mask; + const int32_t threshold = (mask >> 1) + ((x < zero) ? one : zero); + return ((x >> exponent) + ((remainder > threshold) ? one : zero)); +} + +// Returns the product of a run-time integer value by a compile-time power +// of two, with either a positive exponent (equivalent to an arithmetic +// left shift, saturating) or a negative exponent (equivalent to an arithmetic +// right shift, rounding to nearest). +template 0 ? 1 : Exponent < 0 ? -1 : 0)> +struct ImplSaturatingRoundingMultiplyByPOT +{ +}; + +template struct ImplSaturatingRoundingMultiplyByPOT +{ + static int32_t eval(int32_t x) { return x; } +}; + +template struct ImplSaturatingRoundingMultiplyByPOT +{ + static int32_t eval(int32_t x) + { + const int32_t min = (std::numeric_limits::min()); + const int32_t max = (std::numeric_limits::max()); + const int32_t threshold = ((1 << (31 - Exponent)) - 1); + const int32_t zero = 0; + const int32_t one = 1; + + const int32_t positive_mask = ((x > threshold) ? ~zero : zero); + const int32_t negative_mask = ((x < -threshold) ? ~zero : zero); + + int32_t result = (x * (one << Exponent)); + result = (positive_mask ? max : result); + result = (negative_mask ? min : result); + return result; + } +}; + +template struct ImplSaturatingRoundingMultiplyByPOT +{ + static int32_t eval(int32_t x) { return RoundingDivideByPOT(x, -Exponent); } +}; + +template int32_t SaturatingRoundingMultiplyByPOT(int32_t x) +{ + return ImplSaturatingRoundingMultiplyByPOT::eval(x); +} + +template class FixedPoint +{ +public: + static constexpr int kTotalBits = 8 * sizeof(int32_t); + static constexpr int kIntegerBits = tIntegerBits; + static constexpr int kFractionalBits = kTotalBits - 1 - kIntegerBits; + static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits, "bad IntegerBits"); + + static const int32_t ScalarRawMax() { return std::numeric_limits::max(); } + + static FixedPoint FromRaw(int32_t x) + { + FixedPoint retval; + retval.raw() = x; + return retval; + } + + static FixedPoint FromScalarRaw(int32_t x) { return FromRaw(x); } + + template static FixedPoint ConstantPOT() + { + static constexpr int kOffset = kFractionalBits + Exponent; + static_assert(kOffset < 31, "Constant not exactly representable in this fixed-point format"); + return FromScalarRaw((int32_t)1 << kOffset); + } + + static FixedPoint Zero() { return FromScalarRaw(0); } + + static FixedPoint One() + { + return FromScalarRaw(kIntegerBits == 0 ? ScalarRawMax() : ((int32_t)1 << kFractionalBits)); + } + + int32_t raw() const { return i_; } + int32_t &raw() { return i_; } + +private: + int32_t i_; +}; + +// A FixedPoint multiplication is just a +// SaturatingRoundingDoublingHighMul operation on the underlying +// raw integer values. The IntegerBits simply add up, as is obvious +// from the fact that the range is [-2^IntegerBits, 2^IntegerBits). +template +FixedPoint operator*(FixedPoint a, + FixedPoint b) +{ + FixedPoint c; + c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw()); + return c; +} + +// Tweaking IntegerBits gives exact multiplication by a power of two. +template +FixedPoint ExactMulByPot(FixedPoint a) +{ + FixedPoint c; + c.raw() = a.raw(); + return c; +} + +template +FixedPoint operator+(FixedPoint a, FixedPoint b) +{ + return FixedPoint::FromRaw((a.raw() + b.raw())); +} +template +FixedPoint operator-(FixedPoint a, FixedPoint b) +{ + return FixedPoint::FromRaw((a.raw() - b.raw())); +} +template +FixedPoint operator&(FixedPoint a, FixedPoint b) +{ + return FixedPoint::FromRaw((a.raw() & b.raw())); +} + +// Rescale changes the number of IntegerBits and updates the underlying +// raw integer value accordingly. +template +FixedPoint Rescale(FixedPoint x) +{ + static constexpr int kExponent = tIntegerBitsSrc - tIntegerBitsDst; + FixedPoint result; + result.raw() = SaturatingRoundingMultiplyByPOT(x.raw()); + return result; +} + +// Implementation of exponential function. + +// Returns exp(x) for x in [-1/4, 0). +inline FixedPoint<0> exp_on_interval_between_negative_one_quarter_and_0_excl(FixedPoint<0> a) +{ + typedef FixedPoint<0> F; + const F constant_term = F::FromScalarRaw(RoundingDivideByPOT(1895147668, 0)); + const F constant_1_over_3 = F::FromScalarRaw(RoundingDivideByPOT(715827883, 0)); + // We're evaluating a Taylor expansion around -1/8, so we do the change of + // variable: x = a + 1/8. + // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28. + F x = a + F::template ConstantPOT<-3>(); + F x2 = x * x; + F x3 = x2 * x; + F x4 = x2 * x2; + F x4_over_4 = F::FromScalarRaw(SaturatingRoundingMultiplyByPOT<-2>(x4.raw())); + F x4_over_24_plus_x3_over_6_plus_x2_over_2 = F::FromScalarRaw( + SaturatingRoundingMultiplyByPOT<-1>((((x4_over_4 + x3) * constant_1_over_3) + x2).raw())); + return (constant_term + constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2)); +} + +// Returns exp(x) for x < 0. +template FixedPoint<0> exp_on_negative_values(FixedPoint a) +{ + typedef FixedPoint InputF; + typedef FixedPoint<0> ResultF; + static constexpr int kFractionalBits = InputF::kFractionalBits; + static constexpr int kIntegerBits = InputF::kIntegerBits; + const InputF kOneQuarter = InputF::template ConstantPOT<-2>(); + InputF mask = kOneQuarter - InputF::FromScalarRaw(1); + InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter; + ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl( + Rescale<0>(a_mod_quarter_minus_one_quarter)); + int32_t remainder = (a_mod_quarter_minus_one_quarter - a).raw(); + + const int32_t zero = 0; + +#define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier) \ + if (kIntegerBits > Exponent) \ + { \ + const ResultF kMultiplier = \ + ResultF::FromScalarRaw(RoundingDivideByPOT(FixedPointMultiplier, 0)); \ + static constexpr int kShiftAmount = \ + ((kIntegerBits > Exponent) ? (kFractionalBits + Exponent) : 0); \ + result = ((remainder & (1 << kShiftAmount)) ? (result * kMultiplier) : result); \ + } + + GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947); + GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674); + GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084); + GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308); + GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535); + GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401); + GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242); + +#undef GEMMLOWP_EXP_BARREL_SHIFTER + + static constexpr int clampB = ((kIntegerBits > 5) ? (36 - kIntegerBits) : 0); + if (kIntegerBits > 5) + { + const InputF clamp = InputF::FromScalarRaw(RoundingDivideByPOT(-(1 << clampB), 0)); + result.raw() = ((a.raw() < clamp.raw()) ? ResultF::Zero().raw() : result.raw()); + } + + result.raw() = (a.raw() ? result.raw() : ResultF::One().raw()); + return result; +} + +// Returns 1 / (1 + x) for x in (0, 1). +inline FixedPoint<0> one_over_one_plus_x_for_x_in_0_1(FixedPoint<0> a) +{ + typedef FixedPoint<0> F0; + typedef FixedPoint<2> F2; + F0 half_denominator = F0::FromScalarRaw(RoundingHalfSum(a.raw(), F0::One().raw())); + // Newton-Raphson division + // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division + // Refer to that page for the logic behind the 48/17 and 32/17 constants. + const F2 constant_48_over_17 = F2::FromScalarRaw(RoundingDivideByPOT(1515870810, 0)); + const F2 constant_neg_32_over_17 = F2::FromScalarRaw(RoundingDivideByPOT(-1010580540, 0)); + F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17; + for (int i = 0; i < 3; i++) + { + F2 half_denominator_times_x = half_denominator * x; + F2 one_minus_half_denominator_times_x = F2::One() - half_denominator_times_x; + x = x + Rescale<2>(x * one_minus_half_denominator_times_x); + } + return Rescale<0>(ExactMulByPot<-1>(x)); +} + +} // namespace gemmlowp +} // namespace cker +} // namespace nnfw + +#endif // __NNFW_CKER_FXIED_POINT_H__ diff --git a/libs/cker/include/cker/operation/SoftMax.h b/libs/cker/include/cker/operation/SoftMax.h index d3082f7a6..322f5d5a2 100644 --- a/libs/cker/include/cker/operation/SoftMax.h +++ b/libs/cker/include/cker/operation/SoftMax.h @@ -20,7 +20,7 @@ #include "cker/Shape.h" #include "cker/Utils.h" -#include "cker/FixedPoint.h" +#include "cker/gemmlowp/FixedPoint.h" #include @@ -89,9 +89,9 @@ inline void Softmax(const SoftmaxParams ¶ms, const Shape &input_shape, // accumulation, but exp(-16) definitely is. static const int kScaledDiffIntegerBits = 5; static const int kAccumulationIntegerBits = 12; - using FixedPointScaledDiff = FixedPoint; - using FixedPointAccum = FixedPoint; - using FixedPoint0 = FixedPoint<0>; + using FixedPointScaledDiff = gemmlowp::FixedPoint; + using FixedPointAccum = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint<0>; const int trailing_dim = input_shape.DimensionsCount() - 1; const int outer_size = MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); @@ -115,8 +115,8 @@ inline void Softmax(const SoftmaxParams ¶ms, const Shape &input_shape, input_diff, input_beta_multiplier, input_beta_left_shift); const FixedPointScaledDiff scaled_diff_f8 = FixedPointScaledDiff::FromRaw(input_diff_rescaled); - sum_of_exps = - sum_of_exps + Rescale(exp_on_negative_values(scaled_diff_f8)); + sum_of_exps = sum_of_exps + gemmlowp::Rescale( + exp_on_negative_values(scaled_diff_f8)); } } @@ -144,8 +144,8 @@ inline void Softmax(const SoftmaxParams ¶ms, const Shape &input_shape, FixedPointScaledDiff::FromRaw(input_diff_rescaled); FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8); - int32_t unsat_output = - RoundingDivideByPOT((shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8); + int32_t unsat_output = gemmlowp::RoundingDivideByPOT((shifted_scale * exp_in_0).raw(), + num_bits_over_unit + 31 - 8); output_data[i * depth + c] = static_cast( std::max(std::min(unsat_output, static_cast(255)), static_cast(0))); -- cgit v1.2.3