diff options
Diffstat (limited to 'runtimes/nn/depend/external/gemmlowp')
33 files changed, 10482 insertions, 0 deletions
diff --git a/runtimes/nn/depend/external/gemmlowp/CMakeLists.txt b/runtimes/nn/depend/external/gemmlowp/CMakeLists.txt new file mode 100644 index 000000000..4e4f4b129 --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/CMakeLists.txt @@ -0,0 +1,11 @@ + +SET(CUR_INCS + ${CMAKE_CURRENT_SOURCE_DIR}/fixedpoint + ${CMAKE_CURRENT_SOURCE_DIR}/public +) + +SET(INC_DIRS + ${INC_DIRS} + ${CUR_INCS} + PARENT_SCOPE +) diff --git a/runtimes/nn/depend/external/gemmlowp/fixedpoint/fixedpoint.h b/runtimes/nn/depend/external/gemmlowp/fixedpoint/fixedpoint.h new file mode 100644 index 000000000..e21337f28 --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/fixedpoint/fixedpoint.h @@ -0,0 +1,779 @@ +// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// fixedpoint.h: fixed-point arithmetic, with basic operations and +// a few math functions such as tanh. + +#ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_ +#define GEMMLOWP_INTERNAL_FIXEDPOINT_H_ + +#include <cassert> +#include <limits> + +#include "../internal/common.h" + +namespace gemmlowp { + +// Part 1: Low-level integer-arithmetic primitives. +// The implementations here are generic implementations valid for +// scalar types (e.g. std::int32_t). Architecture-specific SIMD types +// (e.g. NEON int32x4_t) may be supported by providing +// specializations for them in separate files. +// +// The purpose of these primitives is two-fold: +// - They will be used to implement higher-level fixed-point +// abstractions, namely the FixedPoint class and its arithmetic +// operators. +// - They will be directly used to implement some more involved +// fixed-point computations, e.g. the fixed-point implementation +// of math functions such as tanh. + +// Some compile-time traits around raw types to handle SIMD aspects: +// number of lanes, underlying scalar type. +template <typename tIntegerType> +struct FixedPointRawTypeTraits {}; + +template <> +struct FixedPointRawTypeTraits<std::int32_t> { + typedef std::int32_t ScalarRawType; + static const int kLanes = 1; +}; + +// Returns a SIMD value duplicating a scalar value across all lanes. +template <typename tRawType> +tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) { + return x; +} + +// Plain bit-wise AND +template <typename tIntegerType> +tIntegerType BitAnd(tIntegerType a, tIntegerType b) { + return a & b; +} + +// Plain bit-wise OR +template <typename tIntegerType> +tIntegerType BitOr(tIntegerType a, tIntegerType b) { + return a | b; +} + +// Plain bit-wise XOR +template <typename tIntegerType> +tIntegerType BitXor(tIntegerType a, tIntegerType b) { + return a ^ b; +} + +// Plain bit-wise NOT +template <typename tIntegerType> +tIntegerType BitNot(tIntegerType a) { + return ~a; +} + +// Integer addition. Not saturating. Overflow is undefined behavior. +template <typename tIntegerType> +tIntegerType Add(tIntegerType a, tIntegerType b) { + return a + b; +} + +// Integer subtraction. Not saturating. Overflow is undefined behavior. +template <typename tIntegerType> +tIntegerType Mul(tIntegerType a, tIntegerType b) { + return a * b; +} + +template <typename tIntegerType> +tIntegerType Sub(tIntegerType a, tIntegerType b) { + return a - b; +} + +// Integer unary negative. Not saturating. Overflow is undefined behavior. +template <typename tIntegerType> +tIntegerType Neg(tIntegerType a) { + return -a; +} + +// Integer arithmetic left-shift, equivalent to multiplying with a +// power of two. Not saturating. Overflow is undefined behavior. +template <typename tIntegerType> +tIntegerType ShiftLeft(tIntegerType a, int offset) { + return a << offset; +} + +// Integer arithmetic right-shift. Not rounding. +// Relying on implementation-defined, but in-practice-consistent, +// C++ compiler behavior. +template <typename tIntegerType> +tIntegerType ShiftRight(tIntegerType a, int offset) { + return a >> offset; +} + +// Each bit of the result is set to the corresponding bit of either then_val or +// else_val depending on whether the corresponding bit of if_mask is set. +// Equivalent to the VBSL instruction in ARM NEON. +template <typename tIntegerType> +tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val, + tIntegerType else_val) { + return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val)); +} + +// For each input scalar, the corresponding bits of the result are set if the +// input scalar is non-zero. +template <typename tIntegerType> +tIntegerType MaskIfNonZero(tIntegerType a) { + static const tIntegerType zero = 0; + return a ? BitNot(zero) : zero; +} + +// For each input scalar, the corresponding bits of the result are set if the +// input scalar is zero. +template <typename tIntegerType> +tIntegerType MaskIfZero(tIntegerType a) { + return MaskIfNonZero<tIntegerType>(!a); +} + +// For each pair of input scalars, the corresponding bits of the result are +// set if the input scalars are equal. +template <typename tIntegerType> +tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) { + return MaskIfNonZero<tIntegerType>(a == b); +} + +// For each pair of input scalars, the corresponding bits of the result are +// set if the input scalars are not equal. +template <typename tIntegerType> +tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) { + return MaskIfNonZero<tIntegerType>(a != b); +} + +// For each pair of input scalars, the corresponding bits of the result are +// set if the input scalars a, b satisfy a > b. +template <typename tIntegerType> +tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) { + return MaskIfNonZero<tIntegerType>(a > b); +} + +// For each pair of input scalars, the corresponding bits of the result are +// set if the input scalars a, b satisfy a >= b. +template <typename tIntegerType> +tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) { + return MaskIfNonZero<tIntegerType>(a >= b); +} + +// For each pair of input scalars, the corresponding bits of the result are +// set if the input scalars a, b satisfy a < b. +template <typename tIntegerType> +tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) { + return MaskIfNonZero<tIntegerType>(a < b); +} + +// For each pair of input scalars, the corresponding bits of the result are +// set if the input scalars a, b satisfy a <= b. +template <typename tIntegerType> +tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) { + return MaskIfNonZero<tIntegerType>(a <= b); +} + +// Returns true if all of the input scalars are nonzero. +// This function may currently assume that each of the input scalars has either +// all or none of its bits set. Otherwise, its behavior is currently undefined. +template <typename tIntegerType> +bool All(tIntegerType a) { + return a; +} + +// Returns true if any of the input scalars are nonzero. +// This function may currently assume that each of the input scalars has either +// all or none of its bits set. Otherwise, its behavior is currently undefined. +template <typename tIntegerType> +bool Any(tIntegerType a) { + return a; +} + +// Returns (a+b)/2, rounded to the nearest integer. +// Equivalent to VRHADD in the ARM NEON instruction set. +template <typename IntegerType> +IntegerType RoundingHalfSum(IntegerType a, IntegerType b) { + static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); + return a; +} + +template <> +inline std::int32_t RoundingHalfSum(std::int32_t a, std::int32_t b) { + std::int64_t a64 = a; + std::int64_t b64 = b; + std::int64_t sum = a64 + b64; + std::int64_t sign = sum >= 0 ? 1 : -1; + return static_cast<std::int32_t>((sum + sign) / 2); +} + +// Returns the integer that represents the product of two fixed-point +// numbers, interpreting all integers as fixed-point values in the +// interval [-1, 1), rounding to the nearest value, and saturating +// -1 * -1 to the maximum value (since 1 is not in the half-open +// interval [-1, 1)). +// +// [The explanation below specializes to std::int32_t for example purpose.] +// +// The mapping between IntegerType and the interval [-1, 1) is unique and +// implied by IntegerType, which is assumed to be signed. For example, +// for IntegerType==std::int32_t, the mapping is +// real_value = integer_value / 2^31. +// So in this case, and leaving aside rounding and saturating, this +// function computes ((a / 2^31) * (b / 2^31)) * 2^31, which simplifies to +// (a * b) / 2^31. +// +// The 'doubling' part in the name of this function comes from the fact that +// this operation is very close to a "multiply-high" operation, keeping only +// the top half bits, except that that would be effectively computing +// (a * b) / 2^32, +// so here we are computing 2x that, since +// 1/2^31 = 2 * 1/2^32. +// The idea is to use all of the available 32 bits in the destination int32 +// value. +// +// [End of the explanation specializing to int32.] +// +// This is equivalent to the VQRDMULH instruction in ARM NEON. +template <typename IntegerType> +IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) { + static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); + return a; +} + +// This function implements the same computation as the ARMv7 NEON VQRDMULH +// instruction. +template <> +inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a, + std::int32_t b) { + bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min(); + std::int64_t a_64(a); + std::int64_t b_64(b); + std::int64_t ab_64 = a_64 * b_64; + std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30)); + std::int32_t ab_x2_high32 = + static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31)); + return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32; +} + +// Correctly-rounded-to-nearest division by a power-of-two. +// Also known as a rounding arithmetic right shift. +template <typename IntegerType> +inline IntegerType RoundingDivideByPOT(IntegerType x, int exponent) { + using ScalarIntegerType = + typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType; + static_assert(std::is_same<ScalarIntegerType, std::int32_t>::value, + "Currently only supporting int32 scalar and SIMD types"); + assert(exponent >= 0); + assert(exponent <= 31); + const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1); + const IntegerType zero = Dup<IntegerType>(0); + const IntegerType one = Dup<IntegerType>(1); + const IntegerType remainder = BitAnd(x, mask); + const IntegerType threshold = + Add(ShiftRight(mask, 1), BitAnd(MaskIfLessThan(x, zero), one)); + return Add(ShiftRight(x, exponent), + BitAnd(MaskIfGreaterThan(remainder, threshold), one)); +} + +// Returns the product of a run-time integer value by a compile-time power +// of two, with either a positive exponent (equivalent to an arithmetic +// left shift, saturating) or a negative exponent (equivalent to an arithmetic +// right shift, rounding to nearest). +template <int Exponent, typename IntegerType, + int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)> +struct ImplSaturatingRoundingMultiplyByPOT {}; + +template <int Exponent, typename IntegerType> +struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> { + static IntegerType eval(IntegerType x) { return x; } +}; + +template <int Exponent, typename IntegerType> +struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 1> { + static IntegerType eval(IntegerType x) { + using ScalarIntegerType = + typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType; + static_assert(std::is_same<ScalarIntegerType, std::int32_t>::value, + "Currently only supporting int32 scalar and SIMD types"); + const IntegerType min = + Dup<IntegerType>(std::numeric_limits<std::int32_t>::min()); + const IntegerType max = + Dup<IntegerType>(std::numeric_limits<std::int32_t>::max()); + + const std::int32_t threshold = ((1 << (31 - Exponent)) - 1); + const IntegerType positive_mask = + MaskIfGreaterThan(x, Dup<IntegerType>(threshold)); + const IntegerType negative_mask = + MaskIfLessThan(x, Dup<IntegerType>(-threshold)); + + IntegerType result = ShiftLeft(x, Exponent); + result = SelectUsingMask(positive_mask, max, result); + result = SelectUsingMask(negative_mask, min, result); + return result; + } +}; + +template <int Exponent, typename IntegerType> +struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, -1> { + static IntegerType eval(IntegerType x) { + return RoundingDivideByPOT<IntegerType>(x, -Exponent); + } +}; + +template <int Exponent, typename IntegerType> +IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) { + return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x); +} + +// Part 2: the FixedPoint class. + +// A FixedPoint object represents a fixed-point value stored in the underlying +// integer type tRawType, if tRawType is a plain scalar integer type. +// Alternatively, tRawType may be a SIMD type (e.g. NEON int32x4_t) in which +// case a FixedPoint object represents a corresponding SIMD vector of fixed +// point values. +// +// tIntegerBits describes the range of the fixed-point format: if +// tIntegerBits == m then the range of representable values is the half-open +// interval [-2^m; 2^m) where the open boundary on the right side means that +// 2^m is not representable (how close the maximum representable value is to +// it, depends on bit-depth of tRawType). +// +// In "Q format notation", +// https://en.wikipedia.org/wiki/Q_(number_format) +// we are describing the format +// Qm.n +// where +// m = tIntegerBits +// and +// n = NumberOfBits(tRawType) - (m + 1) +// Note that the (m + 1) in the above line is because we adopt the convention +// that we count the integer bits exclusively of the sign bit; so (m + 1) is +// the total number of integer bits inclusive of the sign bit. +// +// Accordingly, the number of integral representable values in our range +// [-2^m ; 2^m) +// is equal to 2^(m+1). +template <typename tRawType, int tIntegerBits> +class FixedPoint { + public: + typedef tRawType RawType; + + typedef FixedPointRawTypeTraits<RawType> RawTypeTraits; + typedef typename RawTypeTraits::ScalarRawType ScalarRawType; + + static const int kTotalBits = 8 * sizeof(ScalarRawType); + static const int kIntegerBits = tIntegerBits; + static const int kFractionalBits = kTotalBits - 1 - kIntegerBits; + static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits, + "bad IntegerBits"); + + typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType; + + static const ScalarRawType ScalarRawMin() { + return std::numeric_limits<ScalarRawType>::min(); + } + + static const ScalarRawType ScalarRawMax() { + return std::numeric_limits<ScalarRawType>::max(); + } + + static const ScalarRawType RawMin() { + return VectorFromScalar(ScalarRawMin()); + } + + static const ScalarRawType RawMax() { + return VectorFromScalar(ScalarRawMax()); + } + + static FixedPoint FromRaw(RawType x) { + FixedPoint retval; + retval.raw() = x; + return retval; + } + + static FixedPoint FromScalarRaw(ScalarRawType x) { + FixedPoint retval; + retval.raw() = Dup<RawType>(x); + return retval; + } + + static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) { + return FromScalarRaw(x.raw()); + } + + template <int Exponent> + static FixedPoint ConstantPOT() { + static const int kOffset = kFractionalBits + Exponent; + static_assert( + kOffset < 31, + "Constant not exactly representable in this fixed-point format"); + return FromScalarRaw(ScalarRawType(1) << kOffset); + } + + static FixedPoint Zero() { return FromScalarRaw(0); } + + static FixedPoint One() { + return FromScalarRaw(kIntegerBits == 0 + ? ScalarRawMax() + : (ScalarRawType(1) << kFractionalBits)); + } + + static FixedPoint FromDouble(double x) { + const double min_bound = static_cast<double>(ScalarRawMin()); + const double max_bound = static_cast<double>(ScalarRawMax()); + return FromScalarRaw(static_cast<std::int32_t>(std::min( + std::max(round(x * static_cast<double>(1ll << kFractionalBits)), + min_bound), + max_bound))); + } + + RawType raw() const { return i_; } + RawType& raw() { return i_; } + + private: + RawType i_; +}; + +// Part 3: implementation of arithmetic operators for the +// FixedPoint class, and a few related functions. + +// A FixedPoint multiplication is just a +// SaturatingRoundingDoublingHighMul operation on the underlying +// raw integer values. The IntegerBits simply add up, as is obvious +// from the fact that the range is [-2^IntegerBits, 2^IntegerBits). +template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b> +FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*( + FixedPoint<tRawType, tIntegerBits_a> a, + FixedPoint<tRawType, tIntegerBits_b> b) { + FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c; + c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw()); + return c; +} + +// Tweaking IntegerBits gives exact multiplication by a power of two. +template <int tExponent, typename tRawType, int tIntegerBits> +FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot( + FixedPoint<tRawType, tIntegerBits> a) { + FixedPoint<tRawType, tExponent + tIntegerBits> c; + c.raw() = a.raw(); + return c; +} + +// If we want to leave IntegerBits fixed, then multiplication +// by a power of two has to be saturating/rounding, not exact anymore. +template <int tExponent, typename tRawType, int tIntegerBits> +FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT( + FixedPoint<tRawType, tIntegerBits> a) { + return FixedPoint<tRawType, tIntegerBits>::FromRaw( + SaturatingRoundingMultiplyByPOT<tExponent>(a.raw())); +} + +// Generic arithmetic operators. + +#define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName) \ + template <typename tRawType, int tIntegerBits> \ + FixedPoint<tRawType, tIntegerBits> FuncName( \ + FixedPoint<tRawType, tIntegerBits> a) { \ + return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \ + } + +#define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \ + template <typename tRawType, int tIntegerBits> \ + FixedPoint<tRawType, tIntegerBits> FuncName( \ + FixedPoint<tRawType, tIntegerBits> a, \ + FixedPoint<tRawType, tIntegerBits> b) { \ + return FixedPoint<tRawType, tIntegerBits>::FromRaw( \ + ImplFuncName(a.raw(), b.raw())); \ + } + +MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg) +MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot) +MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add) +MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub) +MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd) +MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor) +MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr) +MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum) + +#undef MAKE_FIXEDPOINT_UNARY_FUNC +#undef MAKE_FIXEDPOINT_BINARY_FUNC + +#define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName) \ + template <typename tRawType, int tIntegerBits> \ + tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \ + return FuncName(a.raw()); \ + } + +#define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \ + template <typename tRawType, int tIntegerBits> \ + tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a, \ + FixedPoint<tRawType, tIntegerBits> b) { \ + return FuncName(a.raw(), b.raw()); \ + } + +MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero) +MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero) +MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual) +MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual) +MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan) +MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual) +MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan) +MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual) + +#undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW +#undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW + +template <typename tRawType, int tIntegerBits> +FixedPoint<tRawType, tIntegerBits> SelectUsingMask( + tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val, + FixedPoint<tRawType, tIntegerBits> else_val) { + return FixedPoint<tRawType, tIntegerBits>::FromRaw( + SelectUsingMask(if_mask, then_val.raw(), else_val.raw())); +} + +template <typename tRawType, int tIntegerBits> +bool operator==(FixedPoint<tRawType, tIntegerBits> a, + FixedPoint<tRawType, tIntegerBits> b) { + return All(MaskIfEqual(a.raw(), b.raw())); +} + +template <typename tRawType, int tIntegerBits> +bool operator!=(FixedPoint<tRawType, tIntegerBits> a, + FixedPoint<tRawType, tIntegerBits> b) { + return !(a == b); +} + +// Conversion to floating-point. +template <typename tRawType, int tIntegerBits> +double ToDouble(FixedPoint<tRawType, tIntegerBits> x) { + static_assert(FixedPointRawTypeTraits<tRawType>::kLanes == 1, + "not applicable to SIMD types"); + typedef FixedPoint<tRawType, tIntegerBits> F; + return x.raw() / static_cast<double>(1ll << F::kFractionalBits); +} + +// Rescale changes the number of IntegerBits and updates the underlying +// raw integer value accordingly. +template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc> +FixedPoint<tRawType, tIntegerBitsDst> Rescale( + FixedPoint<tRawType, tIntegerBitsSrc> x) { + static const int kExponent = tIntegerBitsSrc - tIntegerBitsDst; + FixedPoint<tRawType, tIntegerBitsDst> result; + result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw()); + return result; +} + +// CheckedFixedPointConstant allows to specify fixed-point constants +// initialized as real numbers, in a way that does not compile floating-point +// arithmetic in production code, yet still checks agreement with the +// floating-point expressions when asserts are enabled. +#ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS +template <typename FixedPointType> +FixedPointType CheckedFixedPointConstant( + typename FixedPointType::ScalarRawType raw_value, double double_value) { + typedef typename FixedPointType::RawType RawType; + const FixedPointType result = FixedPointType::FromScalarRaw(raw_value); + assert(result == FixedPointType::FromDouble(double_value)); + return result; +} +#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawValue, \ + DoubleValue) \ + (CheckedFixedPointConstant<FixedPointType>(ScalarRawValue, DoubleValue)) + +#else +#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawValue, \ + DoubleValue) \ + (FixedPointType::FromScalarRaw(ScalarRawValue)) +#endif + +// Implementation of exponential function. + +// Returns exp(x) for x in [-1/4, 0). +template <typename tRawType> +FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl( + FixedPoint<tRawType, 0> a) { + typedef FixedPoint<tRawType, 0> F; + const F constant_term = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 1895147668, std::exp(-1.0 / 8.0)); + const F constant_1_over_3 = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 715827883, 1.0 / 3.0); + // We're evaluating a Taylor expansion around -1/8, so we do the change of + // variable: x = a + 1/8. + // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28. + F x = a + F::template ConstantPOT<-3>(); + F x2 = x * x; + F x3 = x2 * x; + F x4 = x2 * x2; + F x4_over_4 = SaturatingRoundingMultiplyByPOT<-2>(x4); + F x4_over_24_plus_x3_over_6_plus_x2_over_2 = + SaturatingRoundingMultiplyByPOT<-1>( + ((x4_over_4 + x3) * constant_1_over_3) + x2); + return constant_term + + constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2); +} + +// Returns exp(x) for x < 0. +template <typename tRawType, int tIntegerBits> +FixedPoint<tRawType, 0> exp_on_negative_values( + FixedPoint<tRawType, tIntegerBits> a) { + typedef FixedPoint<tRawType, tIntegerBits> InputF; + typedef FixedPoint<tRawType, 0> ResultF; + static const int kFractionalBits = InputF::kFractionalBits; + static const int kIntegerBits = InputF::kIntegerBits; + static const InputF kOneQuarter = InputF::template ConstantPOT<-2>(); + InputF mask = kOneQuarter - InputF::FromScalarRaw(1); + InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter; + ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl( + Rescale<0>(a_mod_quarter_minus_one_quarter)); + tRawType remainder = (a_mod_quarter_minus_one_quarter - a).raw(); + +#define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier) \ + if (kIntegerBits > Exponent) { \ + const ResultF kMultiplier = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( \ + ResultF, FixedPointMultiplier, std::exp(-std::pow(2.0, Exponent))); \ + static constexpr int kShiftAmount = \ + kIntegerBits > Exponent ? kFractionalBits + Exponent : 0; \ + result = SelectUsingMask( \ + MaskIfNonZero(BitAnd(remainder, Dup<tRawType>(1 << kShiftAmount))), \ + result * kMultiplier, result); \ + } + + GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947); + GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674); + GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084); + GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308); + GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535); + GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401); + GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242); + +#undef GEMMLOWP_EXP_BARREL_SHIFTER + + if (kIntegerBits > 5) { + static const int b = kIntegerBits > 5 ? kFractionalBits + 5 : 0; + const InputF clamp = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << b), -32.0); + result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result); + } + + result = SelectUsingMask(MaskIfZero(a), ResultF::One(), result); + return result; +} + +// Implementation of tanh: (1 - exp(-2x)) / (1 + exp(-2x)). + +// Returns (1 - x) / (1 + x) for x in (0, 1). +template <typename tRawType> +FixedPoint<tRawType, 0> one_minus_x_over_one_plus_x_for_x_in_0_1( + FixedPoint<tRawType, 0> a) { + typedef FixedPoint<tRawType, 0> F0; + typedef FixedPoint<tRawType, 2> F2; + F0 half_denominator = RoundingHalfSum(a, F0::One()); + // Newton-Raphson division + // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division + // Refer to that page for the logic behind the 48/17 and 32/17 constants. + const F2 constant_48_over_17 = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0); + const F2 constant_neg_32_over_17 = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0); + F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17; + for (int i = 0; i < 3; i++) { + F2 half_denominator_times_x = half_denominator * x; + F2 one_minus_half_denominator_times_x = + F2::One() - half_denominator_times_x; + x = x + Rescale<2>(x * one_minus_half_denominator_times_x); + } + return Rescale<0>(x - F2::One()); +} + +// Returns -tanh(x) for x < 0. +template <typename tRawType, int tIntegerBits> +FixedPoint<tRawType, 0> neg_tanh_on_negative_values( + FixedPoint<tRawType, tIntegerBits> a) { + return one_minus_x_over_one_plus_x_for_x_in_0_1( + exp_on_negative_values(ExactMulByPot<1>(a))); +} + +// Returns tanh(x) for any x. +template <typename tRawType, int tIntegerBits> +FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) { + typedef FixedPoint<tRawType, tIntegerBits> InputF; + typedef FixedPoint<tRawType, 0> ResultF; + tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero()); + tRawType mask_if_zero = MaskIfZero(a); + InputF n = SelectUsingMask(mask_if_negative, a, -a); + ResultF t = neg_tanh_on_negative_values(n); + return SelectUsingMask(mask_if_zero, ResultF::Zero(), + SelectUsingMask(mask_if_negative, -t, t)); +} + +// Implementation of logistic function. + +// Returns 1 / (1 + x) for x in (0, 1). +template <typename tRawType> +FixedPoint<tRawType, 0> one_over_one_plus_x_for_x_in_0_1( + FixedPoint<tRawType, 0> a) { + typedef FixedPoint<tRawType, 0> F0; + typedef FixedPoint<tRawType, 2> F2; + F0 half_denominator = RoundingHalfSum(a, F0::One()); + // Newton-Raphson division + // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division + // Refer to that page for the logic behind the 48/17 and 32/17 constants. + const F2 constant_48_over_17 = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0); + const F2 constant_neg_32_over_17 = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0); + F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17; + for (int i = 0; i < 3; i++) { + F2 half_denominator_times_x = half_denominator * x; + F2 one_minus_half_denominator_times_x = + F2::One() - half_denominator_times_x; + x = x + Rescale<2>(x * one_minus_half_denominator_times_x); + } + return Rescale<0>(ExactMulByPot<-1>(x)); +} + +// Returns logistic(x) = 1 / (1 + exp(-x)) for x > 0. +template <typename tRawType, int tIntegerBits> +FixedPoint<tRawType, 0> logistic_on_positive_values( + FixedPoint<tRawType, tIntegerBits> a) { + return one_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(-a)); +} + +// Returns logistic(x) = 1 / (1 + exp(-x)) for any x. +template <typename tRawType, int tIntegerBits> +FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) { + typedef FixedPoint<tRawType, tIntegerBits> InputF; + typedef FixedPoint<tRawType, 0> ResultF; + tRawType mask_if_positive = MaskIfGreaterThan(a, InputF::Zero()); + tRawType mask_if_zero = MaskIfZero(a); + InputF abs_input = SelectUsingMask(mask_if_positive, a, -a); + ResultF result_if_positive = logistic_on_positive_values(abs_input); + ResultF result_if_negative = ResultF::One() - result_if_positive; + const ResultF one_half = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(ResultF, 1 << 30, 0.5); + return SelectUsingMask(mask_if_zero, one_half, + SelectUsingMask(mask_if_positive, result_if_positive, + result_if_negative)); +} + +} // end namespace gemmlowp + +#ifdef GEMMLOWP_NEON +#include "./fixedpoint_neon.h" +#elif defined(GEMMLOWP_SSE4) +#include "./fixedpoint_sse.h" +#endif + +#endif // GEMMLOWP_INTERNAL_FIXEDPOINT_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/fixedpoint/fixedpoint_neon.h b/runtimes/nn/depend/external/gemmlowp/fixedpoint/fixedpoint_neon.h new file mode 100644 index 000000000..8b23de274 --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/fixedpoint/fixedpoint_neon.h @@ -0,0 +1,175 @@ +// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// fixedpoint_neon.h: optimized NEON specializations of the templates +// in fixedpoint.h. + +#ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_NEON_H_ +#define GEMMLOWP_INTERNAL_FIXEDPOINT_NEON_H_ + +#include <arm_neon.h> + +namespace gemmlowp { + +template <> +struct FixedPointRawTypeTraits<int32x4_t> { + typedef std::int32_t ScalarRawType; + static const int kLanes = 4; +}; + +template <> +inline int32x4_t BitAnd(int32x4_t a, int32x4_t b) { + return vandq_s32(a, b); +} + +template <> +inline int32x4_t BitOr(int32x4_t a, int32x4_t b) { + return vorrq_s32(a, b); +} + +template <> +inline int32x4_t BitXor(int32x4_t a, int32x4_t b) { + return veorq_s32(a, b); +} + +template <> +inline int32x4_t BitNot(int32x4_t a) { + return veorq_s32(a, vdupq_n_s32(-1)); +} + +template <> +inline int32x4_t Add(int32x4_t a, int32x4_t b) { + return vaddq_s32(a, b); +} + +template <> +inline int32x4_t Sub(int32x4_t a, int32x4_t b) { + return vsubq_s32(a, b); +} + +template <> +inline int32x4_t Neg(int32x4_t a) { + return vnegq_s32(a); +} + +template <> +inline int32x4_t ShiftLeft(int32x4_t a, int offset) { + return vshlq_s32(a, vdupq_n_s32(offset)); +} + +template <> +inline int32x4_t ShiftRight(int32x4_t a, int offset) { + return vshlq_s32(a, vdupq_n_s32(-offset)); +} + +template <> +inline int32x4_t SelectUsingMask(int32x4_t if_mask, int32x4_t then_val, + int32x4_t else_val) { + return vbslq_s32(vreinterpretq_u32_s32(if_mask), then_val, else_val); +} + +template <> +inline int32x4_t MaskIfEqual(int32x4_t a, int32x4_t b) { + return vreinterpretq_s32_u32(vceqq_s32(a, b)); +} + +template <> +inline int32x4_t MaskIfNotEqual(int32x4_t a, int32x4_t b) { + return BitNot(MaskIfEqual(a, b)); +} + +template <> +inline int32x4_t MaskIfZero(int32x4_t a) { + return MaskIfEqual(a, vdupq_n_s32(0)); +} + +template <> +inline int32x4_t MaskIfNonZero(int32x4_t a) { + return vreinterpretq_s32_u32(vtstq_s32(a, a)); +} + +template <> +inline int32x4_t MaskIfGreaterThan(int32x4_t a, int32x4_t b) { + return vreinterpretq_s32_u32(vcgtq_s32(a, b)); +} + +template <> +inline int32x4_t MaskIfGreaterThanOrEqual(int32x4_t a, int32x4_t b) { + return vreinterpretq_s32_u32(vcgeq_s32(a, b)); +} + +template <> +inline int32x4_t MaskIfLessThan(int32x4_t a, int32x4_t b) { + return vreinterpretq_s32_u32(vcltq_s32(a, b)); +} + +template <> +inline int32x4_t MaskIfLessThanOrEqual(int32x4_t a, int32x4_t b) { + return vreinterpretq_s32_u32(vcleq_s32(a, b)); +} + +template <> +inline bool All(int32x4_t a) { + a = vandq_s32(a, vextq_s32(a, a, 1)); + a = vandq_s32(a, vextq_s32(a, a, 2)); + return vgetq_lane_s32(a, 0); +} + +template <> +inline bool Any(int32x4_t a) { + a = vorrq_s32(a, vextq_s32(a, a, 1)); + a = vorrq_s32(a, vextq_s32(a, a, 2)); + return vgetq_lane_s32(a, 0); +} + +template <> +inline int32x4_t RoundingHalfSum(int32x4_t a, int32x4_t b) { + return vrhaddq_s32(a, b); +} + +template <> +inline int32x4_t SaturatingRoundingDoublingHighMul(int32x4_t a, int32x4_t b) { + return vqrdmulhq_s32(a, b); +} + +template <> +inline int32x4_t RoundingDivideByPOT(int32x4_t x, int exponent) { + const int32x4_t shift_vec = vdupq_n_s32(-exponent); + const int32x4_t fixup = vshrq_n_s32(vandq_s32(x, shift_vec), 31); + const int32x4_t fixed_up_x = vqaddq_s32(x, fixup); + return vrshlq_s32(fixed_up_x, shift_vec); +} + +template <int Exponent> +struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32x4_t, 1> { + static int32x4_t eval(int32x4_t x) { return vqshlq_n_s32(x, Exponent); } +}; + +template <int Exponent> +struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32x4_t, -1> { + static int32x4_t eval(int32x4_t x) { + const int32x4_t fixup = vshrq_n_s32(x, 31); + const int32x4_t fixed_up_x = vqaddq_s32(x, fixup); + return vrshrq_n_s32(fixed_up_x, -Exponent); + } +}; + +template <> +inline int32x4_t Dup<int32x4_t>(std::int32_t x) { + return vdupq_n_s32(x); +} + +} // end namespace gemmlowp + +#endif // GEMMLOWP_INTERNAL_FIXEDPOINT_NEON_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/fixedpoint/fixedpoint_sse.h b/runtimes/nn/depend/external/gemmlowp/fixedpoint/fixedpoint_sse.h new file mode 100644 index 000000000..3f2654d22 --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/fixedpoint/fixedpoint_sse.h @@ -0,0 +1,218 @@ +// Copyright 2015 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// fixedpoint_SSE.h: optimized SSE specializations of the templates +// in fixedpoint.h. + +#ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_SSE_H_ +#define GEMMLOWP_INTERNAL_FIXEDPOINT_SSE_H_ + +#include <smmintrin.h> +#include "fixedpoint.h" + +namespace gemmlowp { + +template <> +struct FixedPointRawTypeTraits<__m128i> { + typedef std::int32_t ScalarRawType; + static const int kLanes = 4; +}; + +template <> +inline __m128i BitAnd(__m128i a, __m128i b) { + return _mm_and_si128(a, b); +} + +template <> +inline __m128i BitOr(__m128i a, __m128i b) { + return _mm_or_si128(a, b); +} + +template <> +inline __m128i BitXor(__m128i a, __m128i b) { + return _mm_xor_si128(a, b); +} + +template <> +inline __m128i BitNot(__m128i a) { + return _mm_andnot_si128(a, _mm_set1_epi32(-1)); +} + +template <> +inline __m128i Add(__m128i a, __m128i b) { + return _mm_add_epi32(a, b); +} + +template <> +inline __m128i Mul(__m128i a, __m128i b) { + return _mm_mullo_epi32(a, b); +} + +template <> +inline __m128i Sub(__m128i a, __m128i b) { + return _mm_sub_epi32(a, b); +} + +template <> +inline __m128i Neg(__m128i a) { + return _mm_sign_epi32(a, _mm_set1_epi32(-1)); +} + +template <> +inline __m128i ShiftLeft(__m128i a, int offset) { + return _mm_slli_epi32(a, offset); +} + +template <> +inline __m128i ShiftRight(__m128i a, int offset) { + return _mm_srai_epi32(a, offset); +} + +template <> +inline __m128i SelectUsingMask(__m128i if_mask, __m128i then_val, + __m128i else_val) { + return _mm_castps_si128(_mm_blendv_ps(_mm_castsi128_ps(else_val), + _mm_castsi128_ps(then_val), + _mm_castsi128_ps(if_mask))); +} + +template <> +inline __m128i MaskIfEqual(__m128i a, __m128i b) { + return _mm_cmpeq_epi32(a, b); +} + +template <> +inline __m128i MaskIfNotEqual(__m128i a, __m128i b) { + return BitNot(MaskIfEqual(a, b)); +} + +template <> +inline __m128i MaskIfZero(__m128i a) { + return MaskIfEqual(a, _mm_set1_epi32(0)); +} + +template <> +inline __m128i MaskIfNonZero(__m128i a) { + return MaskIfNotEqual(a, _mm_set1_epi32(0)); +} + +template <> +inline __m128i MaskIfGreaterThan(__m128i a, __m128i b) { + return _mm_cmpgt_epi32(a, b); +} + +template <> +inline __m128i MaskIfLessThan(__m128i a, __m128i b) { + return _mm_cmplt_epi32(a, b); +} + +template <> +inline __m128i MaskIfGreaterThanOrEqual(__m128i a, __m128i b) { + return BitNot(MaskIfLessThan(a, b)); +} + +template <> +inline __m128i MaskIfLessThanOrEqual(__m128i a, __m128i b) { + return BitNot(MaskIfGreaterThan(a, b)); +} + +/* Assumptions: + - All and Any are used on masks. + - masks are all_ones for true lanes, all_zeroes otherwise. +Hence, All means all 128bits set, and Any means any bit set. +*/ + +template <> +inline bool All(__m128i a) { + return _mm_testc_si128(a, a); +} + +template <> +inline bool Any(__m128i a) { + return BitNot(_mm_testz_si128(a, a)); +} + +template <> +inline __m128i RoundingHalfSum(__m128i a, __m128i b) { + /* __m128i round_bit_mask, a_over_2, b_over_2, round_bit, sum; */ + /* We divide the inputs before the add to avoid the overflow and costly test + */ + /* of checking if an overflow occured on signed add */ + /* round_bit_mask = _mm_set1_epi32(1); */ + /* a_over_2 = _mm_srai_epi32(a, 1); */ + /* b_over_2 = _mm_srai_epi32(b, 1); */ + /* sum = Add(a_over_2, b_over_2); */ + /* round_bit = _mm_sign_epi32(BitAnd(BitOr(a,b), round_bit_mask), sum); */ + /* return Add(sum, round_bit); */ + + /* Other possibility detecting overflow and xor the sign if an overflow + * happened*/ + __m128i one, sign_bit_mask, sum, rounded_half_sum, overflow, result; + one = _mm_set1_epi32(1); + sign_bit_mask = _mm_set1_epi32(0x80000000); + sum = Add(a, b); + rounded_half_sum = _mm_srai_epi32(Add(sum, one), 1); + overflow = + BitAnd(BitAnd(BitXor(a, rounded_half_sum), BitXor(b, rounded_half_sum)), + sign_bit_mask); + result = BitXor(rounded_half_sum, overflow); + return result; +} + +template <> +inline __m128i SaturatingRoundingDoublingHighMul(__m128i a, __m128i b) { + __m128i min, saturation_mask, a0_a2, a1_a3, b0_b2, b1_b3; + __m128i a0b0_a2b2, a1b1_a3b3, a0b0_a2b2_rounded, a1b1_a3b3_rounded; + __m128i a0b0_a2b2_rounded_2x, a1b1_a3b3_rounded_2x, result; + __m128i nudge; + + // saturation only happen if a == b == INT_MIN + min = _mm_set1_epi32(std::numeric_limits<std::int32_t>::min()); + saturation_mask = BitAnd(MaskIfEqual(a, b), MaskIfEqual(a, min)); + + // a = a0 | a1 | a2 | a3 + // b = b0 | b1 | b2 | b3 + a0_a2 = a; + a1_a3 = _mm_srli_si128(a, 4); + b0_b2 = b; + b1_b3 = _mm_srli_si128(b, 4); + + a0b0_a2b2 = _mm_mul_epi32(a0_a2, b0_b2); + a1b1_a3b3 = _mm_mul_epi32(a1_a3, b1_b3); + + // do the rounding and take into account that it will be doubled + nudge = _mm_set1_epi64x(1 << 30); + a0b0_a2b2_rounded = _mm_add_epi64(a0b0_a2b2, nudge); + a1b1_a3b3_rounded = _mm_add_epi64(a1b1_a3b3, nudge); + + // do the doubling + a0b0_a2b2_rounded_2x = _mm_slli_epi64(a0b0_a2b2_rounded, 1); + a1b1_a3b3_rounded_2x = _mm_slli_epi64(a1b1_a3b3_rounded, 1); + + // get the high part of the products + result = _mm_blend_epi16(_mm_srli_si128(a0b0_a2b2_rounded_2x, 4), + a1b1_a3b3_rounded_2x, 0xcc); + + // saturate those which overflowed + return SelectUsingMask(saturation_mask, min, result); +} + +template <> +inline __m128i Dup<__m128i>(std::int32_t x) { + return _mm_set1_epi32(x); +} + +} // end namespace gemmlowp + +#endif // GEMMLOWP_INTERNAL_FIXEDPOINT_SSE_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/internal/allocator.h b/runtimes/nn/depend/external/gemmlowp/internal/allocator.h new file mode 100644 index 000000000..da325a4c4 --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/internal/allocator.h @@ -0,0 +1,220 @@ +// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// allocator.h: a buffer allocator that allows avoiding most of the +// malloc/free overhead, by: +// 1. Requiring all N allocations to be reserved in advance, and +// then commited at once, turning N allocations into 1. +// 2. Being persistent, the allocated storage is reused across commits, +// and only reallocated as needed when the commit size gets larger. +// +// This is driven by Android-specific needs: +// 1. On Android, the default (Bionic) allocator tends to aggressively +// unmap pages, which means that malloc/free can be surprisingly expensive. +// 2. On Android, stack allocations with alloca() can't be as large as on +// desktop platforms. +// +// General usage: +// 1. Reserve blocks by calling Reserve(), which returns a Handle. +// 2. Call Commit() once. +// 3. Now it is possible to get pointers to allocated buffers by calling +// GetPointer(). +// 4. Call Decommit() once. +// 5. The allocator is now reverted to its original state, except that +// it retained its allocated storage, so the next Commit() will be faster. +// The allocated storage is only freed when the Allocator object is +// destroyed. + +#ifndef GEMMLOWP_INTERNAL_ALLOCATOR_H_ +#define GEMMLOWP_INTERNAL_ALLOCATOR_H_ + +#include "common.h" + +#if defined(__ANDROID__) +#include <android/api-level.h> +// The 18 here should be 16, but has to be 18 for now due +// to a Google-internal issue. +#if __ANDROID_API__ < 18 +#include <malloc.h> +#define GEMMLOWP_USE_MEMALIGN +#endif +// posix_memalign is missing on some 4.1 x86 devices +#if __ANDROID_API__ == 18 +#ifdef GEMMLOWP_X86_32 +#include <malloc.h> +#define GEMMLOWP_USE_MEMALIGN +#endif +#endif +#endif + +namespace gemmlowp { + +enum class TypeId : std::uint8_t { Uint8, Int8, Uint16, Int16, Uint32, Int32 }; + +template <typename T> +struct GetTypeIdImpl {}; + +template <typename T> +inline TypeId GetTypeId() { + return GetTypeIdImpl<T>::Value; +} + +template <typename T> +struct GetTypeIdImpl<const T> : GetTypeIdImpl<T> {}; + +#define GEMMLOWP_REGISTER_TYPEID(type_, id) \ + template <> \ + struct GetTypeIdImpl<type_> { \ + static const TypeId Value = TypeId::id; \ + }; + +GEMMLOWP_REGISTER_TYPEID(std::uint8_t, Uint8) +GEMMLOWP_REGISTER_TYPEID(std::int8_t, Int8) +GEMMLOWP_REGISTER_TYPEID(std::uint16_t, Uint16) +GEMMLOWP_REGISTER_TYPEID(std::int16_t, Int16) +GEMMLOWP_REGISTER_TYPEID(std::uint32_t, Uint32) +GEMMLOWP_REGISTER_TYPEID(std::int32_t, Int32) + +class Allocator { + public: + Allocator() + : committed_(false), + storage_size_(0), + storage_(nullptr), + reserved_blocks_(0), + reserved_bytes_(0), + generation_(0) {} + + ~Allocator() { + assert(!committed_); + assert(!reserved_blocks_); + DeallocateStorage(); + } + + // Alignment of allocated blocks. + static const std::size_t kAlignment = kDefaultCacheLineSize; + + // This is all we need so far, and since the usage pattern is fixed, + // there is no point in allowing more until we need to. + static const std::size_t kMaxBlocks = 5; + + void Commit() { + assert(!committed_); + + if (reserved_bytes_ > storage_size_) { + DeallocateStorage(); + storage_size_ = RoundUpToPowerOfTwo(reserved_bytes_); +#ifdef GEMMLOWP_USE_MEMALIGN + storage_ = memalign(kAlignment, storage_size_); +#else + if (posix_memalign(&storage_, kAlignment, storage_size_)) { + storage_ = nullptr; + } +#endif + } + + ReleaseBuildAssertion(!storage_size_ || storage_, "allocation failure"); + committed_ = true; + } + + void Decommit() { + assert(committed_); + committed_ = false; + generation_++; + + reserved_blocks_ = 0; + reserved_bytes_ = 0; + } + + // See generation_ + typedef std::size_t generation_t; + + // A handle on a reserved block. The user obtains + // one by calling Reserve() and, after committing, + // passes it to GetPointer(). + class Handle { + std::uint8_t index_; + generation_t generation_; + TypeId type_; + + friend class Allocator; + }; + + // Reserves a block sized for n elements of type T, and + // returns a handle to it. Must be called before committing. + template <typename T> + Handle Reserve(std::size_t n) { + assert(!committed_ && "can't reserve blocks while committed"); + assert(reserved_blocks_ < kMaxBlocks && + "didn't expect to allocate this many blocks"); + const std::size_t bytes = RoundUp<kAlignment>(n * sizeof(T)); + const std::size_t offset = reserved_bytes_; + const std::size_t index = reserved_blocks_; + + reserved_blocks_offsets_[index] = offset; + Handle h; + h.index_ = index; + h.generation_ = generation_; + h.type_ = GetTypeId<T>(); + + reserved_blocks_++; + reserved_bytes_ += bytes; + + return h; + } + + // Returns the pointer to the allocated buffer for the given handle. + // Must be called after committing. + template <typename T> + T* GetPointer(const Handle& h) const { + assert(committed_ && "can't get block pointers unless committed"); + assert(h.index_ < reserved_blocks_ && + "bad handle, points to inexistant block"); + assert(h.generation_ == generation_ && + "handle from earlier generation, have decommitted since"); + assert(h.type_ == GetTypeId<T>() && "type mismatch"); + std::size_t offset = reserved_blocks_offsets_[h.index_]; + std::uintptr_t addr = reinterpret_cast<std::uintptr_t>(storage_) + offset; + return reinterpret_cast<T*>(addr); + } + + private: + void DeallocateStorage() { + assert(!committed_); + free(storage_); + storage_size_ = 0; + } + + // Set to true by Commit() and to false by Decommit(). Initially false. + bool committed_; + + // The actually allocated storage size and buffer pointer. + std::size_t storage_size_; + mutable void* storage_; + + // The number of blocks that have been reserved by Reserve(). + std::size_t reserved_blocks_; + // The number of bytes that have been reserved by Reserve(). + std::size_t reserved_bytes_; + // The offsets of reserved blocks into the storage buffer. + std::size_t reserved_blocks_offsets_[kMaxBlocks]; + + // The 'generation' is incremented on Decommit() and allows catching + // bad GetPointer() calls still referring to a previous commit. + generation_t generation_; +}; + +} // namespace gemmlowp + +#endif // GEMMLOWP_INTERNAL_ALLOCATOR_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/internal/block_params.h b/runtimes/nn/depend/external/gemmlowp/internal/block_params.h new file mode 100644 index 000000000..b2fc3ff78 --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/internal/block_params.h @@ -0,0 +1,174 @@ +// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// block_params.h: Logic to choose L1 and L2 block sizes +// to optimize cache-friendliness. + +#ifndef GEMMLOWP_INTERNAL_BLOCK_PARAMS_H_ +#define GEMMLOWP_INTERNAL_BLOCK_PARAMS_H_ + +#include "common.h" + +namespace gemmlowp { + +// A BlockParams instance contains a full description of all the block size +// parameters to be used by a Gemm. +// There are two nested levels of block subdivisions: first a subdivision +// into large blocks that should fit in last-level cache (what we call L2 here) +// and then another subdivision into smaller blocks that should fit in +// L1 cache. There is then actually a third level of subdivision to fit +// in registers, but we are not concerned with that here. +struct BlockParams { + // L1 block parameters determine the size of small blocks that should + // fit in L1 cache. + int l1_rows; + int l1_cols; + int l1_depth; + + // L2 block parameters determine the size of larger blocks that should + // fit in L2 cache. + int l2_rows; + int l2_cols; + int l2_depth; + + template <typename KernelFormat> + void Init(int rows, int cols, int depth, int num_threads, + int l1_bytes_to_use, int l2_bytes_to_use, float l2_rhs_factor) { + FindL2BlockSizes<KernelFormat>(rows, cols, depth, num_threads, + l2_bytes_to_use, l2_rhs_factor, + &l2_rows, &l2_cols, &l2_depth); + FindL1BlockSizes<KernelFormat>(l2_rows, l2_cols, l2_depth, + l1_bytes_to_use, + &l1_rows, &l1_cols, &l1_depth); + } + + template <typename KernelFormat> + static void FindL2BlockSizes(int rows, int cols, int depth, int num_threads, + int l2_bytes_to_use, float l2_rhs_factor, + int* out_l2_rows, int* out_l2_cols, + int* out_l2_depth) { + int l2_rows = 0; + int l2_cols = 0; + int l2_depth = 0; + // No L2 blocking in the depth dimension at the moment. + // Too much loss of accuracy due to storing intermediate results in + // low precision. + // However, we still want to round l2_depth up to the next multiple + // of register size, so as to avoid having to special-case unaligned depths. + l2_depth = RoundUp<kRegisterSize>(depth); + + { + int max_cache_friendly_l2_cols = std::max( + 1, static_cast<int>(l2_rhs_factor * (l2_bytes_to_use / l2_depth))); + int min_l2_cols_blocks = + std::max(1, CeilQuotient(cols, max_cache_friendly_l2_cols)); + l2_cols = + RoundUp<KernelFormat::kCols>(CeilQuotient(cols, min_l2_cols_blocks)); + } + + // No L2 blocking in the row dimension if l2_rhs_factor is 1.0 as the row + // dimension concerns only the LHS. Blocking only RHS matrix for L2 enhances + // the performance on x86. + if (l2_rhs_factor == 1.0f) { + l2_rows = RoundUp<KernelFormat::kRows>(rows); + } else { + int max_cache_friendly_l2_rows = + std::max(1, (l2_bytes_to_use - l2_depth * l2_cols) / + (num_threads * (l2_depth + 4 * l2_cols))); + int min_l2_rows_blocks = + std::max(1, CeilQuotient(rows, max_cache_friendly_l2_rows)); + l2_rows = + RoundUp<KernelFormat::kRows>(CeilQuotient(rows, min_l2_rows_blocks)); + } + + *out_l2_rows = l2_rows; + *out_l2_cols = l2_cols; + *out_l2_depth = l2_depth; + } + + template <typename KernelFormat> + static void FindL1BlockSizes(int rows, int cols, int depth, + int l1_bytes_to_use, int* out_l1_rows, + int* out_l1_cols, int* out_l1_depth) { + int l1_rows = 0; + int l1_cols = 0; + int l1_depth = 0; + + // L2 block sizes should already be multiples of kernel block sizes. + assert(rows % KernelFormat::kRows == 0); + assert(cols % KernelFormat::kCols == 0); + assert(depth % KernelFormat::kDepth == 0); + + // No L1 blocking in the columns dimension at the moment. + // Thought not to be needed. Similar to Eigen. + l1_cols = cols; + + { + int max_cache_friendly_l1_depth = std::max( + 1, (l1_bytes_to_use - 4 * KernelFormat::kRows * KernelFormat::kCols) / + (KernelFormat::kRows + KernelFormat::kCols)); + int min_l1_depth_blocks = + std::max(1, CeilQuotient(depth, max_cache_friendly_l1_depth)); + l1_depth = + RoundUp<kRegisterSize>(CeilQuotient(depth, min_l1_depth_blocks)); + } + + { + int max_cache_friendly_l1_rows = + std::max(1, l1_bytes_to_use / (l1_depth + 4 * l1_cols)); + int min_l1_rows_blocks = + std::max(1, CeilQuotient(rows, max_cache_friendly_l1_rows)); + l1_rows = + RoundUp<KernelFormat::kRows>(CeilQuotient(rows, min_l1_rows_blocks)); + } + + *out_l1_rows = l1_rows; + *out_l1_cols = l1_cols; + *out_l1_depth = l1_depth; + } +}; + +// A SideBlockParams instance contains only the block params relevant to +// one side (LHS or RHS), expressed in terms of 'width' instead of +// rows/colums. See the explanation in kernel.h: in the LHS, 'width' means +// the number of rows, while in the RHS, 'width' means the number of columns. +// That allows us to write generic code that applies to either LHS or RHS. +struct SideBlockParams { + // L1 block parameters determine the size of small blocks that should + // fit in L1 cache. + int l1_width; + int l1_depth; + + // L2 block parameters determine the size of larger blocks that should + // fit in L2 cache. + int l2_width; + int l2_depth; +}; + +enum class Side { Lhs, Rhs }; + +inline void GetSideBlockParams(Side side, SideBlockParams* side_block_params, + const BlockParams& block_params) { + side_block_params->l1_width = + side == Side::Lhs ? block_params.l1_rows : block_params.l1_cols; + side_block_params->l2_width = + side == Side::Lhs ? block_params.l2_rows : block_params.l2_cols; + + side_block_params->l1_depth = block_params.l1_depth; + side_block_params->l2_depth = block_params.l2_depth; +} + +} // namespace gemmlowp + +#endif // GEMMLOWP_INTERNAL_BLOCK_PARAMS_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/internal/common.h b/runtimes/nn/depend/external/gemmlowp/internal/common.h new file mode 100644 index 000000000..511809d28 --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/internal/common.h @@ -0,0 +1,256 @@ +// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// common.h: contains stuff that's used throughout gemmlowp +// and should always be available. + +#ifndef GEMMLOWP_INTERNAL_COMMON_H_ +#define GEMMLOWP_INTERNAL_COMMON_H_ + +#include <pthread.h> + +#include <algorithm> +#include <cassert> +#include <cmath> +#include <cstdlib> + +#include "../profiling/instrumentation.h" + +// Our inline assembly path assume GCC/Clang syntax. +// Native Client doesn't seem to support inline assembly(?). +#if defined(__GNUC__) && !defined(__native_client__) +#define GEMMLOWP_ALLOW_INLINE_ASM +#endif + +// Define macro statement that avoids inlining for GCC. +// For non-GCC, define as empty macro. +#if defined(__GNUC__) +#define GEMMLOWP_NOINLINE __attribute__((noinline)) +#else +#define GEMMLOWP_NOINLINE +#endif + +// Detect ARM, 32-bit or 64-bit +#ifdef __arm__ +#define GEMMLOWP_ARM_32 +#endif + +#ifdef __aarch64__ +#define GEMMLOWP_ARM_64 +#endif + +#if defined(GEMMLOWP_ARM_32) || defined(GEMMLOWP_ARM_64) +#define GEMMLOWP_ARM +#endif + +// Detect x86, 32-bit or 64-bit +#if defined(__i386__) || defined(_M_IX86) || defined(_X86_) || defined(__i386) +#define GEMMLOWP_X86_32 +#endif + +#if defined(__x86_64__) || defined(_M_X64) || defined(__amd64) +#define GEMMLOWP_X86_64 +#endif + +#if defined(GEMMLOWP_X86_32) || defined(GEMMLOWP_X86_64) +#define GEMMLOWP_X86 +#endif + +// Some of our optimized paths use inline assembly and for +// now we don't bother enabling some other optimized paths using intrinddics +// where we can't use inline assembly paths. +#ifdef GEMMLOWP_ALLOW_INLINE_ASM + +// Detect NEON. It's important to check for both tokens. +#if (defined __ARM_NEON) || (defined __ARM_NEON__) +#define GEMMLOWP_NEON +#endif + +// Convenience NEON tokens for 32-bit or 64-bit +#if defined(GEMMLOWP_NEON) && defined(GEMMLOWP_ARM_32) +#define GEMMLOWP_NEON_32 +#endif + +#if defined(GEMMLOWP_NEON) && defined(GEMMLOWP_ARM_64) +#define GEMMLOWP_NEON_64 +#endif + +// Detect SSE. +#ifdef __SSE4_1__ +#define GEMMLOWP_SSE4 +#endif + +#ifdef __SSE3__ +#define GEMMLOWP_SSE3 +#endif + +// Convenience SSE4 tokens for 32-bit or 64-bit +#if defined(GEMMLOWP_SSE4) && defined(GEMMLOWP_X86_32) +#define GEMMLOWP_SSE4_32 +#endif + +#if defined(GEMMLOWP_SSE3) && defined(GEMMLOWP_X86_32) +#define GEMMLOWP_SSE3_32 +#endif + +#if defined(GEMMLOWP_SSE4) && defined(GEMMLOWP_X86_64) +#define GEMMLOWP_SSE4_64 +#endif + +#if defined(GEMMLOWP_SSE3) && defined(GEMMLOWP_X86_64) +#define GEMMLOWP_SSE3_64 +#endif + +#endif // GEMMLOWP_ALLOW_INLINE_ASM + +// Detect Android. Don't conflate with ARM - we care about tuning +// for non-ARM Android devices too. This can be used in conjunction +// with x86 to tune differently for mobile x86 CPUs (Atom) vs. desktop x86 CPUs. +#if defined(__ANDROID__) +#define GEMMLOWP_ANDROID +#endif + +namespace gemmlowp { + +// Standard cache line size. Useful to optimize alignment and +// prefetches. Ideally we would query this at runtime, however +// 64 byte cache lines are the vast majority, and even if it's +// wrong on some device, it will be wrong by no more than a 2x factor, +// which should be acceptable. +const int kDefaultCacheLineSize = 64; + +// Default L1 and L2 data cache sizes. +// The L1 cache size is assumed to be for each core. +// The L2 cache size is assumed to be shared among all cores. What +// we call 'L2' here is effectively top-level cache. +// +// On x86, we should ideally query this at +// runtime. On ARM, the instruction to query this is privileged and +// Android kernels do not expose it to userspace. Fortunately, the majority +// of ARM devices have roughly comparable values: +// Nexus 5: L1 16k, L2 1M +// Android One: L1 32k, L2 512k +// The following values are equal to or somewhat lower than that, and were +// found to perform well on both the Nexus 5 and Android One. +// Of course, these values are in principle too low for typical x86 CPUs +// where we should set the L2 value to (L3 cache size / number of cores) at +// least. +// +#if defined(GEMMLOWP_ARM) && defined(__APPLE__) +// iPhone/iPad +const int kDefaultL1CacheSize = 48 * 1024; +const int kDefaultL2CacheSize = 2 * 1024 * 1024; +#elif defined(GEMMLOWP_ARM) || defined(GEMMLOWP_ANDROID) +// Other ARM or ARM-like hardware (Android implies ARM-like) so here it's OK +// to tune for ARM, although on x86 Atom we might be able to query +// cache sizes at runtime, which would be better. +const int kDefaultL1CacheSize = 16 * 1024; +const int kDefaultL2CacheSize = 384 * 1024; +#elif defined(GEMMLOWP_X86_64) +// x86-64 and not Android. Therefore, likely desktop-class x86 hardware. +// Thus we assume larger cache sizes, though we really should query +// them at runtime. +const int kDefaultL1CacheSize = 32 * 1024; +const int kDefaultL2CacheSize = 4 * 1024 * 1024; +#elif defined(GEMMLOWP_X86_32) +// x86-32 and not Android. Same as x86-64 but less bullish. +const int kDefaultL1CacheSize = 32 * 1024; +const int kDefaultL2CacheSize = 2 * 1024 * 1024; +#else +// Less common hardware. Maybe some unusual or older or embedded thing. +// Assume smaller caches, but don't depart too far from what we do +// on ARM/Android to avoid accidentally exposing unexpected behavior. +const int kDefaultL1CacheSize = 16 * 1024; +const int kDefaultL2CacheSize = 256 * 1024; +#endif + +// The proportion of the cache that we intend to use for storing +// RHS blocks. This should be between 0 and 1, and typically closer to 1, +// as we typically want to use most of the L2 cache for storing a large +// RHS block. +#if defined(GEMMLOWP_X86) +// For IA, use the entire L2 cache for the RHS matrix. LHS matrix is not blocked +// for L2 cache. +const float kDefaultL2RhsFactor = 1.00f; +#else +const float kDefaultL2RhsFactor = 0.75f; +#endif + +// The number of bytes in a SIMD register. This is used to determine +// the dimensions of PackingRegisterBlock so that such blocks can +// be efficiently loaded into registers, so that packing code can +// work within registers as much as possible. +// In the non-SIMD generic fallback code, this is just a generic array +// size, so any size would work there. Different platforms may set this +// to different values but must ensure that their own optimized packing paths +// are consistent with this value. +const int kRegisterSize = 16; + +// Hints the CPU to prefetch the cache line containing ptr. +inline void Prefetch(const void* ptr) { +#if defined GEMMLOWP_ARM_64 && defined GEMMLOWP_ALLOW_INLINE_ASM + // Aarch64 has very detailed prefetch instructions, that compilers + // can't know how to map __builtin_prefetch to, and as a result, don't, + // leaving __builtin_prefetch a no-op on this architecture. + // For our purposes, "pldl1keep" is usually what we want, meaning: + // "prefetch for load, into L1 cache, using each value multiple times". + asm volatile("prfm pldl1keep, [%[ptr]]\n" ::[ptr] "r"(ptr) : ); +#elif defined \ + __GNUC__ // Clang and GCC define __GNUC__ and have __builtin_prefetch. + __builtin_prefetch(ptr); +#else + (void)ptr; +#endif +} + +// Returns the runtime argument rounded down to the nearest multiple of +// the fixed Modulus. +template <unsigned Modulus, typename Integer> +Integer RoundDown(Integer i) { + return i - (i % Modulus); +} + +// Returns the runtime argument rounded up to the nearest multiple of +// the fixed Modulus. +template <unsigned Modulus, typename Integer> +Integer RoundUp(Integer i) { + return RoundDown<Modulus>(i + Modulus - 1); +} + +// Returns the quotient a / b rounded up ('ceil') to the nearest integer. +template <typename Integer> +Integer CeilQuotient(Integer a, Integer b) { + return (a + b - 1) / b; +} + +// Returns the argument rounded up to the nearest power of two. +template <typename Integer> +Integer RoundUpToPowerOfTwo(Integer n) { + Integer i = n - 1; + i |= i >> 1; + i |= i >> 2; + i |= i >> 4; + i |= i >> 8; + i |= i >> 16; + return i + 1; +} + +template <int N> +struct IsPowerOfTwo { + static const bool value = !(N & (N - 1)); +}; + +} // namespace gemmlowp + +#endif // GEMMLOWP_INTERNAL_COMMON_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/internal/compute.h b/runtimes/nn/depend/external/gemmlowp/internal/compute.h new file mode 100644 index 000000000..bbc9e2a0e --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/internal/compute.h @@ -0,0 +1,104 @@ +// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// compute.h: the central stage of the Gemm computation, operates +// on already-packed LHS and RHS blocks and calls the Gemm kernel +// to compute a block of the product. + +#ifndef GEMMLOWP_INTERNAL_COMPUTE_H_ +#define GEMMLOWP_INTERNAL_COMPUTE_H_ + +#include "block_params.h" +#include "kernel.h" +#include "pack.h" + +namespace gemmlowp { + +template <typename PackedLhs, typename PackedRhs, typename PackedResult> +class ComputeImpl { + typedef typename PackedLhs::KernelSideFormat KernelLhsFormat; + typedef typename PackedRhs::KernelSideFormat KernelRhsFormat; + typedef KernelFormat<KernelLhsFormat, KernelRhsFormat> Format; + + const KernelBase& kernel_; + const BlockParams& block_params_; + + PackedResult* const packed_result_; + const PackedLhs& packed_lhs_; + const PackedRhs& packed_rhs_; + + public: + ComputeImpl(const KernelBase& _kernel, const BlockParams& _block_params, + PackedResult* _packed_result, const PackedLhs& _packed_lhs, + const PackedRhs& _packed_rhs) + : kernel_(_kernel), + block_params_(_block_params), + packed_result_(_packed_result), + packed_lhs_(_packed_lhs), + packed_rhs_(_packed_rhs) {} + + void Compute(int depth) { + depth = RoundUp<Format::kDepth>(depth); + assert(depth <= block_params_.l2_depth); + for (int d = 0; d < depth; d += block_params_.l1_depth) { + int ds = std::min(block_params_.l1_depth, depth - d); + + for (int r = 0; r < block_params_.l2_rows; r += block_params_.l1_rows) { + int rs = std::min(block_params_.l1_rows, block_params_.l2_rows - r); + + ComputeL1(r, rs, 0, block_params_.l2_cols, d, ds); + } + } + } + + private: + void ComputeRun(int start_row, int start_col, int start_depth, + int depth) GEMMLOWP_NOINLINE { + packed_lhs_.seek_run(start_row, start_depth); + packed_rhs_.seek_run(start_col, start_depth); + auto packed_result_block = packed_result_->Map().block( + start_row, start_col, Format::kRows, Format::kCols); + kernel_.Run(packed_result_block.data(), packed_result_block.rows_stride(), + packed_result_block.cols_stride(), packed_lhs_.current_data(), + packed_rhs_.current_data(), start_depth, depth); + } + + void ComputeL1(int start_row, int rows, int start_col, int cols, + int start_depth, int depth) { + assert(rows % Format::kRows == 0); + assert(cols % Format::kCols == 0); + assert(depth % Format::kDepth == 0); + + for (int c = 0; c < cols; c += Format::kCols) { + for (int r = 0; r < rows; r += Format::kRows) { + ComputeRun(start_row + r, start_col + c, start_depth, depth); + } + } + } +}; + +template <typename PackedLhs, typename PackedRhs, typename PackedResult> +void Compute(const KernelBase& kernel, const BlockParams& block_params, + PackedResult* packed_result, const PackedLhs& packed_lhs, + const PackedRhs& packed_rhs, int depth) { + ScopedProfilingLabel label("compute"); + ComputeImpl<PackedLhs, PackedRhs, PackedResult> impl( + kernel, block_params, packed_result, packed_lhs, packed_rhs); + + impl.Compute(depth); +} + +} // namespace gemmlowp + +#endif // GEMMLOWP_INTERNAL_COMPUTE_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/internal/dispatch_gemm_shape.h b/runtimes/nn/depend/external/gemmlowp/internal/dispatch_gemm_shape.h new file mode 100644 index 000000000..0be0bf360 --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/internal/dispatch_gemm_shape.h @@ -0,0 +1,189 @@ +// Copyright 2017 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// dispatch_gemm_shape.h: dispatch GEMM calls according to their shape + +#ifndef GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_ +#define GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_ + +#include "../internal/kernel_default.h" +#include "../public/map.h" +#include "../public/output_stages.h" +#include "multi_thread_gemm.h" + +namespace gemmlowp { + +template <typename T> +struct TransposeImpl { + typedef T DstType; + static T Run(const T& t) { return t; } +}; + +template <typename T> +using TransposeType = typename TransposeImpl<T>::DstType; + +template <typename T> +TransposeType<T> Transpose(const T& t) { + return TransposeImpl<T>::Run(t); +} + +template <MapOrder Order> +struct TransposeMapOrder { + static constexpr MapOrder Value = + Order == MapOrder::RowMajor ? MapOrder::ColMajor : MapOrder::RowMajor; +}; + +template <VectorShape Shape> +struct TransposeVectorShape { + static constexpr VectorShape Value = + Shape == VectorShape::Row ? VectorShape::Col : VectorShape::Row; +}; + +template <typename Scalar, VectorShape Shape> +struct TransposeImpl<VectorMap<Scalar, Shape>> { + typedef VectorMap<Scalar, Shape> SrcType; + static constexpr VectorShape TransposedShape = + TransposeVectorShape<Shape>::Value; + typedef VectorMap<Scalar, TransposedShape> DstType; + static DstType Run(const SrcType& src) { + return DstType(src.data(), src.size()); + } +}; + +template <typename Scalar, MapOrder Order> +struct TransposeImpl<MatrixMap<Scalar, Order>> { + typedef MatrixMap<Scalar, Order> SrcType; + static constexpr MapOrder TransposedOrder = TransposeMapOrder<Order>::Value; + typedef MatrixMap<Scalar, TransposedOrder> DstType; + static DstType Run(const SrcType& src) { + return DstType(src.data(), src.cols(), src.rows(), src.stride()); + } +}; + +template <VectorShape Shape> +struct TransposeImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>> { + typedef OutputStageQuantizeDownInt32ToUint8ScalePC<Shape> SrcType; + static const VectorShape TransposedShape = TransposeVectorShape<Shape>::Value; + typedef OutputStageQuantizeDownInt32ToUint8ScalePC<TransposedShape> DstType; + static DstType Run(const SrcType& src) { + DstType dst; + dst.result_shift = src.result_shift; + dst.result_offset = Transpose(src.result_offset); + dst.result_mult_int = Transpose(src.result_mult_int); + return dst; + } +}; + +template <typename VectorMapType> +struct TransposeImpl<OutputStageBiasAddition<VectorMapType>> { + typedef OutputStageBiasAddition<VectorMapType> SrcType; + typedef TransposeType<VectorMapType> TransposedVectorMapType; + typedef OutputStageBiasAddition<TransposedVectorMapType> DstType; + static DstType Run(const SrcType& src) { + DstType dst; + dst.bias_vector = Transpose(src.bias_vector); + return dst; + } +}; + +// TODO(benoitjacob) - does anyone understand C++ variadic templates? +// How to use them to implement TransposeTuple? Note: there are lots +// of answers on StackOverflow but they seem to all involve either +// C++14/C++17 (we can only use C++11) or lots of abstract nonsense. +inline std::tuple<> TransposeTuple(const std::tuple<>& t) { return t; } + +template <typename T0> +std::tuple<TransposeType<T0>> TransposeTuple(const std::tuple<T0>& t) { + return std::make_tuple(Transpose(std::get<0>(t))); +} + +template <typename T0, typename T1> +std::tuple<TransposeType<T0>, TransposeType<T1>> TransposeTuple( + const std::tuple<T0, T1>& t) { + return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t))); +} + +template <typename T0, typename T1, typename T2> +std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>> +TransposeTuple(const std::tuple<T0, T1, T2>& t) { + return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)), + Transpose(std::get<2>(t))); +} + +template <typename T0, typename T1, typename T2, typename T3> +std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>, + TransposeType<T3>> +TransposeTuple(const std::tuple<T0, T1, T2, T3>& t) { + return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)), + Transpose(std::get<2>(t)), Transpose(std::get<3>(t))); +} + +template <typename T0, typename T1, typename T2, typename T3, typename T4> +std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>, + TransposeType<T3>, TransposeType<T4>> +TransposeTuple(const std::tuple<T0, T1, T2, T3, T4>& t) { + return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)), + Transpose(std::get<2>(t)), Transpose(std::get<3>(t)), + Transpose(std::get<4>(t))); +} + +template <typename T0, typename T1, typename T2, typename T3, typename T4, + typename T5> +std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>, + TransposeType<T3>, TransposeType<T4>, TransposeType<T5>> +TransposeTuple(const std::tuple<T0, T1, T2, T3, T4, T5>& t) { + return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)), + Transpose(std::get<2>(t)), Transpose(std::get<3>(t)), + Transpose(std::get<4>(t)), Transpose(std::get<5>(t))); +} + +template <typename InputScalar, typename OutputScalar, typename BitDepthParams, + MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder, + typename LhsOffset, typename RhsOffset, typename OutputPipelineType, + typename GemmContextType> +void DispatchGemmShape(GemmContextType* context, + const MatrixMap<const InputScalar, LhsOrder>& lhs, + const MatrixMap<const InputScalar, RhsOrder>& rhs, + MatrixMap<OutputScalar, ResultOrder>* result, + const LhsOffset& lhs_offset, const RhsOffset& rhs_offset, + const OutputPipelineType& output_pipeline) { + assert(lhs.cols() == rhs.rows()); + + int rows = result->rows(); + int cols = result->cols(); + int depth = lhs.cols(); + + if (rows == 0 || cols == 0 || depth == 0) { + // Vacuous GEMM, return early to avoid having to deal with + // zero sizes below. + return; + } + + if (rows < cols) { + auto transposed_result_map = Transpose(*result); + return DispatchGemmShape<InputScalar, OutputScalar, BitDepthParams>( + context, Transpose(rhs), Transpose(lhs), &transposed_result_map, + Transpose(rhs_offset), Transpose(lhs_offset), + TransposeTuple(output_pipeline)); + } + + typedef DefaultKernel<BitDepthParams> Kernel; + MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar, + BitDepthParams>(context, Kernel(), lhs, rhs, result, + lhs_offset, rhs_offset, output_pipeline); +} + +} // end namespace gemmlowp + +#endif // GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/internal/kernel.h b/runtimes/nn/depend/external/gemmlowp/internal/kernel.h new file mode 100644 index 000000000..4d006af92 --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/internal/kernel.h @@ -0,0 +1,234 @@ +// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// kernel.h: general definitions for kernels. + +#ifndef GEMMLOWP_INTERNAL_KERNEL_H_ +#define GEMMLOWP_INTERNAL_KERNEL_H_ + +#include "../public/bit_depth.h" +#include "common.h" + +namespace gemmlowp { + +// Explanation of general gemmlowp terminology +// =========================================== +// +// We use the following abbreviations: +// LHS = "left-hand side" +// RHS = "right-hand side" +// Sometimes when referring to either LHS or RHS, we just say a "Side". +// +// In a matrix product of a MxK matrix times a KxN matrix, +// we call K the 'depth'. Note that M is the number of rows +// of the result (and of the LHS), and N is the number of columns +// of the result (and of the RHS). +// +// In each of the LHS and RHS matrices, we call 'width' the +// other dimension, besides the depth. So in the LHS, 'width' +// is the number of rows, while in the RHS, 'width' is the number +// of columns. +// +// So in the LHS MxK matrix, the depth is K and the width in M. +// And in the RHS KxN matrix, the depth is K and the width in N. +// +// This is illustrated in this picture: +// +// RHS width +// <-----------------> +// +-----------------+ ^ +// | RHS | | Depth +// +-----------------+ v +// ^ +--+ +-----------------+ +// | |L | | | +// LHS width | |H | | Result | +// | |S | | | +// v +--+ +-----------------+ +// <--> +// Depth + +// Explanation of gemmlowp kernel formats and "cells" +// ================================================== +// +// Kernels operate on small LHS and RHS blocks that fit in registers. +// These blocks are stored contiguously in memory, but not always +// in a traditional column-major or row-major order; instead, +// they consist of a number of sub-blocks, which we call "cells", +// that are stored in column-major or row-major order. However, +// what really matters to us is not so much rows vs columns, but +// rather width vs depth. So we refer to "width-major" and "depth-major" +// storage orders. In the LHS, width-major means row-major, +// while in the RHS, width-major means column-major. +// There is also a third possibility, "diagonal order", +// which is unused at the moment. +// +// We aim to treat both sides, LHS and RHS, on an equal footing, +// so we call them both 'sides'. A KernelFormat thus is just a pair +// of KernelSideFormat's, one for LHS and one for RHS; each KernelSideFormat +// contains a CellFormat and a number of cells; cells are only ever +// stacked in the width dimension, which means stacked vertically in the +// LHS and stacked horizondally in the RHS. +// +// Example +// ======= +// +// Let's work out the data layout expected by a kernel having the +// following format (the struct names here are defined below in this file): +// +// KernelFormat< +// KernelSideFormat<CellFormat<3, 4>, 3>, +// KernelSideFormat<CellFormat<5, 4>, 2> +// > +// +// The LHS format, KernelSideFormat<CellFormat<3, 4>, 3>, means: +// 3 cells, each cell having dimensions (width=3, depth=4), laid out in +// DepthMajor order (the default value, see CellFormat). In the LHS, +// DepthMajor means column-major, so the LHS cells are of size 3x4 in +// column-major order, so the LHS layout is: +// +// 0 3 6 9 +// 1 4 7 10 +// 2 5 8 11 +// 12 15 18 21 +// 13 16 19 22 +// 14 17 20 23 +// 24 27 30 33 +// 25 28 31 34 +// 26 29 32 35 +// +// The RHS format, KernelSideFormat<CellFormat<5, 4>, 2>, means: +// 2 cells each having dimensions (width=5, depth=4), laid out in +// DepthMajor order (the default value, see CellFormat). In the RHS, +// DepthMajor means row-major, so the RHS cells are of size 4x5 in +// row-major order, so the RHS layout is: +// +// 0 1 2 3 4 20 21 22 23 24 +// 5 6 7 8 9 25 26 27 28 29 +// 10 11 12 13 14 30 31 32 33 34 +// 15 16 17 18 19 35 36 37 38 39 + +// CellOrder enumerates the possible storage orders (=layouts) for +// a cell (see explanation above). +enum class CellOrder { DepthMajor, WidthMajor, Diagonal }; + +// CellFormat describes how data is laid +// out in a cell. That is, a CellOrder together with actual dimensions. +template <int tWidth, int tDepth, CellOrder tOrder = CellOrder::DepthMajor> +struct CellFormat { + static const int kWidth = tWidth; + static const int kDepth = tDepth; + static const CellOrder kOrder = tOrder; + + static const int kSize = kWidth * kDepth; +}; + +// KernelSideFormat describes how data is laid out in a kernel side +// (i.e. LHS or RHS). That is, a CellFormat together with a number of +// cells. These cells are always stacked in the Width dimension. +// For example, in the LHS case, the Width dimension is the rows dimension, +// se we're saying that in the LHS, cells are stacked vertically. +// We never stack cells in the Depth dimension. +template <typename tCellFormat, int tCells> +struct KernelSideFormat { + typedef tCellFormat Cell; + static const int kCells = tCells; + static const int kWidth = kCells * Cell::kWidth; + static const int kDepth = Cell::kDepth; + typedef std::uint8_t Scalar; +}; + +template <typename tCellFormat, int tCells> +struct KernelSideFormatInt8 : KernelSideFormat<tCellFormat, tCells> { + typedef std::int8_t Scalar; +}; + +// KernelFormat describes fully the input data layout that a kernel expects. +// It consists of two KernelSideFormat's, one for LHS and one for RHS. +template <typename tLhs, typename tRhs> +struct KernelFormat { + typedef tLhs Lhs; + typedef tRhs Rhs; + + static_assert(Lhs::Cell::kDepth == Rhs::Cell::kDepth, ""); + static const int kDepth = Lhs::Cell::kDepth; + static const int kRows = Lhs::Cell::kWidth * Lhs::kCells; + static const int kCols = Rhs::Cell::kWidth * Rhs::kCells; +}; + +inline const char* CellOrderName(CellOrder o) { + switch (o) { + case CellOrder::DepthMajor: + return "DepthMajor"; + case CellOrder::WidthMajor: + return "WidthMajor"; + case CellOrder::Diagonal: + return "Diagonal"; + default: + assert(false); + return nullptr; + } +} + +// Returns the offset into a cell, at which a given coefficient is stored. +template <typename CellFormat> +inline int OffsetIntoCell(int w, int d) { + switch (CellFormat::kOrder) { + case CellOrder::DepthMajor: + return w + d * CellFormat::kWidth; + case CellOrder::WidthMajor: + return d + w * CellFormat::kDepth; + case CellOrder::Diagonal: + assert(CellFormat::kWidth == CellFormat::kDepth); + static const int size = CellFormat::kWidth; + return ((size + w - d) * size + d) % (size * size); + default: + assert(false); + return 0; + } +} + +// KernelBase is the virtual base class below all kernels. +// The idea is that we don't need to templatize all our code on the exact +// kernel type; we only need to templatize on kernel format. Kernels +// sharing the same format can thus share the same packing/unpacking code. +struct KernelBase { + virtual const char* Name() const = 0; + + // This is the kernel implementation. We use the word 'run' consistently + // throughout gemmlowp to mean an inner loop, the implementation of which + // is to be provided by a separate optimized function. + virtual void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride, + std::size_t dst_col_stride, const std::uint8_t* lhs_ptr, + const std::uint8_t* rhs_ptr, std::size_t start_depth, + std::size_t run_depth) const = 0; + + virtual ~KernelBase() {} +}; + +template <typename KernelScalarType> +struct ZeroPointInputValue {}; + +template <> +struct ZeroPointInputValue<std::uint8_t> { + static constexpr std::uint8_t kValue = 0; +}; + +template <> +struct ZeroPointInputValue<std::int8_t> { + static constexpr std::uint8_t kValue = 128; +}; + +} // namespace gemmlowp + +#endif // GEMMLOWP_INTERNAL_KERNEL_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/internal/kernel_default.h b/runtimes/nn/depend/external/gemmlowp/internal/kernel_default.h new file mode 100644 index 000000000..7ed55b83d --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/internal/kernel_default.h @@ -0,0 +1,109 @@ +// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// kernel_default.h: Chooses default GEMM and GEMV kernels for the +// host platform. + +#ifndef GEMMLOWP_INTERNAL_KERNEL_DEFAULT_H_ +#define GEMMLOWP_INTERNAL_KERNEL_DEFAULT_H_ + +#ifndef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK +#define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK +#endif + +#include "../public/bit_depth.h" +#include "common.h" +#include "kernel_reference.h" + +namespace gemmlowp { + +template <bool MaxProductIsLessThan4096, + bool LhsAlwaysNonzero> +struct DefaultKernelImpl {}; + +// Partial specialization implementing the logic that if we want to use +// a kernel for LhsAlwaysNonzero but do not have such a kernel, then we fall +// back to a generic kernel not taking advantage of LhsAlwaysNonzero. +template <bool LhsAlwaysNonzero> +struct DefaultKernelImpl<true, LhsAlwaysNonzero> + : DefaultKernelImpl<false, LhsAlwaysNonzero> {}; + +// Partial specialization implementing the logic that if we want to use +// a kernel for MaxProductIsLessThan4096 but do not have such a kernel, then we +// fall back to a generic kernel not taking advantage of +// MaxProductIsLessThan4096. +template <bool MaxProductIsLessThan4096> +struct DefaultKernelImpl<MaxProductIsLessThan4096, true> + : DefaultKernelImpl<MaxProductIsLessThan4096, false> {}; + +template <typename BitDepthParams> +struct DefaultKernel + : DefaultKernelImpl<(BitDepthParams::LhsRange::kMaxValue * + BitDepthParams::RhsRange::kMaxValue < + 4096), + (BitDepthParams::LhsRange::kMinValue > 0)> {}; + +} // end namespace gemmlowp + +#define GEMMLOWP_SET_DEFAULT_KERNEL(MaxProductIsLessThan4096, \ + LhsAlwaysNonzero, Kernel) \ + namespace gemmlowp { \ + template <> \ + struct DefaultKernelImpl<MaxProductIsLessThan4096, \ + LhsAlwaysNonzero> : Kernel {}; \ + } + +#if defined GEMMLOWP_NEON_32 +#include "kernel_neon.h" +GEMMLOWP_SET_DEFAULT_KERNEL(false, false, NEON_32_Kernel12x4Depth2) +GEMMLOWP_SET_DEFAULT_KERNEL(true, false, + NEON_32_Kernel12x4Depth2Assuming12BitProducts) +GEMMLOWP_SET_DEFAULT_KERNEL(false, true, + NEON_32bit_GEMM_Int8Operands_LhsNonzero) +#elif defined GEMMLOWP_NEON_64 +#include "kernel_neon.h" +GEMMLOWP_SET_DEFAULT_KERNEL(false, false, NEON_64_Kernel12x8Depth2) +GEMMLOWP_SET_DEFAULT_KERNEL(false, true, + NEON_64bit_GEMM_Int8Operands_LhsNonzero) +#elif defined GEMMLOWP_SSE4_32 +#include "kernel_sse.h" +GEMMLOWP_SET_DEFAULT_KERNEL(false, false, SSE4_32_Kernel4x4Depth2) +#elif defined GEMMLOWP_SSE4_64 +#include "kernel_sse.h" +GEMMLOWP_SET_DEFAULT_KERNEL(false, false, SSE4_64_Kernel12x4Depth2) +#else +#ifndef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK +#if defined __ARM_ARCH_5TE__ +// SIMD is not available on this platform. The slow fallback will be used. +// Don't require GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK because there's nothing +// the user can do about it. +#else +#error \ + "SIMD not enabled, you'd be getting a slow software fallback. Consider \ +enabling SIMD extensions (for example using -msse4 if you're on modern x86). \ +If that's not an option, and you would like to continue with the \ +slow fallback, define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK." +#endif +#endif +#include "kernel_reference.h" +namespace gemmlowp { +typedef ReferenceKernel<KernelFormat< + KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1>, + KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1> > > + DefaultReferenceKernel; +} +GEMMLOWP_SET_DEFAULT_KERNEL(false, false, DefaultReferenceKernel) +#endif + +#endif // GEMMLOWP_INTERNAL_KERNEL_DEFAULT_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/internal/kernel_neon.h b/runtimes/nn/depend/external/gemmlowp/internal/kernel_neon.h new file mode 100644 index 000000000..5c253babe --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/internal/kernel_neon.h @@ -0,0 +1,1619 @@ +// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// kernel_neon.h: a collection of NEON optimized kernels. +// Check in kernel_default.h which one(s) are actually used by default. +// Others are mere experiments; they are still covered by tests +// in case they might be useful some day. + +#ifndef GEMMLOWP_INTERNAL_KERNEL_NEON_H_ +#define GEMMLOWP_INTERNAL_KERNEL_NEON_H_ + +#include "kernel.h" + +#include <arm_neon.h> +#include <cassert> + +namespace gemmlowp { + +// The kernels here are specifically arm 32bit assembly, not arm 64bit. +#ifdef GEMMLOWP_NEON_32 + +// Our main GEMM kernel. +struct NEON_32_Kernel12x4Depth2 : KernelBase { + typedef KernelFormat<KernelSideFormat<CellFormat<4, 2>, 3>, + KernelSideFormat<CellFormat<4, 2>, 1> > + Format; + + const char* Name() const override { return "NEON, 12x4, depth 2"; } + + // TODO(benoitjacob): reorder function arguments so dst comes last + void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride, + std::size_t dst_col_stride, const std::uint8_t* lhs_ptr, + const std::uint8_t* rhs_ptr, std::size_t start_depth, + std::size_t run_depth) const override { + ScopedProfilingLabel label("optimized kernel (NEON 12x4)"); + +// For iOS assembler, the %= style of local labels cause compilation errors, +// so use numerical ones instead. See +// http://stackoverflow.com/questions/3898435/labels-in-gcc-inline-assembly +// If you add any labels, remember to undef them at the end. +#define GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "1" +#define GEMMLOWP_LABEL_BEFORE_LOOP "2" +#define GEMMLOWP_LABEL_LOOP "3" +#define GEMMLOWP_LABEL_AFTER_LOOP "4" + + assert(dst_row_stride == 1); + asm volatile( + // Overview of register layout: + // + // A 2x4 cell of Rhs is stored in 16bit in d0--d1 (q0). + // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in d2--d7 + // (q1--q3). + // A 12x4 block of accumulators is stored in 32bit in q4--q15. + // + // +-----+-----+-----+-----+ + // |d0[0]|d0[1]|d0[2]|d0[3]| + // Rhs +-----+-----+-----+-----+ + // |d1[0]|d1[1]|d1[2]|d1[3]| + // +-----+-----+-----+-----+ + // + // | | | | | + // + // Lhs | | | | | + // + // +--+--+ - - - - +-----+-----+-----+-----+ + // |d2|d3| | q4 | q5 | q6 | q7 | + // |d2|d3| | q4 | q5 | q6 | q7 | + // |d2|d3| | q4 | q5 | q6 | q7 | + // |d2|d3| | q4 | q5 | q6 | q7 | + // +--+--+ - - - - +-----+-----+-----+-----+ + // |d4|d5| | q8 | q9 | q10 | q11 | + // |d4|d5| | q8 | q9 | q10 | q11 | + // |d4|d5| | q8 | q9 | q10 | q11 | + // |d4|d5| | q8 | q9 | q10 | q11 | + // +--+--+ - - - - +-----+-----+-----+-----+ + // |d6|d7| | q12 | q13 | q14 | q15 | + // |d6|d7| | q12 | q13 | q14 | q15 | + // |d6|d7| | q12 | q13 | q14 | q15 | + // |d6|d7| | q12 | q13 | q14 | q15 | + // +--+--+ - - - - +-----+-----+-----+-----+ + // + // Accumulator + + // Load 1 Rhs cell of size 2x4 + "vld1.8 {d0}, [%[rhs_ptr]]!\n" + // Load 3 Lhs cells of size 4x2 each + "vld1.8 {d2}, [%[lhs_ptr]]!\n" + "vld1.8 {d4}, [%[lhs_ptr]]!\n" + "vld1.8 {d6}, [%[lhs_ptr]]!\n" + + // Check if start_depth==0 to decide whether we will clear + // accumulators or load existing accumulators. + "cmp %[start_depth], #0\n" + + // Multiply dst_col_stride by 4 == sizeof(int32) to use + // it as a byte offset below. + "lsl %[dst_col_stride], #2\n" + + "beq " GEMMLOWP_LABEL_CLEAR_ACCUMULATORS + "f\n" + + // Load accumulators (start_depth != 0) + "mov r1, %[dst_ptr]\n" + "subs %[run_depth], #2\n" + "mov r0, r1\n" + "vld1.32 {d8, d9}, [r0]!\n" + "add r1, %[dst_col_stride]\n" + "vld1.32 {d16, d17}, [r0]!\n" + "vld1.32 {d24, d25}, [r0]\n" + "mov r0, r1\n" + "vld1.32 {d10, d11}, [r0]!\n" + "add r1, %[dst_col_stride]\n" + "vld1.32 {d18, d19}, [r0]!\n" + "vld1.32 {d26, d27}, [r0]\n" + "mov r0, r1\n" + "vld1.32 {d12, d13}, [r0]!\n" + "add r1, %[dst_col_stride]\n" + "vld1.32 {d20, d21}, [r0]!\n" + "vld1.32 {d28, d29}, [r0]\n" + "mov r0, r1\n" + "vld1.32 {d14, d15}, [r0]!\n" + "vld1.32 {d22, d23}, [r0]!\n" + "vld1.32 {d30, d31}, [r0]\n" + + "b " GEMMLOWP_LABEL_BEFORE_LOOP "f\n" + + GEMMLOWP_LABEL_CLEAR_ACCUMULATORS + ":\n" + + // Clear accumulators (start_depth == 0) + "vmov.s32 q4, #0\n" + "subs %[run_depth], #2\n" + "vmov.s32 q8, q4\n" + "vmov.s32 q12, q4\n" + "vmov.s32 q5, q4\n" + "vmov.s32 q9, q4\n" + "vmov.s32 q13, q4\n" + "vmov.s32 q6, q4\n" + "vmov.s32 q10, q4\n" + "vmov.s32 q14, q4\n" + "vmov.s32 q7, q4\n" + "vmov.s32 q11, q4\n" + "vmov.s32 q15, q4\n" + + GEMMLOWP_LABEL_BEFORE_LOOP + ":\n" + + // If there are only two levels of depth, skip the loop. + "beq " GEMMLOWP_LABEL_AFTER_LOOP "f\n" + + GEMMLOWP_LABEL_LOOP + ":\n" + // Expand Lhs/Rhs cells to 16 bit. + // Note: moving theses vmovls further down to allow for + // longer data pipelining helps a little on A57 but is + // harmful on A53 --- It looks as if A53 doesn't like + // interleaving vmovl's into the vmlal's. + "vmovl.u8 q0, d0\n" + "vmovl.u8 q1, d2\n" + "vmovl.u8 q2, d4\n" + "vmovl.u8 q3, d6\n" + + // Multiply-accumulate, level of depth 0 + "vmlal.u16 q4, d2, d0[0]\n" + "vmlal.u16 q5, d2, d0[1]\n" + "vmlal.u16 q6, d2, d0[2]\n" + "vmlal.u16 q7, d2, d0[3]\n" + "vldr d2, [%[lhs_ptr]]\n" + "vmlal.u16 q8, d4, d0[0]\n" + "vmlal.u16 q9, d4, d0[1]\n" + "vmlal.u16 q10, d4, d0[2]\n" + "vmlal.u16 q11, d4, d0[3]\n" + "vldr d4, [%[lhs_ptr], #8]\n" + "vmlal.u16 q12, d6, d0[0]\n" + "vmlal.u16 q13, d6, d0[1]\n" + "vmlal.u16 q14, d6, d0[2]\n" + "vmlal.u16 q15, d6, d0[3]\n" + "vldr d6, [%[lhs_ptr], #16]\n" + "vldr d0, [%[rhs_ptr]]\n" + + // Multiply-accumulate, level of depth 1 + "vmlal.u16 q4, d3, d1[0]\n" + "vmlal.u16 q5, d3, d1[1]\n" + "add %[lhs_ptr], #24\n" + "vmlal.u16 q6, d3, d1[2]\n" + "vmlal.u16 q7, d3, d1[3]\n" + "add %[rhs_ptr], #8\n" + "vmlal.u16 q8, d5, d1[0]\n" + "vmlal.u16 q9, d5, d1[1]\n" + "subs %[run_depth], #2\n" + "vmlal.u16 q10, d5, d1[2]\n" + "vmlal.u16 q11, d5, d1[3]\n" + "vmlal.u16 q12, d7, d1[0]\n" + "vmlal.u16 q13, d7, d1[1]\n" + "vmlal.u16 q14, d7, d1[2]\n" + "vmlal.u16 q15, d7, d1[3]\n" + + "bne " GEMMLOWP_LABEL_LOOP "b\n" + + GEMMLOWP_LABEL_AFTER_LOOP + ":\n" + + // Do remaining arithmetic for the last 2 levels of depth. + + // Expand Lhs/Rhs cells to 16 bit. + "vmovl.u8 q0, d0\n" + "vmovl.u8 q1, d2\n" + "vmovl.u8 q2, d4\n" + "vmovl.u8 q3, d6\n" + + // Multiply-accumulate, level of depth 0 + "vmlal.u16 q4, d2, d0[0]\n" + "vmlal.u16 q5, d2, d0[1]\n" + "vmlal.u16 q6, d2, d0[2]\n" + "vmlal.u16 q7, d2, d0[3]\n" + "vmlal.u16 q8, d4, d0[0]\n" + "vmlal.u16 q9, d4, d0[1]\n" + "vmlal.u16 q10, d4, d0[2]\n" + "vmlal.u16 q11, d4, d0[3]\n" + "vmlal.u16 q12, d6, d0[0]\n" + "vmlal.u16 q13, d6, d0[1]\n" + "vmlal.u16 q14, d6, d0[2]\n" + "vmlal.u16 q15, d6, d0[3]\n" + + // Multiply-accumulate, level of depth 1 + "vmlal.u16 q4, d3, d1[0]\n" + "vmlal.u16 q5, d3, d1[1]\n" + "vmlal.u16 q6, d3, d1[2]\n" + "vmlal.u16 q7, d3, d1[3]\n" + "vmlal.u16 q8, d5, d1[0]\n" + "vmlal.u16 q9, d5, d1[1]\n" + "vmlal.u16 q10, d5, d1[2]\n" + "vmlal.u16 q11, d5, d1[3]\n" + "vmlal.u16 q12, d7, d1[0]\n" + "vmlal.u16 q13, d7, d1[1]\n" + "vmlal.u16 q14, d7, d1[2]\n" + "vmlal.u16 q15, d7, d1[3]\n" + + // Store accumulators + "mov r1, %[dst_ptr]\n" + "mov r0, r1\n" + "vst1.32 {d8, d9}, [r0]!\n" + "add r1, %[dst_col_stride]\n" + "vst1.32 {d16, d17}, [r0]!\n" + "vst1.32 {d24, d25}, [r0]\n" + "mov r0, r1\n" + "vst1.32 {d10, d11}, [r0]!\n" + "add r1, %[dst_col_stride]\n" + "vst1.32 {d18, d19}, [r0]!\n" + "vst1.32 {d26, d27}, [r0]\n" + "mov r0, r1\n" + "vst1.32 {d12, d13}, [r0]!\n" + "add r1, %[dst_col_stride]\n" + "vst1.32 {d20, d21}, [r0]!\n" + "vst1.32 {d28, d29}, [r0]\n" + "mov r0, r1\n" + "vst1.32 {d14, d15}, [r0]!\n" + "vst1.32 {d22, d23}, [r0]!\n" + "vst1.32 {d30, d31}, [r0]\n" + : // outputs + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_ptr] "+r"(dst_ptr), + [run_depth] "+r"(run_depth) + : // inputs + [start_depth] "r"(start_depth), + [dst_col_stride] "r"(dst_col_stride) + : // clobbers + "cc", "memory", "r0", "r1", + // note: someone on internet says that quad registers are + // unsupported in the clobber list! + "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", + "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", + "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", + "d31"); +#undef GEMMLOWP_LABEL_CLEAR_ACCUMULATORS +#undef GEMMLOWP_LABEL_BEFORE_LOOP +#undef GEMMLOWP_LABEL_LOOP +#undef GEMMLOWP_LABEL_AFTER_LOOP + } +}; + +struct NEON_32_Kernel12x4Depth2Assuming12BitProducts : KernelBase { + typedef KernelFormat< + KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 3>, + KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 1> > + Format; + + const char* Name() const override { + return "NEON, 12x4, depth 2, assuming 12-bit products"; + } + + // TODO(benoitjacob): reorder function arguments so dst comes last + void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride, + std::size_t dst_col_stride, const std::uint8_t* lhs_ptr, + const std::uint8_t* rhs_ptr, std::size_t start_depth, + std::size_t run_depth) const override { + ScopedProfilingLabel label( + "optimized kernel (NEON 12x4, assuming 12-bit products)"); + assert(dst_row_stride == 1); + +// See comments above for why we need local numerical labels in our asm. +#define GEMMLOWP_LOOP_NEON_32_KERNEL_12X4_DEPTH2_ASSUMING_12BIT_PRODUCTS "1" +#define GEMMLOWP_LOAD_GLOBAL_ACCUMULATORS_NEON_32_KERNEL_12X4_DEPTH2_12BIT "2" +#define GEMMLOWP_LABEL_32 "3" +#define GEMMLOWP_LABEL_24 "4" +#define GEMMLOWP_LABEL_16 "5" +#define GEMMLOWP_LABEL_8 "6" +#define GEMMLOWP_LABEL_2 "7" + + // This kernel is special in that it uses local 16-bit accumulators. + // Because it assumes that each product fits in 12 bits, it can accumulate + // 16 products into a local 16-bit accumulator without risking overflow. + // At that point, it must accumulate these local 16-bit accumulators back + // into global 32-bit accumulators, which have to be stored in memory for + // lack of register space. + // This 12x4 block of global accumulators is laid out as 3 cells of size 4x4 + // stored in diagonal-major order like this for the first 4x4 cell: + // + // 0 4 8 12 + // 13 1 5 9 + // 10 14 2 6 + // 7 11 15 3 + // + // and likewise for the 2nd cell (16--31) and 3rd cell (32--47) + std::int32_t global_accumulators[3 * 4 * 4]; + asm volatile( + // Compute stride between consecutive columns, in bytes + "mov r0, #4\n" // multiply by 4 = sizeof(int32) + "mul %[dst_col_stride], r0\n" + + "cmp %[start_depth], #0\n" + "bne" + " " GEMMLOWP_LOAD_GLOBAL_ACCUMULATORS_NEON_32_KERNEL_12X4_DEPTH2_12BIT + "f\n" + + // If start_depth==0, we need to clear our global accumulators + "mov r0, %[global_accumulators]\n" + "vmov.s32 q8, #0\n" + "vmov.s32 q9, q8\n" + "vst1.32 {d16,d17,d18,d19}, [r0]!\n" + "vst1.32 {d16,d17,d18,d19}, [r0]!\n" + "vst1.32 {d16,d17,d18,d19}, [r0]!\n" + "vst1.32 {d16,d17,d18,d19}, [r0]!\n" + "vst1.32 {d16,d17,d18,d19}, [r0]!\n" + "vst1.32 {d16,d17,d18,d19}, [r0]!\n" + "b " GEMMLOWP_LOOP_NEON_32_KERNEL_12X4_DEPTH2_ASSUMING_12BIT_PRODUCTS + "f\n" + + // If start_depth!=0, we need to load our existing global accumulators + GEMMLOWP_LOAD_GLOBAL_ACCUMULATORS_NEON_32_KERNEL_12X4_DEPTH2_12BIT + ":\n" + // Load global accumulators from destination matrix, column-major + "mov r1, %[dst_ptr]\n" + "mov r0, %[dst_col_stride]\n" + "sub r0, #32\n" + "vld1.32 {d0,d1}, [r1]!\n" + "vld1.32 {d8,d9}, [r1]!\n" + "vld1.32 {d16,d17}, [r1], r0\n" + "vld1.32 {d2,d3}, [r1]!\n" + "vld1.32 {d10,d11}, [r1]!\n" + "vld1.32 {d18,d19}, [r1], r0\n" + "vld1.32 {d4,d5}, [r1]!\n" + "vld1.32 {d12,d13}, [r1]!\n" + "vld1.32 {d20,d21}, [r1], r0\n" + "vld1.32 {d6,d7}, [r1]!\n" + "vld1.32 {d14,d15}, [r1]!\n" + "vld1.32 {d22,d23}, [r1], r0\n" + // Now we need to convert the global accumulator registers to + // 4x4-block-wise diagonal-major order. What we effectively want to do + // is to rotate the rows, however the accumulators are stored in + // column-major order in registers. So we achieve this by + // transposing, rotating the registers, and transposing again each + // 4x4 block. + // + // Transpose 3 4x4 blocks separately + "vtrn.32 q0, q1\n" + "vtrn.32 q2, q3\n" + "vswp d1, d4\n" + "vswp d3, d6\n" + "vtrn.32 q4, q5\n" + "vtrn.32 q6, q7\n" + "vswp d9, d12\n" + "vswp d11, d14\n" + "vtrn.32 q8, q9\n" + "vtrn.32 q10, q11\n" + "vswp d17, d20\n" + "vswp d19, d22\n" + // Rotate the registers + "vext.32 q1, q1, q1, #1\n" + "vext.32 q2, q2, q2, #2\n" + "vext.32 q3, q3, q3, #3\n" + "vext.32 q5, q5, q5, #1\n" + "vext.32 q6, q6, q6, #2\n" + "vext.32 q7, q7, q7, #3\n" + "vext.32 q9, q9, q9, #1\n" + "vext.32 q10, q10, q10, #2\n" + "vext.32 q11, q11, q11, #3\n" + // Transpose again and store into our global accumulators + // buffer. These two operations are done at once using vst4. + "mov r0, %[global_accumulators]\n" + "vst4.32 {d0,d2,d4,d6}, [r0]!\n" + "vst4.32 {d1,d3,d5,d7}, [r0]!\n" + "vst4.32 {d8,d10,d12,d14}, [r0]!\n" + "vst4.32 {d9,d11,d13,d15}, [r0]!\n" + "vst4.32 {d16,d18,d20,d22}, [r0]!\n" + "vst4.32 {d17,d19,d21,d23}, [r0]!\n" + + /* Main loop */ + + GEMMLOWP_LOOP_NEON_32_KERNEL_12X4_DEPTH2_ASSUMING_12BIT_PRODUCTS + ":\n" + +// Overview of register layout: +// +// Registers q4--q16 are the local 16-bit accumulators. +// However, each entry in the result matrix is represented +// by *two* local 16-bit accumulators: one for even levels +// of depth and one for odd levels of depth. These correspond +// to the scalars at even and odd indices within each q-register. +// Thus we effectively use 32 bits of register space for each +// entry in the result matrix. The accumulators register layout +// is the same as was described above for the global 32-bit +// accumulators (3 cells of size 4x4 in diagonal-major order) +// with the only difference that instead of 32bit values we have +// pairs of 16bit values. +// +// A 2x4 cell of Rhs is stored in 8bit in d0. +// A 12x2 block of 3 4x2 cells Lhs is stored in 8bit in d1--d3. +// +// +--------+--------+--------+--------+ +// |d0[0] |d0[2] |d0[4] |d0[6] | +// Rhs +--------+--------+--------+--------+ +// |d0[1] |d0[3] |d0[5] |d0[7] | +// +--------+--------+--------+--------+ +// +// | | | | | +// +// Lhs | | | | | +// +// +-----+-----+ - - - +--------+--------+--------+--------+ +// |d1[0]|d1[1]| |q4[0,1] |q5[0,1] |q6[0,1] |q7[0,1] | +// |d1[2]|d1[3]| |q7[2,3] |q4[2,3] |q5[2,3] |q6[2,3] | +// |d1[4]|d1[5]| |q6[4,5] |q7[4,5] |q4[4,5] |q5[4,5] | +// |d1[6]|d1[7]| |q5[6,7] |q6[6,7] |q7[6,7] |q4[6,7] | +// +-----+-----+ - - - +--------+--------+--------+--------+ +// |d2[0]|d2[1]| |q8[0,1] |q8[0,1] |q8[0,1] |q8[0,1] | +// |d2[2]|d2[3]| |q9[2,3] |q9[2,3] |q9[2,3] |q9[2,3] | +// |d2[4]|d2[5]| |q10[4,5]|q10[4,5]|q10[4,5]|q10[4,5]| +// |d2[6]|d2[7]| |q11[6,7]|q11[6,7]|q11[6,7]|q11[6,7]| +// +-----+-----+ - - - +--------+--------+--------+--------+ +// |d3[0]|d3[1]| |q12[0,1]|q12[0,1]|q12[0,1]|q12[0,1]| +// |d3[2]|d3[3]| |q13[2,3]|q13[2,3]|q13[2,3]|q13[2,3]| +// |d3[4]|d3[5]| |q14[4,5]|q14[4,5]|q14[4,5]|q14[4,5]| +// |d3[6]|d3[7]| |q15[6,7]|q15[6,7]|q15[6,7]|q15[6,7]| +// +-----+-----+ - - - +--------+--------+--------+--------+ +// +// Local 16-bit accumulators +// Note: 2 scalars per matrix entry + +#define GEMMLOWP_ACCUMULATE_2_LEVELS_OF_DEPTH \ + /* Load 3 Lhs cells of size 4x2 */ \ + "vld1.8 {d1,d2,d3}, [%[lhs_ptr]:64]!\n" \ + \ + /* Load 1 Rhs cell of size 2x4 */ \ + "vld1.8 {d0}, [%[rhs_ptr]:64]!\n" \ + \ + /* Multiply-accumulate */ \ + "vmlal.u8 q4, d1, d0\n" \ + "vmlal.u8 q8, d2, d0\n" \ + "vmlal.u8 q12, d3, d0\n" \ + "vext.8 d0, d0, d0, #2\n" \ + "vmlal.u8 q5, d1, d0\n" \ + "vmlal.u8 q9, d2, d0\n" \ + "vmlal.u8 q13, d3, d0\n" \ + "vext.8 d0, d0, d0, #2\n" \ + "vmlal.u8 q6, d1, d0\n" \ + "vmlal.u8 q10, d2, d0\n" \ + "vmlal.u8 q14, d3, d0\n" \ + "vext.8 d0, d0, d0, #2\n" \ + "vmlal.u8 q7, d1, d0\n" \ + "vmlal.u8 q11, d2, d0\n" \ + "vmlal.u8 q15, d3, d0\n" \ + \ + "sub %[run_depth], #2\n" + +#define GEMMLOWP_ACCUMULATE_8_LEVELS_OF_DEPTH \ + GEMMLOWP_ACCUMULATE_2_LEVELS_OF_DEPTH \ + GEMMLOWP_ACCUMULATE_2_LEVELS_OF_DEPTH \ + GEMMLOWP_ACCUMULATE_2_LEVELS_OF_DEPTH \ + GEMMLOWP_ACCUMULATE_2_LEVELS_OF_DEPTH + + // Clear local 16-bit accumulators + "vmov.s32 q4, #0\n" + "vmov.s32 q5, q4\n" + "vmov.s32 q6, q4\n" + "vmov.s32 q7, q4\n" + "vmov.s32 q8, q4\n" + "vmov.s32 q9, q4\n" + "vmov.s32 q10, q4\n" + "vmov.s32 q11, q4\n" + "vmov.s32 q12, q4\n" + "vmov.s32 q13, q4\n" + "vmov.s32 q14, q4\n" + "vmov.s32 q15, q4\n" + + // Select a suitable number of depth levels + // to process at this iteration. TODO (benoitjacob) I guess that + // someone who really knows asm should make this a jump table. + "cmp %[run_depth], #32\n" + "bge " GEMMLOWP_LABEL_32 + "f\n" + "cmp %[run_depth], #24\n" + "bge " GEMMLOWP_LABEL_24 + "f\n" + "cmp %[run_depth], #16\n" + "bge " GEMMLOWP_LABEL_16 + "f\n" + "cmp %[run_depth], #8\n" + "bge " GEMMLOWP_LABEL_8 + "f\n" + "b " GEMMLOWP_LABEL_2 "f\n" + + GEMMLOWP_LABEL_32 + ":\n" GEMMLOWP_ACCUMULATE_8_LEVELS_OF_DEPTH GEMMLOWP_LABEL_24 + ":\n" GEMMLOWP_ACCUMULATE_8_LEVELS_OF_DEPTH GEMMLOWP_LABEL_16 + ":\n" GEMMLOWP_ACCUMULATE_8_LEVELS_OF_DEPTH GEMMLOWP_LABEL_8 + ":\n" GEMMLOWP_ACCUMULATE_2_LEVELS_OF_DEPTH + GEMMLOWP_ACCUMULATE_2_LEVELS_OF_DEPTH + GEMMLOWP_ACCUMULATE_2_LEVELS_OF_DEPTH GEMMLOWP_LABEL_2 + ":\n" GEMMLOWP_ACCUMULATE_2_LEVELS_OF_DEPTH + + // Accumulate the local accumulators into the global accumulators. + // This is about summing adjacent pairs of 16-bit scalars into + // single 32-bit scalars, so we use pairwise long addition (vpadal). + "mov r0, %[global_accumulators]\n" + "mov r1, %[global_accumulators]\n" + "vld1.32 {d0,d1,d2,d3}, [r0]!\n" + "vld1.32 {d4,d5,d6,d7}, [r0]!\n" + "vpadal.u16 q0, q4\n" + "vpadal.u16 q1, q5\n" + "vpadal.u16 q2, q6\n" + "vpadal.u16 q3, q7\n" + "vst1.32 {d0,d1,d2,d3}, [r1]!\n" + "vst1.32 {d4,d5,d6,d7}, [r1]!\n" + "vld1.32 {d0,d1,d2,d3}, [r0]!\n" + "vld1.32 {d4,d5,d6,d7}, [r0]!\n" + "vpadal.u16 q0, q8\n" + "vpadal.u16 q1, q9\n" + "vpadal.u16 q2, q10\n" + "vpadal.u16 q3, q11\n" + "vst1.32 {d0,d1,d2,d3}, [r1]!\n" + "vst1.32 {d4,d5,d6,d7}, [r1]!\n" + "vld1.32 {d0,d1,d2,d3}, [r0]!\n" + "vld1.32 {d4,d5,d6,d7}, [r0]!\n" + "vpadal.u16 q0, q12\n" + "vpadal.u16 q1, q13\n" + "vpadal.u16 q2, q14\n" + "vpadal.u16 q3, q15\n" + "vst1.32 {d0,d1,d2,d3}, [r1]!\n" + "vst1.32 {d4,d5,d6,d7}, [r1]!\n" + + // Loop. + "cmp %[run_depth], #0\n" + "bne " GEMMLOWP_LOOP_NEON_32_KERNEL_12X4_DEPTH2_ASSUMING_12BIT_PRODUCTS + "b\n" + +#undef GEMMLOWP_CLEAR_LOCAL_ACCUMULATORS +#undef GEMMLOWP_ACCUMULATE_8_LEVELS_OF_DEPTH +#undef GEMMLOWP_ACCUMULATE_2_LEVELS_OF_DEPTH +#undef GEMMLOWP_ADD_TO_GLOBAL_ACCUMULATORS + + /* end of main loop */ + + // Store the global accumulators to the destination matrix + // (column-major) + // This is the reverse of the steps that we followed at the beginning + // when we load the global accumulators from the destination matrix. + // The problem is the same: how to convert 4x4 blocks + // between column-major and diagonal-major orders. + // Like above, we do this by rotating rows, and we achieve that by + // tranposing, rotating columns, and transposing again. + // + // Load and transpose 4x4 blocks of global accumulators + // These two steps are done at once by the vld4 instruction. + "mov r0, %[global_accumulators]\n" + "vld4.32 {d0,d2,d4,d6}, [r0]!\n" + "vld4.32 {d1,d3,d5,d7}, [r0]!\n" + "vld4.32 {d8,d10,d12,d14}, [r0]!\n" + "vld4.32 {d9,d11,d13,d15}, [r0]!\n" + "vld4.32 {d16,d18,d20,d22}, [r0]!\n" + "vld4.32 {d17,d19,d21,d23}, [r0]!\n" + // Rotate the rows of each 4x4 block + "vext.32 q1, q1, q1, #3\n" + "vext.32 q2, q2, q2, #2\n" + "vext.32 q3, q3, q3, #1\n" + "vext.32 q5, q5, q5, #3\n" + "vext.32 q6, q6, q6, #2\n" + "vext.32 q7, q7, q7, #1\n" + "vext.32 q9, q9, q9, #3\n" + "vext.32 q10, q10, q10, #2\n" + "vext.32 q11, q11, q11, #1\n" + // Transpose again each 4x4 block + "vtrn.32 q0, q1\n" + "vtrn.32 q2, q3\n" + "vswp d1, d4\n" + "vswp d3, d6\n" + "vtrn.32 q4, q5\n" + "vtrn.32 q6, q7\n" + "vswp d9, d12\n" + "vswp d11, d14\n" + "vtrn.32 q8, q9\n" + "vtrn.32 q10, q11\n" + "vswp d17, d20\n" + "vswp d19, d22\n" + // Store into the column-major destination matrix + "mov r1, %[dst_ptr]\n" + "mov r0, %[dst_col_stride]\n" + "sub r0, #32\n" + "vst1.32 {d0,d1}, [r1]!\n" + "vst1.32 {d8,d9}, [r1]!\n" + "vst1.32 {d16,d17}, [r1], r0\n" + "vst1.32 {d2,d3}, [r1]!\n" + "vst1.32 {d10,d11}, [r1]!\n" + "vst1.32 {d18,d19}, [r1], r0\n" + "vst1.32 {d4,d5}, [r1]!\n" + "vst1.32 {d12,d13}, [r1]!\n" + "vst1.32 {d20,d21}, [r1], r0\n" + "vst1.32 {d6,d7}, [r1]!\n" + "vst1.32 {d14,d15}, [r1]!\n" + "vst1.32 {d22,d23}, [r1], r0\n" + : // outputs + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_ptr] "+r"(dst_ptr), + [run_depth] "+r"(run_depth) + : // inputs + [start_depth] "r"(start_depth), [dst_col_stride] "r"(dst_col_stride), + [global_accumulators] "r"(&global_accumulators[0]) + : // clobbers + "cc", "memory", "r0", "r1", + // note: someone on internet says that quad registers are + // unsupported in the clobber list! + "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", + "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", + "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", + "d31"); +#undef GEMMLOWP_LOOP_NEON_32_KERNEL_12X4_DEPTH2_ASSUMING_12BIT_PRODUCTS +#undef GEMMLOWP_LOAD_GLOBAL_ACCUMULATORS_NEON_32_KERNEL_12X4_DEPTH2_12BIT +#undef GEMMLOWP_LABEL_32 +#undef GEMMLOWP_LABEL_24 +#undef GEMMLOWP_LABEL_16 +#undef GEMMLOWP_LABEL_8 +#undef GEMMLOWP_LABEL_2 + } +}; + +struct NEON_32bit_GEMM_Int8Operands_LhsNonzero : KernelBase { + typedef KernelFormat< + KernelSideFormatInt8<CellFormat<4, 16, CellOrder::WidthMajor>, 1>, + KernelSideFormatInt8<CellFormat<2, 16, CellOrder::WidthMajor>, 1> > + Format; + const char* Name() const override { + return "NEON, 4x2, depth 16, accumulating two within signed int16"; + } + + // TODO(benoitjacob): reorder function arguments so dst comes last + void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride, + std::size_t dst_col_stride, const std::uint8_t* lhs_ptr, + const std::uint8_t* rhs_ptr, std::size_t start_depth, + std::size_t run_depth) const override { +#define GEMMLOWP_LABEL_AFTER_LOOP "1" +#define GEMMLOWP_LABEL_LOOP "2" +#define GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "3" +#define GEMMLOWP_LABEL_STORE "4" + asm volatile( + // Multiply dst_col_stride by 4 == sizeof(int32) to use + // it as a byte offset below. + "lsl %[dst_col_stride], %[dst_col_stride], #2\n" + + // Overview of register layout: + // + // A 2x16 block of Rhs is stored in 8 bit in d0--d3. + // A 4x16 block of Lhs is stored in 8 bit in d4--d7. That is only + // half of the register space required, so we loop over these registers + // twice. Only half of it, a 2x16 block, is stored in d4--d7 at + // any given time. + // + // A 4x2 block of accumulators is stored in q8--q15 (as 4x32 bit + // components which need to be horizontally-added at the end) + // + // The Lhs vectors are multiplied by the Rhs vectors with a widening + // multiply over the 8 first levels of depth, producing int16x8 + // vectors of products for each position in the accumulator matrix. + // Here comes the special trick: since the operands are signed int8, + // their range being [ -2^7 , 2^7 ), their products are in range + // [ -2^14 , 2^14 - 1 ), meaning that we can add two such values + // without any risk of overflowing int16. + // We thus proceed with the 8 next levels of depth, multiplying + // again Lhs by Rhs, accumulating into this existing int16x8 vector. + // + // Only then, having processed 16 levels of depth, do we need to + // horizontally add these int16x8 accumulators into the final + // int32x4 accumulators. + // + // As we do not have enough registers to store all 16 int16x8 + // temporary-16bit-accumulators, we have them cycle through q4--q7. + // + // + // Register layout (ignoring the q4--q7 temporary 16bit accumulators): + // + // +----+----+ + // | d0 | d2 | + // | . | . | + // | . | . | + // | . | . | + // Rhs +----+----+ + // | d1 | d3 | + // | . | . | + // | . | . | + // | . | . | + // +----+----+ + // + // | | | + // + // Lhs | | | + // + // +--------+--------+ - - - - +----+----+ + // | d4 ... | d5 ... | | q8 | q9 | + // | d6 ... | d7 ... | | q10| q11| + // | d4 ... | d5 ... | | q12| q13| + // | d6 ... | d7 ... | | q14| q15| + // +--------+--------+ - - - - +----+----+ + // + // Accumulator + // + + // Clear accumulators, and, interleaved with it, + // initial loads of the first loop iteration, + // taken out of the loop so that in the loop itself we have + // optimal streaming of data from memory. + "vldr d0, [%[rhs_ptr], #0]\n" + "vmov.i32 q8, #0\n" + "vldr d4, [%[lhs_ptr], #0]\n" + "vmov.i32 q9, #0\n" + "vldr d2, [%[rhs_ptr], #16]\n" + "vmov.i32 q10, q8\n" + "vldr d6, [%[lhs_ptr], #16]\n" + "vmov.i32 q11, q8\n" + "vldr d1, [%[rhs_ptr], #8]\n" + "vmov.i32 q12, q8\n" + "vldr d5, [%[lhs_ptr], #8]\n" + "vmov.i32 q13, q8\n" + "vldr d3, [%[rhs_ptr], #24]\n" + "vmov.i32 q14, q8\n" + "vldr d7, [%[lhs_ptr], #24]\n" + "vmov.i32 q15, q8\n" + + // General loop. + GEMMLOWP_LABEL_LOOP + ":\n" + + // Multiply 8 first levels of depth. + "vmull.s8 q4, d0, d4\n" + "add %[rhs_ptr], %[rhs_ptr], #32\n" + "vmull.s8 q5, d2, d4\n" + "vldr d4, [%[lhs_ptr], #32]\n" + "vmull.s8 q6, d0, d6\n" + "vmull.s8 q7, d2, d6\n" + "vldr d6, [%[lhs_ptr], #48]\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "vmlal.s8 q4, d1, d5\n" + "vmlal.s8 q5, d3, d5\n" + "vldr d5, [%[lhs_ptr], #40]\n" + "vmlal.s8 q6, d1, d7\n" + "vmlal.s8 q7, d3, d7\n" + "vldr d7, [%[lhs_ptr], #56]\n" + + // Add pairwise, accumulate into 32-bit accumulators. + "vpadal.s16 q8, q4\n" + "add %[lhs_ptr], %[lhs_ptr], #64\n" + "vpadal.s16 q9, q5\n" + "subs %[run_depth], %[run_depth], #16\n" + "vpadal.s16 q10, q6\n" + "vpadal.s16 q11, q7\n" + + "beq " GEMMLOWP_LABEL_AFTER_LOOP + "f\n" + + // Multiply first half. + "vmull.s8 q4, d0, d4\n" + "vmull.s8 q5, d2, d4\n" + "vldr d4, [%[lhs_ptr], #0]\n" + "vmull.s8 q6, d0, d6\n" + "vldr d0, [%[rhs_ptr], #0]\n" + "vmull.s8 q7, d2, d6\n" + "vldr d2, [%[rhs_ptr], #16]\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "vmlal.s8 q4, d1, d5\n" + "vldr d6, [%[lhs_ptr], #16]\n" + "vmlal.s8 q5, d3, d5\n" + "vldr d5, [%[lhs_ptr], #8]\n" + "vmlal.s8 q6, d1, d7\n" + "vldr d1, [%[rhs_ptr], #8]\n" + "vmlal.s8 q7, d3, d7\n" + "vldr d3, [%[rhs_ptr], #24]\n" + + // Add pairwise, accumulate into 32-bit accumulators. + "vpadal.s16 q12, q4\n" + "vldr d7, [%[lhs_ptr], #24]\n" + "vpadal.s16 q13, q5\n" + "vpadal.s16 q14, q6\n" + "vpadal.s16 q15, q7\n" + + "b " GEMMLOWP_LABEL_LOOP "b\n" + + GEMMLOWP_LABEL_AFTER_LOOP + ":\n" + + // Multiply first half. + "vmull.s8 q4, d0, d4\n" + "vmull.s8 q5, d2, d4\n" + "vmull.s8 q6, d0, d6\n" + "vmull.s8 q7, d2, d6\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "vmlal.s8 q4, d1, d5\n" + "vmlal.s8 q5, d3, d5\n" + "vmlal.s8 q6, d1, d7\n" + "vmlal.s8 q7, d3, d7\n" + + // Add pairwise, accumulate into 32-bit accumulators. + "vpadal.s16 q12, q4\n" + "vpadal.s16 q13, q5\n" + "vpadal.s16 q14, q6\n" + "vpadal.s16 q15, q7\n" + "cmp %[start_depth], #0\n" + + // Reduce 32bit accumulators horizontally. + "vpadd.s32 d0, d16, d17\n" + "vpadd.s32 d1, d18, d19\n" + "vpadd.s32 d2, d20, d21\n" + "vpadd.s32 d3, d22, d23\n" + "vpadd.s32 d4, d24, d25\n" + "vpadd.s32 d5, d26, d27\n" + "vpadd.s32 d6, d28, d29\n" + "vpadd.s32 d7, d30, d31\n" + + "bne " GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES + "f\n" + + // Reduce 32bit accumulators horizontally, second pass + // (each pass adds pairwise. we need to add 4-wise). + "vpadd.s32 d8, d0, d2\n" + "vpadd.s32 d9, d4, d6\n" + "vpadd.s32 d10, d1, d3\n" + "vpadd.s32 d11, d5, d7\n" + + "b " GEMMLOWP_LABEL_STORE "f\n" + + GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES + ":\n" + + // Reduce 32bit accumulators horizontally, second pass + // (each pass adds pairwise. we need to add 4-wise), + // and load destination values from memory. + "mov r0, %[dst_ptr]\n" + "vld1.32 {d16, d17}, [r0], %[dst_col_stride]\n" + "vpadd.s32 d8, d0, d2\n" + "vpadd.s32 d9, d4, d6\n" + "vld1.32 {d18, d19}, [r0]\n" + "vpadd.s32 d10, d1, d3\n" + "vpadd.s32 d11, d5, d7\n" + + // Add horizontally-reduced accumulators into + // the values loaded from memory + "vadd.s32 q4, q8, q4\n" + "vadd.s32 q5, q9, q5\n" + + GEMMLOWP_LABEL_STORE + ":\n" + // Store back into memory + "mov r0, %[dst_ptr]\n" + "vst1.32 {d8, d9}, [r0], %[dst_col_stride]\n" + "vst1.32 {d10, d11}, [r0]\n" + : // outputs + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_ptr] "+r"(dst_ptr), [run_depth] "+r"(run_depth) + : // inputs + [start_depth] "r"(start_depth), + [dst_col_stride] "r"(dst_col_stride) + : // clobbers + "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", + "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", + "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", + "d28", "d29", "d30", "d31"); +#undef GEMMLOWP_LABEL_LOOP +#undef GEMMLOWP_LABEL_AFTER_LOOP +#undef GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES +#undef GEMMLOWP_LABEL_STORE + } +}; + +#endif // GEMMLOWP_NEON_32 + +// The kernels here are specifically arm 64bit assembly, not arm 32bit. +#ifdef GEMMLOWP_NEON_64 + +struct NEON_64bit_GEMM_Int8Operands_LhsNonzero : KernelBase { + typedef KernelFormat< + KernelSideFormatInt8<CellFormat<4, 16, CellOrder::WidthMajor>, 1>, + KernelSideFormatInt8<CellFormat<4, 16, CellOrder::WidthMajor>, 1> > + Format; + const char* Name() const override { + return "NEON, 4x4, depth 16, accumulating two within signed int16"; + } + + // TODO(benoitjacob): reorder function arguments so dst comes last + void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride, + std::size_t dst_col_stride, const std::uint8_t* lhs_ptr, + const std::uint8_t* rhs_ptr, std::size_t start_depth, + std::size_t run_depth) const override { +#define GEMMLOWP_LABEL_AFTER_LOOP_LAST16 "1" +#define GEMMLOWP_LABEL_LOOP "2" +#define GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "3" +#define GEMMLOWP_LABEL_STORE "4" + asm volatile( + // Clear accumulators, and, interleaved with it, + // initial loads of the first loop iteration, + // taken out of the loop so that in the loop itself we have + // optimal streaming of data from memory. + "ld1 {v0.16b}, [%[rhs_ptr]], #16\n" + "dup v16.4s, wzr\n" + "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" + "dup v17.4s, wzr\n" + "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" + "dup v18.4s, wzr\n" + "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" + "dup v19.4s, wzr\n" + "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" + "dup v20.4s, wzr\n" + "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" + "dup v21.4s, wzr\n" + "ld1 {v6.16b}, [%[lhs_ptr]], #16\n" + "dup v22.4s, wzr\n" + "ld1 {v7.16b}, [%[lhs_ptr]], #16\n" + "dup v23.4s, wzr\n" + "dup v24.4s, wzr\n" + "dup v25.4s, wzr\n" + "dup v26.4s, wzr\n" + "dup v27.4s, wzr\n" + "dup v28.4s, wzr\n" + "dup v29.4s, wzr\n" + "dup v30.4s, wzr\n" + "dup v31.4s, wzr\n" + + // Multiply dst_col_stride by 4 == sizeof(int32) to use + // it as a byte offset below. + "lsl %[dst_col_stride], %[dst_col_stride], #2\n" + + // Initial arithmetic of the first loop iteration, + // taken out of the loop so that in the loop itself we have + // optimal streaming of data from memory. + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v1.8b, v4.8b\n" + "smull v10.8h, v2.8b, v4.8b\n" + "smull v11.8h, v3.8b, v4.8b\n" + "smull v12.8h, v0.8b, v5.8b\n" + "smull v13.8h, v1.8b, v5.8b\n" + "smull v14.8h, v2.8b, v5.8b\n" + "smull v15.8h, v3.8b, v5.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + "smlal2 v12.8h, v0.16b, v5.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + "smlal2 v14.8h, v2.16b, v5.16b\n" + "smlal2 v15.8h, v3.16b, v5.16b\n" + + "subs %[run_depth], %[run_depth], #16\n" + + // If the loop depth is only 16, then we can skip the general loop + // and go straight to the final part of the code. + "beq " GEMMLOWP_LABEL_AFTER_LOOP_LAST16 "f\n" + + // General loop. + GEMMLOWP_LABEL_LOOP + ":\n" + + // Overview of register layout: + // + // A 4x16 block of Rhs is stored in 8 bit in v0--v3. + // A 4x16 block of Lhs is stored in 8 bit in v4--v7. + // + // A 4x4 block of accumulators is stored in v16-v31 (as 4x32 bit + // components which need to be horizontally-added at the end) + // + // The Lhs vectors are multiplied by the Rhs vectors with a widening + // multiply over the 8 first levels of depth, producing int16x8 + // vectors of products for each position in the accumulator matrix. + // Here comes the special trick: since the operands are signed int8, + // their range being [ -2^7 , 2^7 ), their products are in range + // [ -2^14 , 2^14 - 1 ), meaning that we can add two such values + // without any risk of overflowing int16. + // We thus proceed with the 8 next levels of depth, multiplying + // again Lhs by Rhs, accumulating into this existing int16x8 vector. + // + // Only then, having processed 16 levels of depth, do we need to + // horizontally add these int16x8 accumulators into the final + // int32x4 accumulators. + // + // As we do not have enough registers to store all 16 int16x8 + // temporary-16bit-accumulators, we have them cycle through v8--v15. + // + // + // Register layout (ignoring the v8--v15 temporary 16bit accumulators): + // + // +--------+--------+--------+--------+ + // |v0.b[0] |v1.b[0] |v2.b[0] |v3.b[0] | + // Rhs +--------+--------+--------+--------+ + // | ... | ... | ... | ... | + // +--------+--------+--------+--------| + // |v0.b[15]|v1.b[15]|v2.b[15]|v3.b[15]| + // +--------+--------+--------+--------+ + // + // | | | | | + // + // Lhs | | | | | + // + // +-------+-----+--------+ - - +--------+--------+--------+--------+ + // |v4.b[0]| ... |v4.b[15]| | v16.4s | v17.4s | v18.4s | v19.4s | + // |v5.b[0]| ... |v5.b[15]| | v20.4s | v21.4s | v22.4s | v23.4s | + // |v6.b[0]| ... |v6.b[15]| | v24.4s | v25.4s | v26.4s | v27.4s | + // |v7.b[0]| ... |v7.b[15]| | v28.4s | v29.4s | v30.4s | v31.4s | + // +-------+--------------+ - - +--------+--------+--------+--------+ + // + // Accumulator + // + + // Some multiplications and 16-bit accumulation were already done above, + // so we start right away in the middle. + "sadalp v16.4s, v8.8h\n" + "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" + "smull v8.8h, v0.8b, v6.8b\n" + "sadalp v17.4s, v9.8h\n" + "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" + "smull v9.8h, v1.8b, v6.8b\n" + "sadalp v18.4s, v10.8h\n" + "smull v10.8h, v2.8b, v6.8b\n" + "sadalp v19.4s, v11.8h\n" + "smull v11.8h, v3.8b, v6.8b\n" + "sadalp v20.4s, v12.8h\n" + "smull v12.8h, v0.8b, v7.8b\n" + "sadalp v21.4s, v13.8h\n" + "smull v13.8h, v1.8b, v7.8b\n" + "sadalp v22.4s, v14.8h\n" + "smull v14.8h, v2.8b, v7.8b\n" + "sadalp v23.4s, v15.8h\n" + "smull v15.8h, v3.8b, v7.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v6.16b\n" + "smlal2 v9.8h, v1.16b, v6.16b\n" + "smlal2 v10.8h, v2.16b, v6.16b\n" + "smlal2 v11.8h, v3.16b, v6.16b\n" + + "ld1 {v6.16b}, [%[lhs_ptr]], #16\n" + + "smlal2 v12.8h, v0.16b, v7.16b\n" + "ld1 {v0.16b}, [%[rhs_ptr]], #16\n" + "smlal2 v13.8h, v1.16b, v7.16b\n" + "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" + "smlal2 v14.8h, v2.16b, v7.16b\n" + "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" + "smlal2 v15.8h, v3.16b, v7.16b\n" + "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" + + "sadalp v24.4s, v8.8h\n" + "smull v8.8h, v0.8b, v4.8b\n" + "sadalp v25.4s, v9.8h\n" + "ld1 {v7.16b}, [%[lhs_ptr]], #16\n" + "smull v9.8h, v1.8b, v4.8b\n" + "sadalp v26.4s, v10.8h\n" + "smull v10.8h, v2.8b, v4.8b\n" + "sadalp v27.4s, v11.8h\n" + "smull v11.8h, v3.8b, v4.8b\n" + "sadalp v28.4s, v12.8h\n" + "smull v12.8h, v0.8b, v5.8b\n" + "sadalp v29.4s, v13.8h\n" + "smull v13.8h, v1.8b, v5.8b\n" + "sadalp v30.4s, v14.8h\n" + "smull v14.8h, v2.8b, v5.8b\n" + "sadalp v31.4s, v15.8h\n" + "smull v15.8h, v3.8b, v5.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + + // Loop. Decrement loop index (depth) by 16, since we just handled + // 16 levels of depth. Do this subs a bit before the end of the loop + // for better dispatch on A57. + "subs %[run_depth], %[run_depth], #16\n" + + "smlal2 v12.8h, v0.16b, v5.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + "smlal2 v14.8h, v2.16b, v5.16b\n" + "smlal2 v15.8h, v3.16b, v5.16b\n" + + "bne " GEMMLOWP_LABEL_LOOP "b\n" + + // Final code for the last 16 levels of depth. + // There is nothing to load anymore, only some arithmetic to finish. + GEMMLOWP_LABEL_AFTER_LOOP_LAST16 + ":\n" + + // Some multiplications and 16-bit accumulation were already done above, + // so we start right away in the middle. + "sadalp v16.4s, v8.8h\n" + "smull v8.8h, v0.8b, v6.8b\n" + "sadalp v17.4s, v9.8h\n" + "smull v9.8h, v1.8b, v6.8b\n" + "sadalp v18.4s, v10.8h\n" + "smull v10.8h, v2.8b, v6.8b\n" + "sadalp v19.4s, v11.8h\n" + "smull v11.8h, v3.8b, v6.8b\n" + "sadalp v20.4s, v12.8h\n" + "smull v12.8h, v0.8b, v7.8b\n" + "sadalp v21.4s, v13.8h\n" + "smull v13.8h, v1.8b, v7.8b\n" + "sadalp v22.4s, v14.8h\n" + "smull v14.8h, v2.8b, v7.8b\n" + "sadalp v23.4s, v15.8h\n" + "smull v15.8h, v3.8b, v7.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v6.16b\n" + "smlal2 v9.8h, v1.16b, v6.16b\n" + "smlal2 v10.8h, v2.16b, v6.16b\n" + "smlal2 v11.8h, v3.16b, v6.16b\n" + "smlal2 v12.8h, v0.16b, v7.16b\n" + "smlal2 v13.8h, v1.16b, v7.16b\n" + "smlal2 v14.8h, v2.16b, v7.16b\n" + "smlal2 v15.8h, v3.16b, v7.16b\n" + + "sadalp v24.4s, v8.8h\n" + "sadalp v25.4s, v9.8h\n" + "sadalp v26.4s, v10.8h\n" + "sadalp v27.4s, v11.8h\n" + "sadalp v28.4s, v12.8h\n" + "sadalp v29.4s, v13.8h\n" + "sadalp v30.4s, v14.8h\n" + "sadalp v31.4s, v15.8h\n" + + // Reduce 32bit accumulators horizontally. + "addp v0.4s, v16.4s, v20.4s\n" + "addp v2.4s, v17.4s, v21.4s\n" + "addp v4.4s, v18.4s, v22.4s\n" + "addp v6.4s, v19.4s, v23.4s\n" + "addp v1.4s, v24.4s, v28.4s\n" + "addp v3.4s, v25.4s, v29.4s\n" + "addp v5.4s, v26.4s, v30.4s\n" + "addp v7.4s, v27.4s, v31.4s\n" + + "cmp %[start_depth], #0\n" + "bne " GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES + "f\n" + + // Reduce 32bit accumulators horizontally, second pass + // (each pass adds pairwise. we need to add 4-wise). + "addp v12.4s, v0.4s, v1.4s\n" + "addp v13.4s, v2.4s, v3.4s\n" + "addp v14.4s, v4.4s, v5.4s\n" + "addp v15.4s, v6.4s, v7.4s\n" + + "b " GEMMLOWP_LABEL_STORE "f\n" + + GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES + ":\n" + + // Reduce 32bit accumulators horizontally, second pass + // (each pass adds pairwise. we need to add 4-wise), + // and load destination values from memory. + "mov x0, %[dst_ptr]\n" + "ld1 {v12.16b}, [x0], %[dst_col_stride]\n" + "addp v8.4s, v0.4s, v1.4s\n" + "ld1 {v13.16b}, [x0], %[dst_col_stride]\n" + "addp v9.4s, v2.4s, v3.4s\n" + "ld1 {v14.16b}, [x0], %[dst_col_stride]\n" + "addp v10.4s, v4.4s, v5.4s\n" + "ld1 {v15.16b}, [x0]\n" + "addp v11.4s, v6.4s, v7.4s\n" + + // Add horizontally-reduced accumulators into + // the values loaded from memory + "add v12.4s, v12.4s, v8.4s\n" + "add v13.4s, v13.4s, v9.4s\n" + "add v14.4s, v14.4s, v10.4s\n" + "add v15.4s, v15.4s, v11.4s\n" + + GEMMLOWP_LABEL_STORE + ":\n" + // Store back into memory + "mov x0, %[dst_ptr]\n" + "st1 {v12.16b}, [x0], %[dst_col_stride]\n" + "st1 {v13.16b}, [x0], %[dst_col_stride]\n" + "st1 {v14.16b}, [x0], %[dst_col_stride]\n" + "st1 {v15.16b}, [x0]\n" + : // outputs + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_ptr] "+r"(dst_ptr), [run_depth] "+r"(run_depth), + [dst_col_stride] "+r"(dst_col_stride) + : // inputs + [start_depth] "r"(start_depth) + : // clobbers + "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", + "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31"); +#undef GEMMLOWP_LABEL_LOOP +#undef GEMMLOWP_LABEL_AFTER_LOOP_LAST16 +#undef GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES +#undef GEMMLOWP_LABEL_STORE + } +}; + + +// Our main GEMM kernel. +struct NEON_64_Kernel12x8Depth2 : KernelBase { + typedef KernelFormat<KernelSideFormat<CellFormat<4, 2>, 3>, + KernelSideFormat<CellFormat<4, 2>, 2> > + Format; + + const char* Name() const override { return "NEON, 12x8, depth 2"; } + + // TODO(benoitjacob): reorder function arguments so dst comes last + void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride, + std::size_t dst_col_stride, const std::uint8_t* lhs_ptr, + const std::uint8_t* rhs_ptr, std::size_t start_depth, + std::size_t run_depth) const override { + ScopedProfilingLabel label("optimized kernel (NEON 12x8)"); +// See comments above for why we need local numerical labels in our asm. +#define GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "1" +#define GEMMLOWP_LABEL_BEFORE_LOOP "2" +#define GEMMLOWP_LABEL_LOOP "3" +#define GEMMLOWP_LABEL_AFTER_LOOP "4" + + assert(dst_row_stride == 1); + asm volatile( + // Load 1 Rhs cell of size 2x8 + "ld1 {v5.8b}, [%[rhs_ptr]], #8\n" + "ld1 {v6.8b}, [%[rhs_ptr]], #8\n" + + // Load 3 Lhs cells of size 4x2 each + "ld1 {v2.8b}, [%[lhs_ptr]], #8\n" + "ld1 {v3.8b}, [%[lhs_ptr]], #8\n" + "ld1 {v4.8b}, [%[lhs_ptr]], #8\n" + + // Multiply dst_col_stride by 4 == sizeof(int32) to use + // it as a byte offset below. + "lsl %[dst_col_stride], %[dst_col_stride], #2\n" + + "cmp %[start_depth], #0\n" + "beq " GEMMLOWP_LABEL_CLEAR_ACCUMULATORS + "f\n" + + // Load accumulators + "mov x1, %[dst_ptr]\n" + "mov x0, x1\n" + "ld1 {v8.16b}, [x0], #16\n" + "subs %[run_depth], %[run_depth], #2\n" + "ld1 {v16.16b}, [x0], #16\n" + "add x1, x1, %[dst_col_stride]\n" + "ld1 {v24.16b}, [x0]\n" + "mov x0, x1\n" + "ld1 {v9.16b}, [x0], #16\n" + "add x1, x1, %[dst_col_stride]\n" + "ld1 {v17.16b}, [x0], #16\n" + "ld1 {v25.16b}, [x0]\n" + "mov x0, x1\n" + "ld1 {v10.16b}, [x0], #16\n" + "add x1, x1, %[dst_col_stride]\n" + "ld1 {v18.16b}, [x0], #16\n" + "ld1 {v26.16b}, [x0]\n" + "mov x0, x1\n" + "ld1 {v11.16b}, [x0], #16\n" + "add x1, x1, %[dst_col_stride]\n" + "ld1 {v19.16b}, [x0], #16\n" + "ld1 {v27.16b}, [x0]\n" + "mov x0, x1\n" + "ld1 {v12.16b}, [x0], #16\n" + "add x1, x1, %[dst_col_stride]\n" + "ld1 {v20.16b}, [x0], #16\n" + "ld1 {v28.16b}, [x0]\n" + "mov x0, x1\n" + "ld1 {v13.16b}, [x0], #16\n" + "add x1, x1, %[dst_col_stride]\n" + "ld1 {v21.16b}, [x0], #16\n" + "ld1 {v29.16b}, [x0]\n" + "mov x0, x1\n" + "ld1 {v14.16b}, [x0], #16\n" + "add x1, x1, %[dst_col_stride]\n" + "ld1 {v22.16b}, [x0], #16\n" + "ld1 {v30.16b}, [x0]\n" + "mov x0, x1\n" + "ld1 {v15.16b}, [x0], #16\n" + "ld1 {v23.16b}, [x0], #16\n" + "ld1 {v31.16b}, [x0]\n" + + "b " GEMMLOWP_LABEL_BEFORE_LOOP "f\n" + + GEMMLOWP_LABEL_CLEAR_ACCUMULATORS + ":\n" + + // Clear accumulator registers (see layout below) + "dup v8.4s, wzr\n" + "subs %[run_depth], %[run_depth], #2\n" + "dup v9.4s, wzr\n" + "dup v10.4s, wzr\n" + "dup v11.4s, wzr\n" + "dup v12.4s, wzr\n" + "dup v13.4s, wzr\n" + "dup v14.4s, wzr\n" + "dup v15.4s, wzr\n" + "dup v16.4s, wzr\n" + "dup v17.4s, wzr\n" + "dup v18.4s, wzr\n" + "dup v19.4s, wzr\n" + "dup v20.4s, wzr\n" + "dup v21.4s, wzr\n" + "dup v22.4s, wzr\n" + "dup v23.4s, wzr\n" + "dup v24.4s, wzr\n" + "dup v25.4s, wzr\n" + "dup v26.4s, wzr\n" + "dup v27.4s, wzr\n" + "dup v28.4s, wzr\n" + "dup v29.4s, wzr\n" + "dup v30.4s, wzr\n" + "dup v31.4s, wzr\n" + + GEMMLOWP_LABEL_BEFORE_LOOP + ":\n" + + "beq " GEMMLOWP_LABEL_AFTER_LOOP "f\n" + + GEMMLOWP_LABEL_LOOP + ":\n" + + // Overview of register layout: + // + // A 2x8 block of 2 2x4 cells of Rhs is stored in 16bit in v0--v1. + // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in v2--v4. + // A 12x8 block of accumulators is stored in 32bit in v8--v31. + // + // +--------+--------+-----+--------+--------+ + // |v0.h[0] |v0.h[1] | ... |v1.h[2] |v1.h[3] | + // Rhs +--------+--------+-----+--------+--------+ + // |v0.h[4] |v0.h[5] | ... |v1.h[6] |v1.h[7] | + // +--------+--------+-----+--------+--------+ + // + // | | | | | | + // + // Lhs | | | | | | + // + // +-------+-------+ - - +--------+--------+-----+--------+--------+ + // |v2.h[0]|v2.h[4]| |v8.s[0] |v9.s[0] | ... |v14.s[0]|v15.s[0]| + // |v2.h[1]|v2.h[5]| |v8.s[1] |v9.s[1] | ... |v14.s[1]|v15.s[1]| + // |v2.h[2]|v2.h[6]| |v8.s[2] |v9.s[2] | ... |v14.s[2]|v15.s[2]| + // |v2.h[3]|v2.h[7]| |v8.s[3] |v9.s[3] | ... |v14.s[3]|v15.s[3]| + // +-------+-------+ - - +--------+--------+-----+--------+--------+ + // |v3.h[0]|v3.h[4]| |v16.s[0]|v17.s[0]| ... |v22.s[0]|v23.s[0]| + // |v3.h[1]|v3.h[5]| |v16.s[1]|v17.s[1]| ... |v22.s[1]|v23.s[1]| + // |v3.h[2]|v3.h[6]| |v16.s[2]|v17.s[2]| ... |v22.s[2]|v23.s[2]| + // |v3.h[3]|v3.h[7]| |v16.s[3]|v17.s[3]| ... |v22.s[3]|v23.s[3]| + // +-------+-------+ - - +--------+--------+-----+--------+--------+ + // |v4.h[0]|v4.h[4]| |v24.s[0]|v25.s[0]| ... |v30.s[0]|v31.s[0]| + // |v4.h[1]|v4.h[5]| |v24.s[1]|v25.s[1]| ... |v30.s[1]|v31.s[1]| + // |v4.h[2]|v4.h[6]| |v24.s[2]|v25.s[2]| ... |v30.s[2]|v31.s[2]| + // |v4.h[3]|v4.h[7]| |v24.s[3]|v25.s[3]| ... |v30.s[3]|v31.s[3]| + // +-------+-------+ - - +--------+--------+-----+--------+--------+ + // + // Accumulator + + // Expand Lhs/Rhs cells to 16 bit. + "uxtl v0.8h, v5.8b\n" + "ld1 {v5.8b}, [%[rhs_ptr]], #8\n" + "uxtl v1.8h, v6.8b\n" + "ld1 {v6.8b}, [%[rhs_ptr]], #8\n" + "uxtl v2.8h, v2.8b\n" + "uxtl v3.8h, v3.8b\n" + "uxtl v4.8h, v4.8b\n" + + // Multiply-accumulate, top third + "umlal v8.4s, v2.4h, v0.h[0]\n" + "umlal v9.4s, v2.4h, v0.h[1]\n" + "umlal v10.4s, v2.4h, v0.h[2]\n" + "umlal v11.4s, v2.4h, v0.h[3]\n" + "umlal v12.4s, v2.4h, v1.h[0]\n" + "umlal v13.4s, v2.4h, v1.h[1]\n" + "umlal v14.4s, v2.4h, v1.h[2]\n" + "umlal v15.4s, v2.4h, v1.h[3]\n" + "umlal2 v8.4s, v2.8h, v0.h[4]\n" + "umlal2 v9.4s, v2.8h, v0.h[5]\n" + "umlal2 v10.4s, v2.8h, v0.h[6]\n" + "umlal2 v11.4s, v2.8h, v0.h[7]\n" + "umlal2 v12.4s, v2.8h, v1.h[4]\n" + "umlal2 v13.4s, v2.8h, v1.h[5]\n" + "umlal2 v14.4s, v2.8h, v1.h[6]\n" + "umlal2 v15.4s, v2.8h, v1.h[7]\n" + "ld1 {v2.8b}, [%[lhs_ptr]], #8\n" + + // Multiply-accumulate, middle third + "umlal v16.4s, v3.4h, v0.h[0]\n" + "umlal v17.4s, v3.4h, v0.h[1]\n" + "umlal v18.4s, v3.4h, v0.h[2]\n" + "umlal v19.4s, v3.4h, v0.h[3]\n" + "umlal v20.4s, v3.4h, v1.h[0]\n" + "umlal v21.4s, v3.4h, v1.h[1]\n" + "umlal v22.4s, v3.4h, v1.h[2]\n" + "umlal v23.4s, v3.4h, v1.h[3]\n" + "umlal2 v16.4s, v3.8h, v0.h[4]\n" + "umlal2 v17.4s, v3.8h, v0.h[5]\n" + "umlal2 v18.4s, v3.8h, v0.h[6]\n" + "umlal2 v19.4s, v3.8h, v0.h[7]\n" + "umlal2 v20.4s, v3.8h, v1.h[4]\n" + "umlal2 v21.4s, v3.8h, v1.h[5]\n" + "umlal2 v22.4s, v3.8h, v1.h[6]\n" + "umlal2 v23.4s, v3.8h, v1.h[7]\n" + "ld1 {v3.8b}, [%[lhs_ptr]], #8\n" + + "subs %[run_depth], %[run_depth], #2\n" + + // Multiply-accumulate, bottom third + "umlal v24.4s, v4.4h, v0.h[0]\n" + "umlal v25.4s, v4.4h, v0.h[1]\n" + "umlal v26.4s, v4.4h, v0.h[2]\n" + "umlal v27.4s, v4.4h, v0.h[3]\n" + "umlal v28.4s, v4.4h, v1.h[0]\n" + "umlal v29.4s, v4.4h, v1.h[1]\n" + "umlal v30.4s, v4.4h, v1.h[2]\n" + "umlal v31.4s, v4.4h, v1.h[3]\n" + "umlal2 v24.4s, v4.8h, v0.h[4]\n" + "umlal2 v25.4s, v4.8h, v0.h[5]\n" + "umlal2 v26.4s, v4.8h, v0.h[6]\n" + "umlal2 v27.4s, v4.8h, v0.h[7]\n" + "umlal2 v28.4s, v4.8h, v1.h[4]\n" + "umlal2 v29.4s, v4.8h, v1.h[5]\n" + "umlal2 v30.4s, v4.8h, v1.h[6]\n" + "umlal2 v31.4s, v4.8h, v1.h[7]\n" + "ld1 {v4.8b}, [%[lhs_ptr]], #8\n" + + "bne " GEMMLOWP_LABEL_LOOP "b\n" + + GEMMLOWP_LABEL_AFTER_LOOP + ":\n" + + // Expand Lhs/Rhs cells to 16 bit. + "uxtl v0.8h, v5.8b\n" + "uxtl v1.8h, v6.8b\n" + "uxtl v2.8h, v2.8b\n" + "uxtl v3.8h, v3.8b\n" + "uxtl v4.8h, v4.8b\n" + + // Multiply-accumulate, level of depth 0 + "umlal v8.4s, v2.4h, v0.h[0]\n" + "umlal v9.4s, v2.4h, v0.h[1]\n" + "umlal v10.4s, v2.4h, v0.h[2]\n" + "umlal v11.4s, v2.4h, v0.h[3]\n" + "umlal v12.4s, v2.4h, v1.h[0]\n" + "umlal v13.4s, v2.4h, v1.h[1]\n" + "umlal v14.4s, v2.4h, v1.h[2]\n" + "umlal v15.4s, v2.4h, v1.h[3]\n" + "umlal v16.4s, v3.4h, v0.h[0]\n" + "umlal v17.4s, v3.4h, v0.h[1]\n" + "umlal v18.4s, v3.4h, v0.h[2]\n" + "umlal v19.4s, v3.4h, v0.h[3]\n" + "umlal v20.4s, v3.4h, v1.h[0]\n" + "umlal v21.4s, v3.4h, v1.h[1]\n" + "umlal v22.4s, v3.4h, v1.h[2]\n" + "umlal v23.4s, v3.4h, v1.h[3]\n" + "umlal v24.4s, v4.4h, v0.h[0]\n" + "umlal v25.4s, v4.4h, v0.h[1]\n" + "umlal v26.4s, v4.4h, v0.h[2]\n" + "umlal v27.4s, v4.4h, v0.h[3]\n" + "umlal v28.4s, v4.4h, v1.h[0]\n" + "umlal v29.4s, v4.4h, v1.h[1]\n" + "umlal v30.4s, v4.4h, v1.h[2]\n" + "umlal v31.4s, v4.4h, v1.h[3]\n" + + // Multiply-accumulate, level of depth 1 + "umlal2 v8.4s, v2.8h, v0.h[4]\n" + "umlal2 v9.4s, v2.8h, v0.h[5]\n" + "umlal2 v10.4s, v2.8h, v0.h[6]\n" + "umlal2 v11.4s, v2.8h, v0.h[7]\n" + "umlal2 v12.4s, v2.8h, v1.h[4]\n" + "umlal2 v13.4s, v2.8h, v1.h[5]\n" + "umlal2 v14.4s, v2.8h, v1.h[6]\n" + "umlal2 v15.4s, v2.8h, v1.h[7]\n" + "umlal2 v16.4s, v3.8h, v0.h[4]\n" + "umlal2 v17.4s, v3.8h, v0.h[5]\n" + "umlal2 v18.4s, v3.8h, v0.h[6]\n" + "umlal2 v19.4s, v3.8h, v0.h[7]\n" + "umlal2 v20.4s, v3.8h, v1.h[4]\n" + "umlal2 v21.4s, v3.8h, v1.h[5]\n" + "umlal2 v22.4s, v3.8h, v1.h[6]\n" + "umlal2 v23.4s, v3.8h, v1.h[7]\n" + "umlal2 v24.4s, v4.8h, v0.h[4]\n" + "umlal2 v25.4s, v4.8h, v0.h[5]\n" + "umlal2 v26.4s, v4.8h, v0.h[6]\n" + "umlal2 v27.4s, v4.8h, v0.h[7]\n" + "umlal2 v28.4s, v4.8h, v1.h[4]\n" + "umlal2 v29.4s, v4.8h, v1.h[5]\n" + "umlal2 v30.4s, v4.8h, v1.h[6]\n" + "umlal2 v31.4s, v4.8h, v1.h[7]\n" + + // Store accumulators + "mov x1, %[dst_ptr]\n" + "mov x0, x1\n" + "st1 {v8.16b}, [x0], #16\n" + "subs %[run_depth], %[run_depth], #2\n" + "st1 {v16.16b}, [x0], #16\n" + "add x1, x1, %[dst_col_stride]\n" + "st1 {v24.16b}, [x0]\n" + "mov x0, x1\n" + "st1 {v9.16b}, [x0], #16\n" + "add x1, x1, %[dst_col_stride]\n" + "st1 {v17.16b}, [x0], #16\n" + "st1 {v25.16b}, [x0]\n" + "mov x0, x1\n" + "st1 {v10.16b}, [x0], #16\n" + "add x1, x1, %[dst_col_stride]\n" + "st1 {v18.16b}, [x0], #16\n" + "st1 {v26.16b}, [x0]\n" + "mov x0, x1\n" + "st1 {v11.16b}, [x0], #16\n" + "add x1, x1, %[dst_col_stride]\n" + "st1 {v19.16b}, [x0], #16\n" + "st1 {v27.16b}, [x0]\n" + "mov x0, x1\n" + "st1 {v12.16b}, [x0], #16\n" + "add x1, x1, %[dst_col_stride]\n" + "st1 {v20.16b}, [x0], #16\n" + "st1 {v28.16b}, [x0]\n" + "mov x0, x1\n" + "st1 {v13.16b}, [x0], #16\n" + "add x1, x1, %[dst_col_stride]\n" + "st1 {v21.16b}, [x0], #16\n" + "st1 {v29.16b}, [x0]\n" + "mov x0, x1\n" + "st1 {v14.16b}, [x0], #16\n" + "add x1, x1, %[dst_col_stride]\n" + "st1 {v22.16b}, [x0], #16\n" + "st1 {v30.16b}, [x0]\n" + "mov x0, x1\n" + "st1 {v15.16b}, [x0], #16\n" + "st1 {v23.16b}, [x0], #16\n" + "st1 {v31.16b}, [x0]\n" +#undef GEMMLOWP_LABEL_CLEAR_ACCUMULATORS +#undef GEMMLOWP_LABEL_BEFORE_LOOP +#undef GEMMLOWP_LABEL_LOOP +#undef GEMMLOWP_LABEL_AFTER_LOOP + : // outputs + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_ptr] "+r"(dst_ptr), + [run_depth] "+r"(run_depth) + : // inputs + [start_depth] "r"(start_depth), + [dst_col_stride] "r"(dst_col_stride) + : // clobbers + "cc", "memory", "x0", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", + "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", + "v27", "v28", "v29", "v30", "v31"); + } +}; + +#endif // GEMMLOWP_NEON_64 + +} // namespace gemmlowp + +#endif // GEMMLOWP_INTERNAL_KERNEL_NEON_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/internal/kernel_reference.h b/runtimes/nn/depend/external/gemmlowp/internal/kernel_reference.h new file mode 100644 index 000000000..3458c6a99 --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/internal/kernel_reference.h @@ -0,0 +1,118 @@ +// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// kernel_reference.h: a reference kernel for CPU architectures where we don't +// have optimized kernels yet. Also useful for testing, as it's templatized +// to have any arbitrary format, allowing tests to cover all sorts of corner +// cases. + +#ifndef GEMMLOWP_INTERNAL_KERNEL_REFERENCE_H_ +#define GEMMLOWP_INTERNAL_KERNEL_REFERENCE_H_ + +#include "kernel.h" + +#include <cstdio> +#include <cstring> + +namespace gemmlowp { + +// This kernel is templatized in an arbitrary Format template parameter, +// allowing it to have any arbitrary format. +template <typename tFormat> +struct ReferenceKernel : KernelBase { + typedef tFormat Format; + + const char* Name() const override { + static char buf[256]; + snprintf(buf, sizeof(buf), + "reference(Lhs: %d cells %dx%d %s, Rhs: %d cells %dx%d %s)", + Format::Lhs::kCells, Format::Lhs::Cell::kWidth, + Format::Lhs::Cell::kDepth, + CellOrderName(Format::Lhs::Cell::kOrder), Format::Rhs::kCells, + Format::Rhs::Cell::kDepth, Format::Rhs::Cell::kWidth, + CellOrderName(Format::Rhs::Cell::kOrder)); + return buf; + } + + void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride, + std::size_t dst_col_stride, const std::uint8_t* lhs_ptr, + const std::uint8_t* rhs_ptr, std::size_t start_depth, + std::size_t run_depth) const override { + std::int32_t accumulator[Format::kRows * Format::kCols]; + memset(accumulator, 0, sizeof(accumulator)); + + const int run_depth_cells = static_cast<int>(run_depth / Format::kDepth); + + // The outer loop is over the depth dimension. + for (int dc = 0; dc < run_depth_cells; dc++) { + // The next two loops are over cells of the Lhs (stacked vertically), + // and over cells of the Rhs (stacked horizontally). + for (int rc = 0; rc < Format::Lhs::kCells; rc++) { + const std::uint8_t* lhs_cell_ptr = + lhs_ptr + (dc * Format::Lhs::kCells + rc) * + Format::Lhs::Cell::kWidth * Format::kDepth; + for (int cc = 0; cc < Format::Rhs::kCells; cc++) { + const std::uint8_t* rhs_cell_ptr = + rhs_ptr + (dc * Format::Rhs::kCells + cc) * + Format::Rhs::Cell::kWidth * Format::kDepth; + + // Now we are inside one cell of the Lhs and inside one cell + // of the Rhs, so the remaining inner loops are just + // traditional three loops of matrix multiplication. + for (int di = 0; di < Format::kDepth; di++) { + for (int ri = 0; ri < Format::Lhs::Cell::kWidth; ri++) { + for (int ci = 0; ci < Format::Rhs::Cell::kWidth; ci++) { + const std::uint8_t* lhs_coeff_ptr = + lhs_cell_ptr + + OffsetIntoCell<typename Format::Lhs::Cell>(ri, di); + const std::uint8_t* rhs_coeff_ptr = + rhs_cell_ptr + + OffsetIntoCell<typename Format::Rhs::Cell>(ci, di); + std::int32_t* accumulator_coeff_ptr = + accumulator + (ri + rc * Format::Lhs::Cell::kWidth) + + (ci + cc * Format::Rhs::Cell::kWidth) * Format::kRows; + *accumulator_coeff_ptr += + std::int32_t(*lhs_coeff_ptr) * std::int32_t(*rhs_coeff_ptr); + } + } + } + } + } + } + + if (start_depth == 0) { + // start_depth == 0 means we haven't accumulated anything yet, so we need + // to overwrite the accumulator, as it hasn't been initialized to zero. + for (int r = 0; r < Format::kRows; r++) { + for (int c = 0; c < Format::kCols; c++) { + dst_ptr[r * dst_row_stride + c * dst_col_stride] = + accumulator[r + c * Format::kRows]; + } + } + } else { + // We have already accumulated stuff, so we need to continue accumulating + // instead of just overwriting. + for (int r = 0; r < Format::kRows; r++) { + for (int c = 0; c < Format::kCols; c++) { + dst_ptr[r * dst_row_stride + c * dst_col_stride] += + accumulator[r + c * Format::kRows]; + } + } + } + } +}; + +} // namespace gemmlowp + +#endif // GEMMLOWP_INTERNAL_KERNEL_REFERENCE_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/internal/kernel_sse.h b/runtimes/nn/depend/external/gemmlowp/internal/kernel_sse.h new file mode 100644 index 000000000..b879fd7c1 --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/internal/kernel_sse.h @@ -0,0 +1,517 @@ +// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// kernel_SSE.h: a collection of Intel SSE optimized kernels. +// Check in kernel_default.h which one(s) are actually used by default. +// Others are mere experiments; they are still covered by tests +// in case they might be useful some day. +// + +#ifndef GEMMLOWP_INTERNAL_KERNEL_SSE_H_ +#define GEMMLOWP_INTERNAL_KERNEL_SSE_H_ + +#include "kernel.h" + +#include <string.h> +#include <cassert> + +namespace gemmlowp { + +#ifdef GEMMLOWP_SSE4_32 +struct SSE4_32_Kernel4x4Depth2 : KernelBase { + typedef KernelFormat< + KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 1>, + KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 1> > + Format; + + const char* Name() const override { return "SSE, 4x4, depth 2"; } + + void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride, + std::size_t dst_col_stride, const std::uint8_t* lhs_ptr, + const std::uint8_t* rhs_ptr, std::size_t start_depth, + std::size_t run_depth) const override { + ScopedProfilingLabel label("optimized kernel"); + assert(dst_row_stride == 1); + std::int32_t run_depth_cells = run_depth / Format::kDepth; + /* Main loop */ + + // A 2x4 cell of Rhs is stored in 16bit in xmm1 . + // A 4x2 block Lhs is stored in 16bit in xmm0. + // A 4x4 block of accumulators is stored in 32bit in xmm4--xmm7. + // + // +-------+-------+-------+-------+ + // |xmm1[0]|xmm1[2]|xmm1[4]|xmm1[6]| + // Rhs +-------+---------------+-------+ + // |xmm1[1]|xmm1[3]|xmm1[5]|xmm1[7]| + // +-------+-------+-------+-------+ + // + // | | | | | + // + // Lhs | | | | | + // + // +--+--+ - - - - +-------+-------+-------+-------+ + // |xmm0 | | xmm4 | xmm5 | xmm6 | xmm7 | + // |xmm0 | (Iter1) | xmm4 | xmm5 | xmm6 | xmm7 | + // |xmm0 | | xmm4 | xmm5 | xmm6 | xmm7 | + // |xmm0 | | xmm4 | xmm5 | xmm6 | xmm7 | + // +--+--+ - - - - +-------+-------+-------+-------+ + // + // Accumulator + + asm volatile( + + // set accumulators to zero. + "pxor %%xmm4 , %%xmm4 \n\t" + "pxor %%xmm5 , %%xmm5 \n\t" + "pxor %%xmm6 , %%xmm6 \n\t" + "pxor %%xmm7 , %%xmm7 \n\t" + + "movl %[run_depth_cells], %%eax\n\t" + "subl $2, %%eax\n\t" + "js outerLoop1%=\n\t" + + // Loop for K unrolled by 4 + "outerLoop2%=:\n\t" + + // K = 1,2 + // RHS cell to xmm1 + "pmovzxbw (%[rhs_ptr]), %%xmm1\n\t" + + // LHS cell + "pmovzxbw 0x00(%[lhs_ptr]), %%xmm0\n\t" + "pshufd $0x00,%%xmm1,%%xmm2 \n\t" + "pshufd $0x55,%%xmm1,%%xmm3 \n\t" + "pmaddwd %%xmm0, %%xmm2 \n\t" + "pmaddwd %%xmm0, %%xmm3 \n\t" + "paddd %%xmm2, %%xmm4 \n\t" + "paddd %%xmm3, %%xmm5 \n\t" + + "prefetcht0 0x80(%[lhs_ptr]) \n\t" + + "pshufd $0xaa,%%xmm1,%%xmm2 \n\t" + "pmaddwd %%xmm0, %%xmm2 \n\t" + "pshufd $0xff,%%xmm1,%%xmm3 \n\t" + "pmaddwd %%xmm0, %%xmm3 \n\t" + + "prefetcht0 0x80(%[rhs_ptr]) \n\t" + + // K = 3,4 + // RHS cell to xmm1 + "pmovzxbw 0x08(%[rhs_ptr]), %%xmm1\n\t" + + "paddd %%xmm2, %%xmm6 \n\t" + "paddd %%xmm3, %%xmm7 \n\t" + + // LHS cell + "pmovzxbw 0x08(%[lhs_ptr]), %%xmm0\n\t" + "pshufd $0x00,%%xmm1,%%xmm2 \n\t" + "pshufd $0x55,%%xmm1,%%xmm3 \n\t" + "pmaddwd %%xmm0, %%xmm2 \n\t" + "pmaddwd %%xmm0, %%xmm3 \n\t" + "paddd %%xmm2, %%xmm4 \n\t" + "paddd %%xmm3, %%xmm5 \n\t" + "pshufd $0xaa,%%xmm1,%%xmm2 \n\t" + "pshufd $0xff,%%xmm1,%%xmm3 \n\t" + + "addl $0x10, %[lhs_ptr] \n\t" + "addl $0x10, %[rhs_ptr] \n\t" + + "pmaddwd %%xmm0, %%xmm3 \n\t" + "paddd %%xmm3, %%xmm7 \n\t" + "pmaddwd %%xmm0, %%xmm2 \n\t" + "paddd %%xmm2, %%xmm6 \n\t" + + "subl $2, %[run_depth_cells]\n\t" + "ja outerLoop2%=\n\t" + + "movl %[run_depth_cells], %%eax\n\t" + "decl %%eax\n\t" + "js finish%=\n\t" + + // Loop for K unrolled by 2 + "outerLoop1%=:\n\t" + + // RHS cell to xmm1 + "pmovzxbw (%[rhs_ptr]), %%xmm1\n\t" + + // LHS cell + "pmovzxbw 0x00(%[lhs_ptr]), %%xmm0\n\t" + "pshufd $0x00,%%xmm1,%%xmm2 \n\t" + "pmaddwd %%xmm0, %%xmm2 \n\t" + "paddd %%xmm2, %%xmm4 \n\t" + "pshufd $0x55,%%xmm1,%%xmm3 \n\t" + "pmaddwd %%xmm0, %%xmm3 \n\t" + "paddd %%xmm3, %%xmm5 \n\t" + + "pshufd $0xaa,%%xmm1,%%xmm2 \n\t" + "pmaddwd %%xmm0, %%xmm2 \n\t" + "paddd %%xmm2, %%xmm6 \n\t" + "pshufd $0xff,%%xmm1,%%xmm3 \n\t" + "pmaddwd %%xmm0, %%xmm3 \n\t" + "paddd %%xmm3, %%xmm7 \n\t" + + "addl $0x08, %[lhs_ptr]\n\t" + "addl $0x08, %[rhs_ptr]\n\t" + + "decl %[run_depth_cells]\n\t" + "jnz outerLoop1%=\n\t" + + "finish%=:\n\t" + + "movl %[dst_col_stride], %%eax\n\t" + "shll $2, %%eax\n\t" + + "movl %[start_depth], %%ecx\n\t" + "test %%ecx, %%ecx\n\t" + "jz storeDst%=\n\t" + + "leal (%%eax,%%eax,0x2), %%ecx\n\t" + "paddd 0x00(%[dst_ptr]) , %%xmm4 \n\t" + "paddd 0x00(%[dst_ptr], %%eax, 1) , %%xmm5 \n\t" + "paddd 0x00(%[dst_ptr], %%eax, 2) , %%xmm6 \n\t" + "paddd 0x00(%[dst_ptr], %%ecx, 1) , %%xmm7 \n\t" + + "storeDst%=:\n\t" + + "leal (%%eax,%%eax,0x2), %%ecx\n\t" + "movdqu %%xmm4 , 0x00(%[dst_ptr]) \n\t" + "movdqu %%xmm5 , 0x00(%[dst_ptr], %%eax, 1)\n\t" + "movdqu %%xmm6 , 0x00(%[dst_ptr], %%eax, 2)\n\t" + "movdqu %%xmm7 , 0x00(%[dst_ptr], %%ecx, 1)\n\t" + + : // outputs + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_ptr] "+r"(dst_ptr) + : // inputs + [start_depth] "g"(start_depth), [dst_col_stride] "g"(dst_col_stride), + [run_depth_cells] "g"(run_depth_cells) + : // clobbers + "cc", "memory", "%xmm0", "%xmm1", "%xmm3", "%xmm2", "%xmm4", "%xmm5", + "%xmm6", "%xmm7", "%eax", "%ecx"); + } +}; +#endif +#ifdef GEMMLOWP_SSE4_64 +struct SSE4_64_Kernel12x4Depth2 : KernelBase { + typedef KernelFormat< + KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 3>, + KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 1> > + Format; + + const char* Name() const override { return "SSE, 12x4, depth 2"; } + + void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride, + std::size_t dst_col_stride, const std::uint8_t* lhs_ptr, + const std::uint8_t* rhs_ptr, std::size_t start_depth, + std::size_t run_depth) const override { + ScopedProfilingLabel label("optimized kernel"); + assert(dst_row_stride == 1); + const std::int64_t run_depth_cells = run_depth / Format::kDepth; + const std::int64_t dst_col_stride_q = dst_col_stride; + + /* Main loop */ + + // A 2x4 cell of Rhs is stored in 16bit in xmm1 . + // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in xmm0, replaced + // every Iteration. + // A 12x4 block of accumulators is stored in 32bit in xmm4--xmm15. + // + // +-------+-------+-------+-------+ + // |xmm1[0]|xmm1[2]|xmm1[4]|xmm1[6]| + // Rhs +-------+---------------+-------+ + // |xmm1[1]|xmm1[3]|xmm1[5]|xmm1[7]| + // +-------+-------+-------+-------+ + // + // | | | | | + // + // Lhs | | | | | + // + // +--+--+ - - - - +-------+-------+-------+-------+ + // |xmm0 | | xmm4 | xmm5 | xmm6 | xmm7 | + // |xmm0 | (Iter1) | xmm4 | xmm5 | xmm6 | xmm7 | + // |xmm0 | | xmm4 | xmm5 | xmm6 | xmm7 | + // |xmm0 | | xmm4 | xmm5 | xmm6 | xmm7 | + // +--+--+ - - - - +-------+-------+-------+-------+ + // |xmm0 | | xmm8 | xmm9 | xmm10 | xmm11 | + // |xmm0 | (Iter2) | xmm8 | xmm9 | xmm10 | xmm11 | + // |xmm0 | | xmm8 | xmm9 | xmm10 | xmm11 | + // |xmm0 | | xmm8 | xmm9 | xmm10 | xmm11 | + // +--+--+ - - - - +-------+-------+-------+-------+ + // |xmm0 | | xmm12 | xmm13 | xmm14 | xmm15 | + // |xmm0 | (Iter3) | xmm12 | xmm13 | xmm14 | xmm15 | + // |xmm0 | | xmm12 | xmm13 | xmm14 | xmm15 | + // |xmm0 | | xmm12 | xmm13 | xmm14 | xmm15 | + // +--+--+ - - - - +-------+-------+-------+-------+ + // + // Accumulator + + asm volatile( + + // Set registers for destination + "movq %[dst_col_stride_q], %%r12\n\t" + "shlq $2, %%r12\n\t" + "leaq (%%r12,%%r12,0x2), %%r13\n\t" + + // Set accumulators to zero. + "pxor %%xmm4 , %%xmm4 \n\t" + "pxor %%xmm5 , %%xmm5 \n\t" + "pxor %%xmm6 , %%xmm6 \n\t" + "pxor %%xmm7 , %%xmm7 \n\t" + "pxor %%xmm8 , %%xmm8 \n\t" + "pxor %%xmm9 , %%xmm9 \n\t" + "pxor %%xmm10 , %%xmm10\n\t" + "pxor %%xmm11 , %%xmm11\n\t" + "pxor %%xmm12 , %%xmm12\n\t" + "pxor %%xmm13 , %%xmm13\n\t" + "pxor %%xmm14 , %%xmm14\n\t" + "pxor %%xmm15 , %%xmm15\n\t" + + "movq %[run_depth_cells], %%r14\n\t" + "subq $2, %%r14\n\t" + "js outerLoop1%=\n\t" + + // Loop for K unrolled by 4 + "outerLoop2%=:\n\t" + + // K = 1,2 + // RHS cell to xmm1 + + "pmovzxbw (%[rhs_ptr]), %%xmm1\n\t" + + // LHS cell + "pmovzxbw 0x00(%[lhs_ptr]), %%xmm0\n\t" + "pshufd $0x00,%%xmm1,%%xmm2 \n\t" + "pshufd $0x55,%%xmm1,%%xmm3 \n\t" + "pmaddwd %%xmm0, %%xmm2 \n\t" + "pmaddwd %%xmm0, %%xmm3 \n\t" + "paddd %%xmm2, %%xmm4 \n\t" + "paddd %%xmm3, %%xmm5 \n\t" + + "prefetcht0 0x80(%[lhs_ptr]) \n\t" + + "pshufd $0xaa,%%xmm1,%%xmm2 \n\t" + "pmaddwd %%xmm0, %%xmm2 \n\t" + "pshufd $0xff,%%xmm1,%%xmm3 \n\t" + "pmaddwd %%xmm0, %%xmm3 \n\t" + + // next LHS cell + "pmovzxbw 0x08(%[lhs_ptr]), %%xmm0\n\t" + + "paddd %%xmm2, %%xmm6 \n\t" + "paddd %%xmm3, %%xmm7 \n\t" + + "pshufd $0x00,%%xmm1,%%xmm2 \n\t" + "pshufd $0x55,%%xmm1,%%xmm3 \n\t" + "pmaddwd %%xmm0, %%xmm2 \n\t" + "pmaddwd %%xmm0, %%xmm3 \n\t" + "paddd %%xmm2, %%xmm8 \n\t" + "paddd %%xmm3, %%xmm9 \n\t" + + "prefetcht0 0x80(%[rhs_ptr]) \n\t" + + "pshufd $0xaa,%%xmm1,%%xmm2 \n\t" + "pshufd $0xff,%%xmm1,%%xmm3 \n\t" + "pmaddwd %%xmm0, %%xmm2 \n\t" + "pmaddwd %%xmm0, %%xmm3 \n\t" + "paddd %%xmm2, %%xmm10 \n\t" + "paddd %%xmm3, %%xmm11 \n\t" + + // next LHS cell + "pmovzxbw 0x10(%[lhs_ptr]), %%xmm0\n\t" + "pshufd $0x00,%%xmm1,%%xmm2 \n\t" + "pshufd $0x55,%%xmm1,%%xmm3 \n\t" + "pmaddwd %%xmm0, %%xmm2 \n\t" + "pmaddwd %%xmm0, %%xmm3 \n\t" + "paddd %%xmm2, %%xmm12 \n\t" + "paddd %%xmm3, %%xmm13 \n\t" + + "pshufd $0xaa,%%xmm1,%%xmm2 \n\t" + "pshufd $0xff,%%xmm1,%%xmm3 \n\t" + "pmaddwd %%xmm0, %%xmm2 \n\t" + "pmaddwd %%xmm0, %%xmm3 \n\t" + "paddd %%xmm2, %%xmm14 \n\t" + "paddd %%xmm3, %%xmm15 \n\t" + + // K = 3,4 + // RHS cell to xmm1 + "pmovzxbw 0x08(%[rhs_ptr]), %%xmm1\n\t" + + // LHS cell + "pmovzxbw 0x18(%[lhs_ptr]), %%xmm0\n\t" + "pshufd $0x00,%%xmm1,%%xmm2 \n\t" + "pshufd $0x55,%%xmm1,%%xmm3 \n\t" + "pmaddwd %%xmm0, %%xmm2 \n\t" + "pmaddwd %%xmm0, %%xmm3 \n\t" + "paddd %%xmm2, %%xmm4 \n\t" + "paddd %%xmm3, %%xmm5 \n\t" + + "pshufd $0xaa,%%xmm1,%%xmm2 \n\t" + "pshufd $0xff,%%xmm1,%%xmm3 \n\t" + "pmaddwd %%xmm0, %%xmm2 \n\t" + "pmaddwd %%xmm0, %%xmm3 \n\t" + "paddd %%xmm2, %%xmm6 \n\t" + "paddd %%xmm3, %%xmm7 \n\t" + + // next LHS cell + "pmovzxbw 0x20(%[lhs_ptr]), %%xmm0\n\t" + "pshufd $0x00,%%xmm1,%%xmm2 \n\t" + "pshufd $0x55,%%xmm1,%%xmm3 \n\t" + "pmaddwd %%xmm0, %%xmm2 \n\t" + "pmaddwd %%xmm0, %%xmm3 \n\t" + "paddd %%xmm2, %%xmm8 \n\t" + "paddd %%xmm3, %%xmm9 \n\t" + + "pshufd $0xaa,%%xmm1,%%xmm2 \n\t" + "pshufd $0xff,%%xmm1,%%xmm3 \n\t" + "pmaddwd %%xmm0, %%xmm2 \n\t" + "pmaddwd %%xmm0, %%xmm3 \n\t" + "paddd %%xmm2, %%xmm10 \n\t" + "paddd %%xmm3, %%xmm11 \n\t" + + // next LHS cell + "pmovzxbw 0x28(%[lhs_ptr]), %%xmm0\n\t" + + "addq $0x30, %[lhs_ptr] \n\t" + "addq $0x10, %[rhs_ptr] \n\t" + + "pshufd $0x00,%%xmm1,%%xmm2 \n\t" + "pshufd $0x55,%%xmm1,%%xmm3 \n\t" + "pmaddwd %%xmm0, %%xmm2 \n\t" + "pmaddwd %%xmm0, %%xmm3 \n\t" + "paddd %%xmm2, %%xmm12 \n\t" + "paddd %%xmm3, %%xmm13 \n\t" + + "pshufd $0xaa,%%xmm1,%%xmm2 \n\t" + "pshufd $0xff,%%xmm1,%%xmm3 \n\t" + "pmaddwd %%xmm0, %%xmm2 \n\t" + "pmaddwd %%xmm0, %%xmm3 \n\t" + "paddd %%xmm2, %%xmm14 \n\t" + "paddd %%xmm3, %%xmm15 \n\t" + + "subq $2, %[run_depth_cells]\n\t" + "ja outerLoop2%=\n\t" + + "movq %[run_depth_cells], %%r14\n\t" + "decq %%r14\n\t" + "js finish%=\n\t" + + // Loop for K unrolled by 2 + "outerLoop1%=:\n\t" + + // RHS cell to xmm1 + "pmovzxbw (%[rhs_ptr]), %%xmm1\n\t" + + // LHS cell + "pmovzxbw 0x00(%[lhs_ptr]), %%xmm0\n\t" + "pshufd $0x00,%%xmm1,%%xmm2 \n\t" + "pshufd $0x55,%%xmm1,%%xmm3 \n\t" + "pmaddwd %%xmm0, %%xmm2 \n\t" + "pmaddwd %%xmm0, %%xmm3 \n\t" + "paddd %%xmm2, %%xmm4 \n\t" + "paddd %%xmm3, %%xmm5 \n\t" + "pshufd $0xaa,%%xmm1,%%xmm2 \n\t" + "pshufd $0xff,%%xmm1,%%xmm3 \n\t" + "pmaddwd %%xmm0, %%xmm2 \n\t" + "pmaddwd %%xmm0, %%xmm3 \n\t" + "paddd %%xmm2, %%xmm6 \n\t" + "paddd %%xmm3, %%xmm7 \n\t" + + // next LHS cell + "pmovzxbw 0x08(%[lhs_ptr]), %%xmm0\n\t" + "pshufd $0x00,%%xmm1,%%xmm2 \n\t" + "pshufd $0x55,%%xmm1,%%xmm3 \n\t" + "pmaddwd %%xmm0, %%xmm2 \n\t" + "pmaddwd %%xmm0, %%xmm3 \n\t" + "paddd %%xmm2, %%xmm8 \n\t" + "paddd %%xmm3, %%xmm9 \n\t" + "pshufd $0xaa,%%xmm1,%%xmm2 \n\t" + "pshufd $0xff,%%xmm1,%%xmm3 \n\t" + "pmaddwd %%xmm0, %%xmm2 \n\t" + "pmaddwd %%xmm0, %%xmm3 \n\t" + "paddd %%xmm2, %%xmm10 \n\t" + "paddd %%xmm3, %%xmm11 \n\t" + + // next LHS cell + "pmovzxbw 0x10(%[lhs_ptr]), %%xmm0\n\t" + + "addq $0x18, %[lhs_ptr] \n\t" + "addq $0x08, %[rhs_ptr] \n\t" + + "pshufd $0x00,%%xmm1,%%xmm2 \n\t" + "pshufd $0x55,%%xmm1,%%xmm3 \n\t" + "pmaddwd %%xmm0, %%xmm2 \n\t" + "pmaddwd %%xmm0, %%xmm3 \n\t" + "paddd %%xmm2, %%xmm12 \n\t" + "paddd %%xmm3, %%xmm13 \n\t" + "pshufd $0xaa,%%xmm1,%%xmm2 \n\t" + "pshufd $0xff,%%xmm1,%%xmm3 \n\t" + "pmaddwd %%xmm0, %%xmm2 \n\t" + "pmaddwd %%xmm0, %%xmm3 \n\t" + "paddd %%xmm2, %%xmm14 \n\t" + "paddd %%xmm3, %%xmm15 \n\t" + + "decq %[run_depth_cells]\n\t" + "jnz outerLoop1%=\n\t" + + "finish%=:\n\t" + + "test %[start_depth], %[start_depth]\n\t" + "jz storeDst%=\n\t" + + "paddd 0x00(%[dst_ptr]) , %%xmm4 \n\t" + "paddd 0x10(%[dst_ptr]) , %%xmm8 \n\t" + "paddd 0x20(%[dst_ptr]) , %%xmm12\n\t" + "paddd 0x00(%[dst_ptr], %%r12, 1) , %%xmm5 \n\t" + "paddd 0x10(%[dst_ptr], %%r12, 1) , %%xmm9 \n\t" + "paddd 0x20(%[dst_ptr], %%r12, 1) , %%xmm13\n\t" + "paddd 0x00(%[dst_ptr], %%r12, 2) , %%xmm6 \n\t" + "paddd 0x10(%[dst_ptr], %%r12, 2) , %%xmm10\n\t" + "paddd 0x20(%[dst_ptr], %%r12, 2) , %%xmm14\n\t" + "paddd 0x00(%[dst_ptr], %%r13, 1) , %%xmm7 \n\t" + "paddd 0x10(%[dst_ptr], %%r13, 1) , %%xmm11\n\t" + "paddd 0x20(%[dst_ptr], %%r13, 1) , %%xmm15\n\t" + + "storeDst%=:\n\t" + + "movdqu %%xmm4 , 0x00(%[dst_ptr]) \n\t" + "movdqu %%xmm8 , 0x10(%[dst_ptr]) \n\t" + "movdqu %%xmm12 , 0x20(%[dst_ptr]) \n\t" + "movdqu %%xmm5 , 0x00(%[dst_ptr], %%r12, 1)\n\t" + "movdqu %%xmm9 , 0x10(%[dst_ptr], %%r12, 1)\n\t" + "movdqu %%xmm13 , 0x20(%[dst_ptr], %%r12, 1)\n\t" + "movdqu %%xmm6 , 0x00(%[dst_ptr], %%r12, 2)\n\t" + "movdqu %%xmm10 , 0x10(%[dst_ptr], %%r12, 2)\n\t" + "movdqu %%xmm14 , 0x20(%[dst_ptr], %%r12, 2)\n\t" + "movdqu %%xmm7 , 0x00(%[dst_ptr], %%r13, 1)\n\t" + "movdqu %%xmm11 , 0x10(%[dst_ptr], %%r13, 1)\n\t" + "movdqu %%xmm15 , 0x20(%[dst_ptr], %%r13, 1)\n\t" + + : // outputs + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_ptr] "+r"(dst_ptr) + : // inputs + [start_depth] "r"(start_depth), + [dst_col_stride_q] "r"(dst_col_stride_q), + [run_depth_cells] "r"(run_depth_cells) + : // clobbers + "cc", "memory", "%xmm0", "%xmm1", "%xmm3", "%xmm2", "%xmm4", "%xmm5", + "%xmm6", "%xmm7", "%xmm8", "%xmm9", "%xmm10", "%r12", "%r13", "%r14", + "%xmm11", "%xmm12", "%xmm13", "%xmm14", "%xmm15"); + } +}; +#endif + +} // namespace gemmlowp + +#endif // GEMMLOWP_INTERNAL_KERNEL_SSE_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/internal/multi_thread_gemm.h b/runtimes/nn/depend/external/gemmlowp/internal/multi_thread_gemm.h new file mode 100644 index 000000000..0234b26e9 --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/internal/multi_thread_gemm.h @@ -0,0 +1,701 @@ +// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// multi_thread_gemm.h: Multi-threaded GEMM entry point. +// Readers note: To understand this file, it is useful to first +// read and understand the much simpler single_thread_gemm.h. + +#ifndef GEMMLOWP_INTERNAL_MULTI_THREAD_GEMM_H_ +#define GEMMLOWP_INTERNAL_MULTI_THREAD_GEMM_H_ + +#include <pthread.h> +#include <unistd.h> +#include <vector> + +#include "single_thread_gemm.h" + +namespace gemmlowp { + +// On X86 and ARM platforms we enable a busy-wait spinlock before waiting on a +// pthread conditional variable. In order to implement that correctly we need +// to put some explicit memory load/store barriers. + +#if defined(GEMMLOWP_ALLOW_INLINE_ASM) && !defined(GEMMLOWP_NO_BUSYWAIT) && \ + (defined(GEMMLOWP_ARM) || defined(GEMMLOWP_X86)) + +#define GEMMLOWP_USE_BUSYWAIT + +const int kMaxBusyWaitNOPs = 32 * 1000 * 1000; + +#define GEMMLOWP_NOP "nop\n" + +#define GEMMLOWP_STRING_CONCAT_4(X) X X X X +#define GEMMLOWP_NOP4 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP) +#define GEMMLOWP_NOP16 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP4) +#define GEMMLOWP_NOP64 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP16) + +inline int Do256NOPs() { + asm volatile(GEMMLOWP_NOP64); + return 64; +} + +#undef GEMMLOWP_STRING_CONCAT_4 +#undef GEMMLOWP_NOP256 +#undef GEMMLOWP_NOP64 +#undef GEMMLOWP_NOP16 +#undef GEMMLOWP_NOP4 +#undef GEMMLOWP_NOP + +inline void WriteBarrier() { +#ifdef GEMMLOWP_ARM_32 + MemoryBarrier(); +#elif defined(GEMMLOWP_ARM_64) + asm volatile("dmb ishst" ::: "memory"); +#elif defined(GEMMLOWP_X86) + asm volatile("sfence" ::: "memory"); +#else +#error "Unsupported architecture for WriteBarrier." +#endif +} + +inline void ReadBarrier() { +#ifdef GEMMLOWP_ARM_32 + MemoryBarrier(); +#elif defined(GEMMLOWP_ARM_64) + asm volatile("dmb ishld" ::: "memory"); +#elif defined(GEMMLOWP_X86) + asm volatile("lfence" ::: "memory"); +#else +#error "Unsupported architecture for ReadBarrier." +#endif +} + +#endif + +// Waits until *var != initial_value. +// +// Returns the new value of *var. The guarantee here is that +// the return value is different from initial_value, and that that +// new value has been taken by *var at some point during the +// execution of this function. There is no guarantee that this is +// still the value of *var when this function returns, since *var is +// not assumed to be guarded by any lock. +// +// First does some busy-waiting for a fixed number of no-op cycles, +// then falls back to passive waiting for the given condvar, guarded +// by the given mutex. +// +// The idea of doing some initial busy-waiting is to help get +// better and more consistent multithreading benefits for small GEMM sizes. +// Busy-waiting help ensuring that if we need to wake up soon after having +// started waiting, then we can wake up quickly (as opposed to, say, +// having to wait to be scheduled again by the OS). On the other hand, +// we must still eventually revert to passive waiting for longer waits +// (e.g. worker threads having finished a GEMM and waiting until the next GEMM) +// so as to avoid permanently spinning. +// +template <typename T> +T WaitForVariableChange(volatile T* var, T initial_value, pthread_cond_t* cond, + pthread_mutex_t* mutex) { +#ifdef GEMMLOWP_USE_BUSYWAIT + // If we are on a platform that supports it, spin for some time. + { + int nops = 0; + // First, trivial case where the variable already changed value. + T new_value = *var; + if (new_value != initial_value) { + ReadBarrier(); + return new_value; + } + // Then try busy-waiting. + while (nops < kMaxBusyWaitNOPs) { + nops += Do256NOPs(); + new_value = *var; + if (new_value != initial_value) { + ReadBarrier(); + return new_value; + } + } + } +#endif + + // Finally, do real passive waiting. + pthread_mutex_lock(mutex); + T new_value = *var; + if (new_value == initial_value) { + pthread_cond_wait(cond, mutex); + new_value = *var; + assert(new_value != initial_value); + } + pthread_mutex_unlock(mutex); + return new_value; +} + +// A BlockingCounter lets one thread to wait for N events to occur. +// This is how the master thread waits for all the worker threads +// to have finished working. +class BlockingCounter { + public: + BlockingCounter() + : cond_(PTHREAD_COND_INITIALIZER), + mutex_(PTHREAD_MUTEX_INITIALIZER), + count_(0), + initial_count_(0) {} + + // Sets/resets the counter; initial_count is the number of + // decrementing events that the Wait() call will be waiting for. + void Reset(std::size_t initial_count) { + pthread_mutex_lock(&mutex_); + assert(count_ == 0); + initial_count_ = initial_count; + count_ = initial_count_; + pthread_mutex_unlock(&mutex_); + } + + // Decrements the counter; if the counter hits zero, signals + // the thread that was waiting for that, and returns true. + // Otherwise (if the decremented count is still nonzero), + // returns false. + bool DecrementCount() { + pthread_mutex_lock(&mutex_); + assert(count_ > 0); + count_--; +#ifdef GEMMLOWP_USE_BUSYWAIT + WriteBarrier(); +#endif + if (count_ == 0) { + pthread_cond_signal(&cond_); + } + bool retval = count_ == 0; + pthread_mutex_unlock(&mutex_); + return retval; + } + + // Waits for the N other threads (N having been set by Reset()) + // to hit the BlockingCounter. + void Wait() { + ScopedProfilingLabel label("BlockingCounter::Wait"); + while (count_) { + MemoryBarrier(); + const std::size_t count_value = count_; + if (count_value) { + WaitForVariableChange(&count_, count_value, &cond_, &mutex_); + } + } + } + + private: + pthread_cond_t cond_; + pthread_mutex_t mutex_; + std::size_t count_; + std::size_t initial_count_; +}; + +// A workload for a worker. +struct Task { + Task() : local_allocator(nullptr) {} + virtual ~Task() {} + virtual void Run() = 0; + Allocator* local_allocator; +}; + +// A worker thread. +class Worker { + public: + enum class State { + ThreadStartup, // The initial state before the thread main loop runs. + Ready, // Is not working, has not yet received new work to do. + HasWork, // Has work to do. + ExitAsSoonAsPossible // Should exit at earliest convenience. + }; + + explicit Worker(BlockingCounter* counter_to_decrement_when_ready) + : task_(nullptr), + state_cond_(PTHREAD_COND_INITIALIZER), + state_mutex_(PTHREAD_MUTEX_INITIALIZER), + state_(State::ThreadStartup), + counter_to_decrement_when_ready_(counter_to_decrement_when_ready) { + pthread_create(&thread_, nullptr, ThreadFunc, this); + } + + ~Worker() { + ChangeState(State::ExitAsSoonAsPossible); + pthread_join(thread_, nullptr); + } + + // Changes State; may be called from either the worker thread + // or the master thread; however, not all state transitions are legal, + // which is guarded by assertions. + void ChangeState(State new_state) { + ScopedProfilingLabel label("Worker::ChangeState"); + pthread_mutex_lock(&state_mutex_); + assert(new_state != state_); + switch (state_) { + case State::ThreadStartup: + assert(new_state == State::Ready); + break; + case State::Ready: + assert(new_state == State::HasWork || + new_state == State::ExitAsSoonAsPossible); + break; + case State::HasWork: + assert(new_state == State::Ready || + new_state == State::ExitAsSoonAsPossible); + break; + default: + abort(); + } + state_ = new_state; + pthread_cond_signal(&state_cond_); + if (state_ == State::Ready) { + counter_to_decrement_when_ready_->DecrementCount(); + } + pthread_mutex_unlock(&state_mutex_); + } + + // Thread entry point. + void ThreadFunc() { + ScopedProfilingLabel label("Worker::ThreadFunc"); + RegisterCurrentThreadForProfiling(); + + ChangeState(State::Ready); + + // Thread main loop + while (true) { + // Get a state to act on + // In the 'Ready' state, we have nothing to do but to wait until + // we switch to another state. + State state_to_act_upon = WaitForVariableChange( + &state_, State::Ready, &state_cond_, &state_mutex_); + + // We now have a state to act on, so act. + switch (state_to_act_upon) { + case State::HasWork: + // Got work to do! So do it, and then revert to 'Ready' state. + assert(task_); + task_->Run(); + task_ = nullptr; + ChangeState(State::Ready); + break; + case State::ExitAsSoonAsPossible: + return; + default: + abort(); + } + } + } + + static void* ThreadFunc(void* arg) { + static_cast<Worker*>(arg)->ThreadFunc(); + return nullptr; + } + + // Called by the master thead to give this worker work to do. + // It is only legal to call this if the worker + void StartWork(Task* task) { + assert(!task_); + task->local_allocator = &local_allocator_; + task_ = task; +#ifdef GEMMLOWP_USE_BUSYWAIT + WriteBarrier(); +#endif + assert(state_ == State::Ready); + ChangeState(State::HasWork); + } + + private: + // The underlying thread. + pthread_t thread_; + + // The task to be worked on. + Task* task_; + + // The condition variable and mutex guarding state changes. + pthread_cond_t state_cond_; + pthread_mutex_t state_mutex_; + + // The state enum tells if we're currently working, waiting for work, etc. + State state_; + + // Each thread had a local allocator so they can allocate temporary + // buffers without blocking each other. + Allocator local_allocator_; + + // pointer to the master's thread BlockingCounter object, to notify the + // master thread of when this worker switches to the 'Ready' state. + BlockingCounter* const counter_to_decrement_when_ready_; +}; + +// A very simple pool of workers, that only allows the very +// specific parallelization pattern that we use here: +// a fixed number of workers can be given work, and one then +// waits for all of them to finish. +// +// See MultiThreadGemmContextBase for how other WorkersPool implementations can +// be used. Note that in those implementations, StartWorker can be free to +// ignore the <index> value; that is, the caller of WorkersPool does not rely on +// <index> to order tasks with equal <index>. +class WorkersPool { + public: + WorkersPool() {} + + ~WorkersPool() { + for (auto w : workers_) { + delete w; + } + } + + void Execute(const std::vector<Task*>& tasks) { + assert(tasks.size() >= 1); + // One of the tasks will be run on the current thread. + int workers_count = tasks.size() - 1; + CreateWorkers(workers_count); + assert(workers_count <= workers_.size()); + counter_to_decrement_when_ready_.Reset(workers_count); + int n = 0; + std::for_each(tasks.begin(), --tasks.end(), [this, &n](Task *task) { + workers_[n++]->StartWork(task); + }); + // Execute the remaining workload immediately on the current thread. + Task* task = tasks.back(); + task->local_allocator = &main_thread_task_allocator_; + task->Run(); + // Wait for the workers submitted above to finish. + counter_to_decrement_when_ready_.Wait(); + // Cleanup tasks (best to do this from the same thread that allocated + // the memory). + std::for_each(tasks.begin(), tasks.end(), [](Task *task) { + delete task; + }); + } + + private: + // Ensures that the pool has at least the given count of workers. + // If any new worker has to be created, this function waits for it to + // be ready. + void CreateWorkers(std::size_t workers_count) { + if (workers_.size() >= workers_count) { + return; + } + counter_to_decrement_when_ready_.Reset(workers_count - workers_.size()); + while (workers_.size() < workers_count) { + workers_.push_back(new Worker(&counter_to_decrement_when_ready_)); + } + counter_to_decrement_when_ready_.Wait(); + } + + // copy construction disallowed + WorkersPool(const WorkersPool&) = delete; + + // The workers in this pool. They are owned by the pool: + // the pool creates workers and destroys them in its destructor. + std::vector<Worker*> workers_; + + // The BlockingCounter used to wait for the workers. + BlockingCounter counter_to_decrement_when_ready_; + + // For N-threaded operations, we will use only N-1 worker threads + // while the last task will be run directly on the main thread. + // It will then use this main_thread_task_allocator_; having a + // dedicated allocator for that (separate from the base allocator_) + // allows to use the same code for all tasks regardless of which + // thread they run on. + Allocator main_thread_task_allocator_; +}; + +// The task we use to implement a multi-threaded Gemm: a block of the +// RHS has been packed by the master thread; each worker thread +// then has to pack a block of the LHS and accumulate the Gemm of these +// packed LHS and RHS blocks. +template <typename KernelFormat, typename InputScalar, typename OutputScalar, + typename BitDepthParams, MapOrder LhsOrder, MapOrder RhsOrder, + MapOrder ResultOrder, typename LhsOffset, typename RhsOffset, + typename OutputPipelineType, typename GemmContextType> +struct GemmWithPackedRhsTask : Task { + typedef PackedSideBlock<typename KernelFormat::Lhs> PackedLhs; + typedef PackedSideBlock<typename KernelFormat::Rhs> PackedRhs; + GemmWithPackedRhsTask(GemmContextType* _context, + const KernelBase& _kernel, + const MatrixMap<const InputScalar, LhsOrder>& _lhs, + const PackedRhs& _packed_rhs, + MatrixMap<OutputScalar, ResultOrder>* _result, + const MatrixBlockBounds& _result_block, + const LhsOffset& _lhs_offset, + const RhsOffset& _rhs_offset, + const OutputPipelineType& _output_pipeline) + : context(_context), + kernel(_kernel), + lhs(_lhs), + packed_rhs(_packed_rhs), + result(*_result), + result_block(_result_block), + lhs_offset(_lhs_offset), + rhs_offset(_rhs_offset), + output_pipeline(_output_pipeline) {} + + void Run() override { + ScopedProfilingLabel label("GemmWithPackedRhsTask"); + + const int rows = result_block.rows; + const int cols = result_block.cols; + const int depth = lhs.cols(); + + BlockParams block_params; + block_params.Init<KernelFormat>(rows, cols, depth, 1, + context->l1_bytes_to_use(), + context->l2_bytes_to_use(), + context->l2_rhs_factor()); + + PackedLhs packed_lhs(Side::Lhs, local_allocator, block_params); + + PackedResult packed_result(local_allocator, block_params); + + local_allocator->Commit(); + + for (int c = 0; c < cols; c += block_params.l2_cols) { + int cs = std::min(block_params.l2_cols, cols - c); + + for (int r = 0; r < rows; r += block_params.l2_rows) { + int rs = std::min(block_params.l2_rows, rows - r); + + PackLhs(&packed_lhs, lhs.block(r, 0, rs, depth)); + + Compute(kernel, block_params, &packed_result, packed_lhs, packed_rhs, + depth); + + auto curr_result_block = MatrixBlockBounds( + result_block.start_row + r, result_block.start_col + c, rs, cs); + UnpackResult<KernelFormat>( + &result, curr_result_block, packed_result, depth, + packed_lhs.sums_of_each_slice(), packed_rhs.sums_of_each_slice(), + lhs_offset.block(curr_result_block.start_row, rs), + rhs_offset.block(curr_result_block.start_col, cs), output_pipeline); + } + } + + local_allocator->Decommit(); + } + + const GemmContextType* context; + const KernelBase& kernel; + const MatrixMap<const InputScalar, LhsOrder> lhs; + const PackedRhs packed_rhs; + MatrixMap<OutputScalar, ResultOrder> result; + const MatrixBlockBounds result_block; + const LhsOffset& lhs_offset; + const RhsOffset& rhs_offset; + const OutputPipelineType& output_pipeline; +}; + +// This base class for multi-threading allows subclasses to implement their own +// workers_pool() method. See MultiThreadGemmContext below for an example; +// any other implementation of workers_pool() must return an object with the +// same public methods as WorkersPool. +class MultiThreadGemmContextBase : public SingleThreadGemmContext { + public: + void set_max_num_threads(int n) { max_num_threads_ = n; } + + int max_num_threads() const { return max_num_threads_; } + + protected: + // The maximum number of worker threads to use (including + // the master thread). + // The default value 1 means single-threading. That is the default + // because gemmlowp's primary target is mobile hardware, where thermal + // constraints usually mean that it may not be realistic to use more + // than 1 CPU core even if multiple cores are present. + // The special value 0 means try to detect the number of hardware threads. + // Note: this assumes that all CPU cores are equivalent. That assumption + // is defeated on big.LITTLE ARM devices, where we have no API to query + // the number of big cores (which is typically what we would want to use, + // leaving aside above-mentioned thermal issues). That is the other reason + // why the best compromise here is to let max_num_threads_ default to 1, + // so users who want multi-threading have to make the decision of how many + // threads to use by themselves. + int max_num_threads_ = 1; +}; + +class MultiThreadGemmContext : public MultiThreadGemmContextBase { + public: + WorkersPool* workers_pool() { return &workers_pool_; } + + private: + // The workers pool used by MultiThreadGemm. Making + // this part of the context allows it to be persistent, + // avoiding recreating threads on every Gemm. + WorkersPool workers_pool_; +}; + +// Needed by chrome native builds +#ifndef _SC_NPROCESSORS_CONF +#define _SC_NPROCESSORS_CONF _SC_NPROCESSORS_ONLN +#endif + +// Determines how many threads should be used for a given Gemm +// operation. +template <int KernelRows> +inline int HowManyThreads(int max_num_threads, int rows, int cols, int depth) { + // Early-exit in the default case where multi-threading is disabled. + if (max_num_threads == 1) { + return 1; + } + + // Determine the maximum number of threads. + int max_count = max_num_threads; + // The special value 0 means try to determine the total number of cores. + if (max_count == 0) { + // No user-set maximum number of threads, so we need to + // do some hardware detection. + // This is expensive to query so we do it only once. + // Too bad for dynamicness. Also, we dont use the c++11 standard getter + // because Google's coding style currently bans #include <thread_>. + static const int hardware_threads_count = + static_cast<int>(sysconf(_SC_NPROCESSORS_CONF)); + + max_count = hardware_threads_count; + } + + // Basic calculation: take into account max pool size, and + // how many rows we have to feed our kernel. + // The motivation for an absolute minimum number of rows per thread, + // potentially higher than KernelRows, is that very thin thread workload + // currently defeat assumptions of the AddMod generator, resulting + // in substantial bias in TestWithRealData on 24 threads. + // Ideally, the AddMod generator should be aware of global (r,c) coordinates + // so as to be independent of the number of threads. + static const int AbsoluteMinRowsPerThread = 16; + static const int MinRowsPerThread = KernelRows > AbsoluteMinRowsPerThread + ? KernelRows + : AbsoluteMinRowsPerThread; + int thread_count = std::min(max_count, CeilQuotient(rows, MinRowsPerThread)); + + // At this point for small products we already have thread_count==1 so + // we can avoid doing more work; otherwise, we still want to check + // that the cubic size (rows*cols*depth) is big enough to keep + // workers_ busy. + if (thread_count > 1) { + // Empirically determined value. + static const std::uint64_t min_cubic_size_per_thread = 64 * 1024; + + // We can only multiply two out of three sizes without risking overflow + const std::uint64_t cubic_size = + std::uint64_t(rows) * std::uint64_t(cols) * std::uint64_t(depth); + + thread_count = + std::min(thread_count, int(cubic_size / min_cubic_size_per_thread)); + + if (thread_count < 1) { + thread_count = 1; + } + } + + assert(thread_count > 0 && thread_count <= max_count); + return thread_count; +} + +// The main multi-threaded Gemm function. +// To understand it, first read the code of SingleThreadGemm(). +// The parallelization scheme used here is to have this master function +// pack a block of RHS and then start worker threads to pack a block of LHS +// each, and accumulate the corresponding products. +template <typename KernelFormat, typename InputScalar, typename OutputScalar, + typename BitDepthParams, MapOrder LhsOrder, MapOrder RhsOrder, + MapOrder ResultOrder, typename LhsOffset, typename RhsOffset, + typename OutputPipelineType, typename GemmContextType> +void MultiThreadGemm(GemmContextType* context, const KernelBase& kernel, + const MatrixMap<const InputScalar, LhsOrder>& lhs, + const MatrixMap<const InputScalar, RhsOrder>& rhs, + MatrixMap<OutputScalar, ResultOrder>* result, + const LhsOffset& lhs_offset, const RhsOffset& rhs_offset, + const OutputPipelineType& output_pipeline) { + ScopedProfilingLabel label("gemmlowp::MultiThreadGemm"); + + assert(lhs.cols() == rhs.rows()); + + int rows = result->rows(); + int cols = result->cols(); + int depth = lhs.cols(); + + // zero sizes should have been caught earlier and early-returned. + assert(rows > 0); + assert(cols > 0); + assert(depth > 0); + + // The case of rows<cols should have been caught earlier and transposed. + assert(rows >= cols); + + const int thread_count = HowManyThreads<KernelFormat::kRows>( + context->max_num_threads(), rows, cols, depth); + if (thread_count == 1) { + return SingleThreadGemm<KernelFormat, InputScalar, OutputScalar, + BitDepthParams>(context, kernel, lhs, rhs, result, + lhs_offset, rhs_offset, + output_pipeline); + } + assert(thread_count > 1); + + // Simple 1:1 mapping of tasks to physical cores, which is very important + // to getting good multithreaded performance, specially for not-very-large + // GEMMs, and especially on Android. + const int task_count = thread_count; + + Allocator* allocator = context->allocator(); + auto* workers_pool = context->workers_pool(); + + BlockParams block_params; + block_params.Init<KernelFormat>(rows, cols, depth, task_count, + context->l1_bytes_to_use(), + context->l2_bytes_to_use(), + context->l2_rhs_factor()); + + PackedSideBlock<typename KernelFormat::Rhs> packed_rhs(Side::Rhs, allocator, + block_params); + allocator->Commit(); + + // We loop over large blocks of the RHS. + for (int c = 0; c < cols; c += block_params.l2_cols) { + int cs = std::min(block_params.l2_cols, cols - c); + + // Pack a large block of the RHS. + PackRhs(&packed_rhs, rhs.block(0, c, depth, cs)); + + // Give work to each worker. + std::vector<Task*> tasks; + int next_start_row = 0; + for (int n = 0; n < task_count; ++n) { + int start_row = next_start_row; + next_start_row = std::min(rows, RoundUp<KernelFormat::kRows>( + rows * (n + 1) / task_count)); + + int block_rows = next_start_row - start_row; + auto lhs_block = lhs.block(start_row, 0, block_rows, depth); + typedef GemmWithPackedRhsTask< + KernelFormat, InputScalar, OutputScalar, BitDepthParams, LhsOrder, + RhsOrder, ResultOrder, LhsOffset, RhsOffset, OutputPipelineType, + GemmContextType> + TaskType; + tasks.push_back(new TaskType(context, kernel, lhs_block, packed_rhs, result, + MatrixBlockBounds(start_row, c, block_rows, cs), + lhs_offset, rhs_offset, output_pipeline)); + } + // Execute the work on the workers (and partially on this thread). + workers_pool->Execute(tasks); + } + + allocator->Decommit(); +} + +} // namespace gemmlowp + +#endif // GEMMLOWP_INTERNAL_MULTI_THREAD_GEMM_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/internal/output.h b/runtimes/nn/depend/external/gemmlowp/internal/output.h new file mode 100644 index 000000000..8ccb8ee1f --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/internal/output.h @@ -0,0 +1,435 @@ +// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// output.h: processing the 32-bit accumulators output by the unpack +// stage, obtaining the final result matrix entries and storing them into +// the destination matrix. + +#ifndef GEMMLOWP_INTERNAL_OUTPUT_H_ +#define GEMMLOWP_INTERNAL_OUTPUT_H_ + +#include <cmath> +#include <tuple> +#include <type_traits> + +#include "../fixedpoint/fixedpoint.h" +#include "../public/output_stages.h" +#include "simd_wrappers.h" + +namespace gemmlowp { + +template <typename OutputStage, typename InputBufferType> +struct OutputStageEvalBufferImpl { + // This generic template body should never be hit. + static_assert( + std::is_same<InputBufferType, void>::value, + "Unimplemented: missing implementation of this output pipeline stage " + "for this data type. This would happen if some architecture-specific " + "SIMD back-end (output_$arch.h) were incomplete."); +}; + +template <typename OutputStage, typename InputType> +struct OutputStageEvalImpl { + static constexpr int kRows = InputType::kRows; + static constexpr int kCols = InputType::kCols; + using InputBufferType = typename InputType::BufferType; + using BufferEvalImplType = + OutputStageEvalBufferImpl<OutputStage, InputBufferType>; + using OutputBufferType = typename BufferEvalImplType::OutputType; + using OutputScalarType = typename OutputBufferType::ScalarType; + using OutputType = RegisterBlock<OutputScalarType, kRows, kCols>; + + OutputStageEvalImpl(const OutputStage& s) : buffer_eval_impl(s) {} + + OutputType Eval(InputType input, int, int) const { + OutputType output; + output.buf = buffer_eval_impl.Eval(input.buf); + return output; + } + + const BufferEvalImplType buffer_eval_impl; +}; + +template <int Size> +struct OutputStageEvalBufferImpl<OutputStageQuantizeDownInt32ToUint8Scale, + RegisterBuffer<std::int32_t, Size>> { + using InputType = RegisterBuffer<std::int32_t, Size>; + using OutputType = RegisterBuffer<std::int32_t, Size>; + + typedef OutputStageQuantizeDownInt32ToUint8Scale OutputStage; + + OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {} + + OutputType Eval(InputType input) const { + const int result_shift = output_stage.result_shift; + const std::int32_t result_mult_int = output_stage.result_mult_int; + using RegisterType = typename InputType::RegisterType; + const RegisterType result_offset = + Dup<RegisterType>(output_stage.result_offset); + OutputType output; + for (int i = 0; i < InputType::kRegisterCount; i++) { + output.reg[i] = RoundingDivideByPOT( + Mul(Add(input.reg[i], result_offset), result_mult_int), result_shift); + } + return output; + } + + const OutputStage& output_stage; +}; + +template <int Rows, int Cols, VectorShape Shape> +struct OutputStageEvalImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>, + RegisterBlock<std::int32_t, Rows, Cols>> { + typedef RegisterBlock<std::int32_t, Rows, Cols> InputType; + typedef RegisterBlock<std::int32_t, Rows, Cols> OutputType; + typedef OutputStageQuantizeDownInt32ToUint8ScalePC<Shape> OutputStage; + + OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {} + + OutputType Eval(InputType input, int row, int col) const { + OutputType output; + const int result_shift = output_stage.result_shift; + const int pos = Shape == VectorShape::Col ? row : col; + const auto result_mult_int = + LoadForBroadcasting<InputType>(output_stage.result_mult_int, pos); + const auto result_offset = + LoadForBroadcasting<InputType>(output_stage.result_offset, pos); + const auto dividend = BroadcastMul<InputType>( + BroadcastAdd<InputType>(input, result_offset), result_mult_int); + for (int i = 0; i < InputType::kRegisterCount; i++) { + output.buf.reg[i] = + RoundingDivideByPOT(dividend.buf.reg[i], result_shift); + } + return output; + } + + const OutputStage& output_stage; +}; + +template <int Size> +struct OutputStageEvalBufferImpl< + OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint, + RegisterBuffer<std::int32_t, Size>> { + typedef RegisterBuffer<std::int32_t, Size> InputType; + typedef RegisterBuffer<std::int32_t, Size> OutputType; + + typedef OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint OutputStage; + + OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {} + + OutputType Eval(InputType input) const { + OutputType output; + using RegisterType = typename InputType::RegisterType; + const RegisterType result_offset_after_shift = + Dup<RegisterType>(output_stage.result_offset_after_shift); + for (int i = 0; i < InputType::kRegisterCount; i++) { + const RegisterType mulhigh_val = SaturatingRoundingDoublingHighMul( + input.reg[i], output_stage.result_fixedpoint_multiplier); + output.reg[i] = + Add(RoundingDivideByPOT(mulhigh_val, output_stage.result_shift), + result_offset_after_shift); + } + return output; + } + + const OutputStage& output_stage; +}; + +// Implementation of OutputStageSaturatingCastToUint8 for scalar data +template <int Size> +struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, + RegisterBuffer<std::int32_t, Size>> { + typedef RegisterBuffer<std::int32_t, Size> InputType; + typedef RegisterBuffer<std::uint8_t, Size> OutputType; + static_assert(InputType::kRegisterLanes == 1, + "This path is only for scalar values"); + + typedef OutputStageSaturatingCastToUint8 OutputStage; + + OutputStageEvalBufferImpl(const OutputStage&) {} + + OutputType Eval(InputType input) const { + OutputType output; + for (int i = 0; i < InputType::kRegisterCount; i++) { + std::int32_t data = input.reg[i]; + output.reg[i] = data > 255 ? 255 : data < 0 ? 0 : data; + } + return output; + } +}; + +template <int Rows, int Cols, typename VectorType> +struct OutputStageEvalImpl<OutputStageBiasAddition<VectorType>, + RegisterBlock<std::int32_t, Rows, Cols>> { + typedef RegisterBlock<std::int32_t, Rows, Cols> InputType; + typedef RegisterBlock<std::int32_t, Rows, Cols> OutputType; + typedef OutputStageBiasAddition<VectorType> OutputStage; + + OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {} + + OutputType Eval(InputType input, int row, int col) const { + const int pos = VectorType::kShape == VectorShape::Row ? col : row; + return BroadcastAdd<InputType>( + input, LoadForBroadcasting<InputType>(output_stage.bias_vector, pos)); + } + + const OutputStage& output_stage; +}; + +template <int Size> +struct OutputStageEvalBufferImpl<OutputStageClamp, + RegisterBuffer<std::int32_t, Size>> { + typedef RegisterBuffer<std::int32_t, Size> InputType; + typedef RegisterBuffer<std::int32_t, Size> OutputType; + + typedef OutputStageClamp OutputStage; + + OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {} + + OutputType Eval(InputType input) const { + using RegisterType = typename InputType::RegisterType; + const RegisterType min = Dup<RegisterType>(output_stage.min); + const RegisterType max = Dup<RegisterType>(output_stage.max); + OutputType output; + for (int i = 0; i < InputType::kRegisterCount; i++) { + output.reg[i] = Min(Max(input.reg[i], min), max); + } + return output; + } + + const OutputStage& output_stage; +}; + +template <int Size> +struct OutputStageEvalBufferImpl<OutputStageTanh, + RegisterBuffer<std::int32_t, Size>> { + typedef RegisterBuffer<std::int32_t, Size> InputType; + typedef RegisterBuffer<std::int32_t, Size> OutputType; + using RegisterType = typename InputType::RegisterType; + typedef RegisterType DataType; + typedef OutputStageTanh OutputStage; + + OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) { + const std::int32_t real_zero_as_int32 = output_stage.real_zero_as_int32; + const std::int32_t real_amplitude_as_int32 = + output_stage.real_amplitude_as_int32; + + input_cutoff_min = real_zero_as_int32 - 8 * real_amplitude_as_int32; + input_cutoff_max = real_zero_as_int32 + 8 * real_amplitude_as_int32; + output_min = real_zero_as_int32 - real_amplitude_as_int32; + output_max = real_zero_as_int32 + real_amplitude_as_int32; + + double inverse_amplitude_normalized_double = 1.0 / real_amplitude_as_int32; + inverse_amplitude_neg_exponent = 0; + while (inverse_amplitude_normalized_double < 0.5) { + inverse_amplitude_normalized_double *= 2; + inverse_amplitude_neg_exponent++; + } + inverse_amplitude_normalized = FixedPoint<DataType, 0>::FromDouble( + inverse_amplitude_normalized_double); + + double amplitude_normalized_double = real_amplitude_as_int32; + amplitude_exponent = 0; + while (amplitude_normalized_double >= 1.0) { + amplitude_normalized_double *= 0.5; + amplitude_exponent++; + } + amplitude_normalized = + FixedPoint<DataType, 0>::FromDouble(amplitude_normalized_double); + } + + OutputType Eval(InputType input) const { + const std::int32_t real_zero_as_int32 = output_stage.real_zero_as_int32; + + typedef FixedPoint<DataType, 3> F3; + typedef FixedPoint<DataType, 0> F0; + + OutputType output; + + for (int i = 0; i < OutputType::kRegisterCount; i++) { + // fixed-point affine transformation + DataType input_centered = + Sub(input.reg[i], Dup<DataType>(real_zero_as_int32)); + F3 fixedpoint_input = + F3::FromRaw(input_centered) * inverse_amplitude_normalized; + // left shift + fixedpoint_input.raw() = ShiftLeft(fixedpoint_input.raw(), + 28 - inverse_amplitude_neg_exponent); + // fixed-point tanh and multiplication + F0 fixedpoint_output = tanh(fixedpoint_input) * amplitude_normalized; + // right shift + DataType int32_output = + Add(Dup<DataType>(real_zero_as_int32), + ShiftRight(fixedpoint_output.raw(), 31 - amplitude_exponent)); + + DataType mask_if_below_cutoff_min = + MaskIfLessThanOrEqual(input.reg[i], Dup<DataType>(input_cutoff_min)); + DataType mask_if_above_cutoff_max = MaskIfGreaterThanOrEqual( + input.reg[i], Dup<DataType>(input_cutoff_max)); + + output.reg[i] = SelectUsingMask( + mask_if_below_cutoff_min, Dup<DataType>(output_min), + SelectUsingMask(mask_if_above_cutoff_max, Dup<DataType>(output_max), + int32_output)); + } + return output; + } + + const OutputStage& output_stage; + std::int32_t input_cutoff_min, input_cutoff_max; + std::int32_t output_min, output_max; + FixedPoint<DataType, 0> inverse_amplitude_normalized; + int inverse_amplitude_neg_exponent; + FixedPoint<DataType, 0> amplitude_normalized; + int amplitude_exponent; +}; + +// OutputPipelineOutputType is a helper to determine the output data type of a +// pipeline, for a +// given input data type. It is a recursive template; see the explanation on +// OutputPipelineEvalImpl below. +template <typename OutputPipelineType, int FirstStage, typename InputType, + bool StopRecursion = + FirstStage == std::tuple_size<OutputPipelineType>::value> +struct OutputPipelineOutputType { + typedef typename std::tuple_element<FirstStage, OutputPipelineType>::type + FirstStageType; + typedef typename OutputStageEvalImpl<FirstStageType, InputType>::OutputType + FirstStageOutputType; + typedef typename OutputPipelineOutputType<OutputPipelineType, FirstStage + 1, + FirstStageOutputType>::Type Type; +}; + +template <typename OutputPipelineType, int FirstStage, typename InputType> +struct OutputPipelineOutputType<OutputPipelineType, FirstStage, InputType, + true> { + typedef InputType Type; +}; + +// OutputPipelineEvalImpl is a helper to implement the evaluation of +// the whole pipeline. It is a recursive template to implement compile-time +// unrolling of the loop over all pipeline stages. The 'FirstStage' parameter +// is how we implement recursion: each specialization implements only +// evaluation starting at 'FirstStage'. The StopRecursion parameter is just a +// helper to implement the termination of the recursion as a partial +// specialization below. +template <typename OutputPipelineType, int FirstStage, typename InputType, + bool StopRecursion = + FirstStage == std::tuple_size<OutputPipelineType>::value> +struct OutputPipelineEvalImpl { + typedef typename std::tuple_element<FirstStage, OutputPipelineType>::type + FirstStageType; + typedef typename OutputStageEvalImpl<FirstStageType, InputType>::OutputType + FirstStageOutputType; + typedef typename OutputPipelineOutputType<OutputPipelineType, FirstStage, + InputType>::Type OutputType; + + OutputPipelineEvalImpl(const OutputPipelineType& output_pipeline) + : head_impl(std::get<FirstStage>(output_pipeline)), + tail_impl(output_pipeline) {} + + OutputType Eval(InputType input, int row, int col) const { + // Evaluate the first stage. + FirstStageOutputType first_stage_output = head_impl.Eval(input, row, col); + // Recurse into the remaining stages. + return tail_impl.Eval(first_stage_output, row, col); + } + + const OutputStageEvalImpl<FirstStageType, InputType> head_impl; + const OutputPipelineEvalImpl<OutputPipelineType, FirstStage + 1, + FirstStageOutputType> + tail_impl; +}; + +// Specialization on 'StopRecursion' for terminating the recursion. +template <typename OutputPipelineType, int FirstStage, typename InputType> +struct OutputPipelineEvalImpl<OutputPipelineType, FirstStage, InputType, true> { + OutputPipelineEvalImpl(const OutputPipelineType&) {} + + InputType Eval(InputType input, int, int) const { + // Terminating the recursion. + return input; + } +}; + +template <typename RegisterBlockType, typename DstType> +struct StoreFinalOutputImpl { + static_assert(std::is_same<RegisterBlockType, void>::value, + "This generic impl should never be hit"); +}; + +template <typename ScalarType, int Rows, int Cols, typename DstType> +struct StoreFinalOutputImpl<RegisterBlock<ScalarType, Rows, Cols>, DstType> { + using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; + static void Run(const RegisterBlockType& src, DstType* dst, int row, + int col) { + for (int r = 0; r < Rows; r++) { + for (int c = 0; c < Cols; c++) { + *dst->data(row + r, col + c) = src.buf.reg[r + c * Rows]; + } + } + } +}; + +// StoreFinalOutput takes the final value at the end of the output pipeline and +// stores it into the destination matrix. It can be specialized for different +// data types; the generic implementation here is typically used only for plain +// old scalar (not SIMD) types. +template <typename RegisterBlockType, typename DstType> +void StoreFinalOutput(RegisterBlockType src, DstType* dst, int row, int col) { + StoreFinalOutputImpl<RegisterBlockType, DstType>::Run(src, dst, row, col); +} + +template <typename OutputPipelineType, typename InputType> +struct OutputPipelineExecutor { + OutputPipelineExecutor(const OutputPipelineType& output_pipeline) + : output_pipeline_eval_impl_(output_pipeline) {} + + // RunOutputPipeline is the entry point into the output pipeline evaluation + // code. It should be the only thing that unpack code calls. It takes the + // result + // of the unpack stage and stores it into the destination matrix. + template <typename DstType> + void Execute(InputType input, DstType* dst, int src_global_row, + int src_global_col, int dst_row, int dst_col) const { + // Statically assert that the output pipeline matches the given destination + // matrix's scalar type. + typedef typename OutputPipelineOutputType< + OutputPipelineType, 0, InputType>::Type::BufferType::ScalarType + + ScalarOutputType; + typedef typename DstType::Scalar ScalarDstType; + static_assert(std::is_same<ScalarOutputType, ScalarDstType>::value, + "mismatched destination scalar type and output pipeline"); + + // Evaluate the output pipeline. + auto output = + output_pipeline_eval_impl_.Eval(input, src_global_row, src_global_col); + // Store the result into the destination matrix. + StoreFinalOutput(output, dst, dst_row, dst_col); + } + + const OutputPipelineEvalImpl<OutputPipelineType, 0, InputType> + output_pipeline_eval_impl_; +}; + +} // namespace gemmlowp + +#ifdef GEMMLOWP_NEON +#include "output_neon.h" +#elif defined(GEMMLOWP_SSE4) +#include "output_sse.h" +#endif + +#endif // GEMMLOWP_INTERNAL_OUTPUT_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/internal/output_neon.h b/runtimes/nn/depend/external/gemmlowp/internal/output_neon.h new file mode 100644 index 000000000..7e111e586 --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/internal/output_neon.h @@ -0,0 +1,432 @@ +// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// output_neon.h: optimized NEON specializations of the templates in output.h. + +#ifndef GEMMLOWP_INTERNAL_OUTPUT_NEON_H_ +#define GEMMLOWP_INTERNAL_OUTPUT_NEON_H_ + +#include "output.h" + +#include <arm_neon.h> + +namespace gemmlowp { + +template <> +struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, + RegBufferInt32<4>> { + typedef RegBufferInt32<4> InputType; + typedef RegBufferUint8<4> OutputType; + + typedef OutputStageSaturatingCastToUint8 OutputStage; + + OutputStageEvalBufferImpl(const OutputStage&) {} + + OutputType Eval(InputType input) const { + OutputType output; + int16x4_t res_16 = vqmovn_s32(input.reg[0]); + uint8x8_t res_8 = vqmovun_s16(vcombine_s16(res_16, res_16)); + output.reg[0] = vget_lane_u32(vreinterpret_u32_u8(res_8), 0); + return output; + } +}; + +template <> +struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, + RegBufferInt32<8>> { + typedef RegBufferInt32<8> InputType; + typedef RegBufferUint8<8> OutputType; + + typedef OutputStageSaturatingCastToUint8 OutputStage; + + OutputStageEvalBufferImpl(const OutputStage&) {} + + OutputType Eval(InputType input) const { + OutputType output; + int16x8_t res_16 = + vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1])); + output.reg[0] = vqmovun_s16(res_16); + return output; + } +}; + +template <> +struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, + RegBufferInt32<16>> { + typedef RegBufferInt32<16> InputType; + typedef RegBufferUint8<16> OutputType; + + typedef OutputStageSaturatingCastToUint8 OutputStage; + + OutputStageEvalBufferImpl(const OutputStage&) {} + + OutputType Eval(InputType input) const { + OutputType output; + int16x8_t res_16_0 = + vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1])); + int16x8_t res_16_1 = + vcombine_s16(vqmovn_s32(input.reg[2]), vqmovn_s32(input.reg[3])); + output.reg[0] = vqmovun_s16(res_16_0); + output.reg[1] = vqmovun_s16(res_16_1); + return output; + } +}; + +template <> +struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, + RegBufferInt32<32>> { + typedef RegBufferInt32<32> InputType; + typedef RegBufferUint8<32> OutputType; + + typedef OutputStageSaturatingCastToUint8 OutputStage; + + OutputStageEvalBufferImpl(const OutputStage&) {} + + OutputType Eval(InputType input) const { + OutputType output; + int16x8_t res_16[4]; + for (int i = 0; i < 4; i++) { + res_16[i] = vcombine_s16(vqmovn_s32(input.reg[2 * i]), + vqmovn_s32(input.reg[2 * i + 1])); + } + for (int i = 0; i < 4; i++) { + output.reg[i] = vqmovun_s16(res_16[i]); + } + return output; + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> { + static void Run(const RegBlockInt32<8, 1>& src, DstType* dst, int row, + int col) { + if (DstType::kOrder == MapOrder::ColMajor) { + StoreInt32x4(dst->data(row, col), src.buf.reg[0]); + StoreInt32x4(dst->data(row + 4, col), src.buf.reg[1]); + } else { + *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]); + *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]); + *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]); + *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]); + *dst->data(row + 4, col) = GetLane<0>(src.buf.reg[1]); + *dst->data(row + 5, col) = GetLane<1>(src.buf.reg[1]); + *dst->data(row + 6, col) = GetLane<2>(src.buf.reg[1]); + *dst->data(row + 7, col) = GetLane<3>(src.buf.reg[1]); + } + } +}; + +inline RegBlockInt32<4, 4> Transpose(const RegBlockInt32<4, 4>& src) { + const int32x4x2_t t0 = vtrnq_s32(src.buf.reg[0], src.buf.reg[1]); + const int32x4x2_t t1 = vtrnq_s32(src.buf.reg[2], src.buf.reg[3]); + RegBlockInt32<4, 4> result; + result.buf.reg[0] = + vcombine_s32(vget_low_s32(t0.val[0]), vget_low_s32(t1.val[0])); + result.buf.reg[1] = + vcombine_s32(vget_low_s32(t0.val[1]), vget_low_s32(t1.val[1])); + result.buf.reg[2] = + vcombine_s32(vget_high_s32(t0.val[0]), vget_high_s32(t1.val[0])); + result.buf.reg[3] = + vcombine_s32(vget_high_s32(t0.val[1]), vget_high_s32(t1.val[1])); + return result; +} + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockInt32<4, 4>, DstType> { + static void Run(const RegBlockInt32<4, 4>& src, DstType* dst, int row, + int col) { + const auto& block = + DstType::kOrder == MapOrder::ColMajor ? src : Transpose(src); + std::int32_t* dst_ptr = dst->data(row, col); + int stride = dst->stride(); + for (int i = 0; i < 4; i++) { + vst1q_s32(dst_ptr + i * stride, block.buf.reg[i]); + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> { + static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row, + int col) { + std::int32_t* dst_ptr = dst->data(row, col); + if (DstType::kOrder == MapOrder::ColMajor) { + int col_stride = dst->cols_stride(); + for (int i = 0; i < 4; i++) { + vst1q_s32(dst_ptr + i * col_stride + 0, src.buf.reg[2 * i + 0]); + vst1q_s32(dst_ptr + i * col_stride + 4, src.buf.reg[2 * i + 1]); + } + } else { + int row_stride = dst->rows_stride(); + RegBlockInt32<4, 4> top; + top.buf.reg[0] = src.buf.reg[0]; + top.buf.reg[1] = src.buf.reg[2]; + top.buf.reg[2] = src.buf.reg[4]; + top.buf.reg[3] = src.buf.reg[6]; + const auto transpose_top = Transpose(top); + for (int i = 0; i < 4; i++) { + vst1q_s32(dst_ptr + i * row_stride, transpose_top.buf.reg[i]); + } + RegBlockInt32<4, 4> bottom; + bottom.buf.reg[0] = src.buf.reg[1]; + bottom.buf.reg[1] = src.buf.reg[3]; + bottom.buf.reg[2] = src.buf.reg[5]; + bottom.buf.reg[3] = src.buf.reg[7]; + const auto transpose_bottom = Transpose(bottom); + for (int i = 0; i < 4; i++) { + vst1q_s32(dst_ptr + (i + 4) * row_stride, transpose_bottom.buf.reg[i]); + } + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> { + static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row, + int col) { + std::int32_t* dst_ptr = dst->data(row, col); + if (DstType::kOrder == MapOrder::ColMajor) { + int col_stride = dst->cols_stride(); + for (int i = 0; i < 8; i++) { + vst1q_s32(dst_ptr + i * col_stride, src.buf.reg[2 * i]); + vst1q_s32(dst_ptr + i * col_stride + 4, src.buf.reg[2 * i + 1]); + } + } else { + int row_stride = dst->rows_stride(); + RegBlockInt32<4, 4> top_left; + top_left.buf.reg[0] = src.buf.reg[0]; + top_left.buf.reg[1] = src.buf.reg[2]; + top_left.buf.reg[2] = src.buf.reg[4]; + top_left.buf.reg[3] = src.buf.reg[6]; + const auto transpose_top_left = Transpose(top_left); + for (int i = 0; i < 4; i++) { + vst1q_s32(dst_ptr + i * row_stride, transpose_top_left.buf.reg[i]); + } + RegBlockInt32<4, 4> bottom_left; + bottom_left.buf.reg[0] = src.buf.reg[1]; + bottom_left.buf.reg[1] = src.buf.reg[3]; + bottom_left.buf.reg[2] = src.buf.reg[5]; + bottom_left.buf.reg[3] = src.buf.reg[7]; + const auto transpose_bottom_left = Transpose(bottom_left); + for (int i = 0; i < 4; i++) { + vst1q_s32(dst_ptr + (i + 4) * row_stride, + transpose_bottom_left.buf.reg[i]); + } + RegBlockInt32<4, 4> top_right; + top_right.buf.reg[0] = src.buf.reg[8]; + top_right.buf.reg[1] = src.buf.reg[10]; + top_right.buf.reg[2] = src.buf.reg[12]; + top_right.buf.reg[3] = src.buf.reg[14]; + const auto transpose_top_right = Transpose(top_right); + for (int i = 0; i < 4; i++) { + vst1q_s32(dst_ptr + i * row_stride + 4, transpose_top_right.buf.reg[i]); + } + RegBlockInt32<4, 4> bottom_right; + bottom_right.buf.reg[0] = src.buf.reg[9]; + bottom_right.buf.reg[1] = src.buf.reg[11]; + bottom_right.buf.reg[2] = src.buf.reg[13]; + bottom_right.buf.reg[3] = src.buf.reg[15]; + const auto transpose_bottom_right = Transpose(bottom_right); + for (int i = 0; i < 4; i++) { + vst1q_s32(dst_ptr + (i + 4) * row_stride + 4, + transpose_bottom_right.buf.reg[i]); + } + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> { + static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row, + int col) { + std::int32_t* dst_ptr = dst->data(row, col); + if (DstType::kOrder == MapOrder::ColMajor) { + vst1q_s32(dst_ptr, src.buf.reg[0]); + } else { + int row_stride = dst->rows_stride(); + vst1q_lane_s32(dst_ptr + 0 * row_stride, src.buf.reg[0], 0); + vst1q_lane_s32(dst_ptr + 1 * row_stride, src.buf.reg[0], 1); + vst1q_lane_s32(dst_ptr + 2 * row_stride, src.buf.reg[0], 2); + vst1q_lane_s32(dst_ptr + 3 * row_stride, src.buf.reg[0], 3); + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockInt32<1, 4>, DstType> { + static void Run(const RegBlockInt32<1, 4>& src, DstType* dst, int row, + int col) { + std::int32_t* dst_ptr = dst->data(row, col); + if (DstType::kOrder == MapOrder::RowMajor) { + vst1q_s32(dst_ptr, src.buf.reg[0]); + } else { + int col_stride = dst->cols_stride(); + vst1q_lane_s32(dst_ptr + 0 * col_stride, src.buf.reg[0], 0); + vst1q_lane_s32(dst_ptr + 1 * col_stride, src.buf.reg[0], 1); + vst1q_lane_s32(dst_ptr + 2 * col_stride, src.buf.reg[0], 2); + vst1q_lane_s32(dst_ptr + 3 * col_stride, src.buf.reg[0], 3); + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockUint8<4, 1>, DstType> { + static void Run(const RegBlockUint8<4, 1>& src, DstType* dst, int row, + int col) { + const std::uint32_t src_reg = src.buf.reg[0]; + for (int i = 0; i < 4; i++) { + *dst->data(row + i, col) = (src_reg >> (8 * i)); + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockUint8<1, 4>, DstType> { + static void Run(const RegBlockUint8<1, 4>& src, DstType* dst, int row, + int col) { + for (int i = 0; i < 4; i++) { + *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i)); + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockUint8<8, 1>, DstType> { + static void Run(const RegBlockUint8<8, 1>& src, DstType* dst, int row, + int col) { + std::uint8_t* dst_ptr = dst->data(row, col); + if (DstType::kOrder == MapOrder::ColMajor) { + vst1_u8(dst_ptr, src.buf.reg[0]); + } else { + const int row_stride = dst->rows_stride(); + vst1_lane_u8(dst_ptr + 0 * row_stride, src.buf.reg[0], 0); + vst1_lane_u8(dst_ptr + 1 * row_stride, src.buf.reg[0], 1); + vst1_lane_u8(dst_ptr + 2 * row_stride, src.buf.reg[0], 2); + vst1_lane_u8(dst_ptr + 3 * row_stride, src.buf.reg[0], 3); + vst1_lane_u8(dst_ptr + 4 * row_stride, src.buf.reg[0], 4); + vst1_lane_u8(dst_ptr + 5 * row_stride, src.buf.reg[0], 5); + vst1_lane_u8(dst_ptr + 6 * row_stride, src.buf.reg[0], 6); + vst1_lane_u8(dst_ptr + 7 * row_stride, src.buf.reg[0], 7); + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockUint8<4, 4>, DstType> { + static void Run(const RegBlockUint8<4, 4>& src, DstType* dst, int row, + int col) { + std::uint8_t* dst_ptr = dst->data(row, col); + const int row_stride = dst->rows_stride(); + const int col_stride = dst->cols_stride(); + for (int i = 0; i < 2; i++) { + vst1_lane_u8(dst_ptr + 0 * row_stride + (2 * i + 0) * col_stride, + src.buf.reg[i], 0); + vst1_lane_u8(dst_ptr + 1 * row_stride + (2 * i + 0) * col_stride, + src.buf.reg[i], 1); + vst1_lane_u8(dst_ptr + 2 * row_stride + (2 * i + 0) * col_stride, + src.buf.reg[i], 2); + vst1_lane_u8(dst_ptr + 3 * row_stride + (2 * i + 0) * col_stride, + src.buf.reg[i], 3); + vst1_lane_u8(dst_ptr + 0 * row_stride + (2 * i + 1) * col_stride, + src.buf.reg[i], 4); + vst1_lane_u8(dst_ptr + 1 * row_stride + (2 * i + 1) * col_stride, + src.buf.reg[i], 5); + vst1_lane_u8(dst_ptr + 2 * row_stride + (2 * i + 1) * col_stride, + src.buf.reg[i], 6); + vst1_lane_u8(dst_ptr + 3 * row_stride + (2 * i + 1) * col_stride, + src.buf.reg[i], 7); + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> { + static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row, + int col) { + std::uint8_t* dst_ptr = dst->data(row, col); + if (DstType::kOrder == MapOrder::ColMajor) { + int col_stride = dst->cols_stride(); + for (int i = 0; i < 4; i++) { + vst1_u8(dst_ptr + i * col_stride, src.buf.reg[i]); + } + } else { + for (int i = 0; i < 4; i++) { + int row_stride = dst->rows_stride(); + std::uint8_t* col_ptr = dst_ptr + i; + vst1_lane_u8(col_ptr + 0 * row_stride, src.buf.reg[i], 0); + vst1_lane_u8(col_ptr + 1 * row_stride, src.buf.reg[i], 1); + vst1_lane_u8(col_ptr + 2 * row_stride, src.buf.reg[i], 2); + vst1_lane_u8(col_ptr + 3 * row_stride, src.buf.reg[i], 3); + vst1_lane_u8(col_ptr + 4 * row_stride, src.buf.reg[i], 4); + vst1_lane_u8(col_ptr + 5 * row_stride, src.buf.reg[i], 5); + vst1_lane_u8(col_ptr + 6 * row_stride, src.buf.reg[i], 6); + vst1_lane_u8(col_ptr + 7 * row_stride, src.buf.reg[i], 7); + } + } + } +}; + +inline RegBlockUint8<8, 8> Transpose(const RegBlockUint8<8, 8>& src) { + uint8x8x2_t a[4]; + a[0] = vtrn_u8(src.buf.reg[0], src.buf.reg[1]); + a[1] = vtrn_u8(src.buf.reg[2], src.buf.reg[3]); + a[2] = vtrn_u8(src.buf.reg[4], src.buf.reg[5]); + a[3] = vtrn_u8(src.buf.reg[6], src.buf.reg[7]); + uint16x4x2_t b[4]; + b[0] = vtrn_u16(vreinterpret_u16_u8(a[0].val[0]), + vreinterpret_u16_u8(a[1].val[0])); + b[1] = vtrn_u16(vreinterpret_u16_u8(a[0].val[1]), + vreinterpret_u16_u8(a[1].val[1])); + b[2] = vtrn_u16(vreinterpret_u16_u8(a[2].val[0]), + vreinterpret_u16_u8(a[3].val[0])); + b[3] = vtrn_u16(vreinterpret_u16_u8(a[2].val[1]), + vreinterpret_u16_u8(a[3].val[1])); + uint32x2x2_t c[4]; + c[0] = vtrn_u32(vreinterpret_u32_u16(b[0].val[0]), + vreinterpret_u32_u16(b[2].val[0])); + c[1] = vtrn_u32(vreinterpret_u32_u16(b[1].val[0]), + vreinterpret_u32_u16(b[3].val[0])); + c[2] = vtrn_u32(vreinterpret_u32_u16(b[0].val[1]), + vreinterpret_u32_u16(b[2].val[1])); + c[3] = vtrn_u32(vreinterpret_u32_u16(b[1].val[1]), + vreinterpret_u32_u16(b[3].val[1])); + RegBlockUint8<8, 8> result; + result.buf.reg[0] = vreinterpret_u8_u32(c[0].val[0]); + result.buf.reg[1] = vreinterpret_u8_u32(c[1].val[0]); + result.buf.reg[2] = vreinterpret_u8_u32(c[2].val[0]); + result.buf.reg[3] = vreinterpret_u8_u32(c[3].val[0]); + result.buf.reg[4] = vreinterpret_u8_u32(c[0].val[1]); + result.buf.reg[5] = vreinterpret_u8_u32(c[1].val[1]); + result.buf.reg[6] = vreinterpret_u8_u32(c[2].val[1]); + result.buf.reg[7] = vreinterpret_u8_u32(c[3].val[1]); + return result; +} + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> { + static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row, + int col) { + const auto& block = + DstType::kOrder == MapOrder::ColMajor ? src : Transpose(src); + std::uint8_t* dst_ptr = dst->data(row, col); + int stride = dst->stride(); + for (int i = 0; i < 8; i++) { + vst1_u8(dst_ptr + i * stride, block.buf.reg[i]); + } + } +}; + +} // namespace gemmlowp + +#endif // GEMMLOWP_INTERNAL_OUTPUT_NEON_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/internal/output_sse.h b/runtimes/nn/depend/external/gemmlowp/internal/output_sse.h new file mode 100644 index 000000000..5c0625398 --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/internal/output_sse.h @@ -0,0 +1,354 @@ +// Copyright 2015 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// output_sse.h: optimized SSE4.2 specializations of the templates in output.h. + +#ifndef GEMMLOWP_INTERNAL_OUTPUT_SSE_H_ +#define GEMMLOWP_INTERNAL_OUTPUT_SSE_H_ + +#include "output.h" + +#include <smmintrin.h> + +namespace gemmlowp { + +template <> +struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, + RegBufferInt32<4>> { + typedef RegBufferInt32<4> InputType; + typedef RegBufferUint8<4> OutputType; + + typedef OutputStageSaturatingCastToUint8 OutputStage; + + OutputStageEvalBufferImpl(const OutputStage&) {} + + OutputType Eval(InputType input) const { + OutputType output; + __m128i res_16 = _mm_packs_epi32(input.reg[0], input.reg[0]); + __m128i res_8 = _mm_packus_epi16(res_16, res_16); + output.reg[0] = _mm_cvtsi128_si32(res_8); + return output; + } +}; + +template <> +struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, + RegBufferInt32<8>> { + typedef RegBufferInt32<8> InputType; + typedef RegBufferUint8<8> OutputType; + + typedef OutputStageSaturatingCastToUint8 OutputStage; + + OutputStageEvalBufferImpl(const OutputStage&) {} + + OutputType Eval(InputType input) const { + OutputType output; + __m128i res_16 = _mm_packs_epi32(input.reg[0], input.reg[1]); + __m128i res_8 = _mm_packus_epi16(res_16, res_16); + output.reg[0] = _mm_extract_epi32(res_8, 0); + output.reg[1] = _mm_extract_epi32(res_8, 1); + return output; + } +}; + +template <> +struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, + RegBufferInt32<16>> { + typedef RegBufferInt32<16> InputType; + typedef RegBufferUint8<16> OutputType; + + typedef OutputStageSaturatingCastToUint8 OutputStage; + + OutputStageEvalBufferImpl(const OutputStage&) {} + + OutputType Eval(InputType input) const { + OutputType output; + __m128i res_16_0 = _mm_packs_epi32(input.reg[0], input.reg[1]); + __m128i res_16_1 = _mm_packs_epi32(input.reg[2], input.reg[3]); + output.reg[0] = _mm_packus_epi16(res_16_0, res_16_1); + return output; + } +}; + +template <> +struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, + RegBufferInt32<32>> { + typedef RegBufferInt32<32> InputType; + typedef RegBufferUint8<32> OutputType; + + typedef OutputStageSaturatingCastToUint8 OutputStage; + + OutputStageEvalBufferImpl(const OutputStage&) {} + + OutputType Eval(InputType input) const { + OutputType output; + __m128i res_16_0 = _mm_packs_epi32(input.reg[0], input.reg[1]); + __m128i res_16_1 = _mm_packs_epi32(input.reg[2], input.reg[3]); + output.reg[0] = _mm_packus_epi16(res_16_0, res_16_1); + __m128i res_16_2 = _mm_packs_epi32(input.reg[4], input.reg[5]); + __m128i res_16_3 = _mm_packs_epi32(input.reg[6], input.reg[7]); + output.reg[1] = _mm_packus_epi16(res_16_2, res_16_3); + return output; + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> { + static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row, + int col) { + if (DstType::kOrder == MapOrder::ColMajor) { + StoreInt32x4(dst->data(row, col), src.buf.reg[0]); + } else { + *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]); + *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]); + *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]); + *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]); + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> { + static void Run(const RegBlockInt32<8, 1>& src, DstType* dst, int row, + int col) { + if (DstType::kOrder == MapOrder::ColMajor) { + StoreInt32x4(dst->data(row, col), src.buf.reg[0]); + StoreInt32x4(dst->data(row + 4, col), src.buf.reg[1]); + } else { + *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]); + *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]); + *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]); + *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]); + *dst->data(row + 4, col) = GetLane<0>(src.buf.reg[1]); + *dst->data(row + 5, col) = GetLane<1>(src.buf.reg[1]); + *dst->data(row + 6, col) = GetLane<2>(src.buf.reg[1]); + *dst->data(row + 7, col) = GetLane<3>(src.buf.reg[1]); + } + } +}; + +inline RegBlockInt32<4, 4> Transpose(const RegBlockInt32<4, 4>& src) { + __m128i t0 = _mm_unpacklo_epi32(src.buf.reg[0], src.buf.reg[1]); + __m128i t1 = _mm_unpacklo_epi32(src.buf.reg[2], src.buf.reg[3]); + __m128i t2 = _mm_unpackhi_epi32(src.buf.reg[0], src.buf.reg[1]); + __m128i t3 = _mm_unpackhi_epi32(src.buf.reg[2], src.buf.reg[3]); + + RegBlockInt32<4, 4> result; + result.buf.reg[0] = _mm_unpacklo_epi64(t0, t1); + result.buf.reg[1] = _mm_unpackhi_epi64(t0, t1); + result.buf.reg[2] = _mm_unpacklo_epi64(t2, t3); + result.buf.reg[3] = _mm_unpackhi_epi64(t2, t3); + return result; +} + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockInt32<4, 4>, DstType> { + static void Run(const RegBlockInt32<4, 4>& src, DstType* dst, int row, + int col) { + if (DstType::kOrder == MapOrder::ColMajor) { + for (int i = 0; i < 4; i++) { + StoreInt32x4(dst->data(row, col + i), src.buf.reg[i]); + } + } else { + const auto transpose = Transpose(src); + for (int i = 0; i < 4; i++) { + StoreInt32x4(dst->data(row + i, col), transpose.buf.reg[i]); + } + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> { + static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row, + int col) { + if (DstType::kOrder == MapOrder::ColMajor) { + for (int i = 0; i < 4; i++) { + StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]); + StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]); + } + } else { + RegBlockInt32<4, 4> top; + top.buf.reg[0] = src.buf.reg[0]; + top.buf.reg[1] = src.buf.reg[2]; + top.buf.reg[2] = src.buf.reg[4]; + top.buf.reg[3] = src.buf.reg[6]; + const auto transpose_top = Transpose(top); + for (int i = 0; i < 4; i++) { + StoreInt32x4(dst->data(row + i, col), transpose_top.buf.reg[i]); + } + RegBlockInt32<4, 4> bottom; + bottom.buf.reg[0] = src.buf.reg[1]; + bottom.buf.reg[1] = src.buf.reg[3]; + bottom.buf.reg[2] = src.buf.reg[5]; + bottom.buf.reg[3] = src.buf.reg[7]; + const auto transpose_bottom = Transpose(bottom); + for (int i = 0; i < 4; i++) { + StoreInt32x4(dst->data(row + 4 + i, col), transpose_bottom.buf.reg[i]); + } + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> { + static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row, + int col) { + if (DstType::kOrder == MapOrder::ColMajor) { + for (int i = 0; i < 8; i++) { + StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]); + StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]); + } + } else { + RegBlockInt32<4, 4> top_left; + top_left.buf.reg[0] = src.buf.reg[0]; + top_left.buf.reg[1] = src.buf.reg[2]; + top_left.buf.reg[2] = src.buf.reg[4]; + top_left.buf.reg[3] = src.buf.reg[6]; + const auto transpose_top_left = Transpose(top_left); + for (int i = 0; i < 4; i++) { + StoreInt32x4(dst->data(row + i, col), transpose_top_left.buf.reg[i]); + } + RegBlockInt32<4, 4> bottom_left; + bottom_left.buf.reg[0] = src.buf.reg[1]; + bottom_left.buf.reg[1] = src.buf.reg[3]; + bottom_left.buf.reg[2] = src.buf.reg[5]; + bottom_left.buf.reg[3] = src.buf.reg[7]; + const auto transpose_bottom_left = Transpose(bottom_left); + for (int i = 0; i < 4; i++) { + StoreInt32x4(dst->data(row + 4 + i, col), + transpose_bottom_left.buf.reg[i]); + } + RegBlockInt32<4, 4> top_right; + top_right.buf.reg[0] = src.buf.reg[8]; + top_right.buf.reg[1] = src.buf.reg[10]; + top_right.buf.reg[2] = src.buf.reg[12]; + top_right.buf.reg[3] = src.buf.reg[14]; + const auto transpose_top_right = Transpose(top_right); + for (int i = 0; i < 4; i++) { + StoreInt32x4(dst->data(row + i, col + 4), + transpose_top_right.buf.reg[i]); + } + RegBlockInt32<4, 4> bottom_right; + bottom_right.buf.reg[0] = src.buf.reg[9]; + bottom_right.buf.reg[1] = src.buf.reg[11]; + bottom_right.buf.reg[2] = src.buf.reg[13]; + bottom_right.buf.reg[3] = src.buf.reg[15]; + const auto transpose_bottom_right = Transpose(bottom_right); + for (int i = 0; i < 4; i++) { + StoreInt32x4(dst->data(row + 4 + i, col + 4), + transpose_bottom_right.buf.reg[i]); + } + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockInt32<1, 4>, DstType> { + static void Run(const RegBlockInt32<1, 4>& src, DstType* dst, int row, + int col) { + if (DstType::kOrder == MapOrder::ColMajor) { + *dst->data(row, col + 0) = GetLane<0>(src.buf.reg[0]); + *dst->data(row, col + 1) = GetLane<1>(src.buf.reg[0]); + *dst->data(row, col + 2) = GetLane<2>(src.buf.reg[0]); + *dst->data(row, col + 3) = GetLane<3>(src.buf.reg[0]); + } else { + StoreInt32x4(dst->data(row, col), src.buf.reg[0]); + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockUint8<4, 1>, DstType> { + static void Run(const RegBlockUint8<4, 1>& src, DstType* dst, int row, + int col) { + const std::uint32_t src_reg = src.buf.reg[0]; + for (int i = 0; i < 4; i++) { + *dst->data(row + i, col) = (src_reg >> (8 * i)); + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockUint8<8, 1>, DstType> { + static void Run(const RegBlockUint8<8, 1>& src, DstType* dst, int row, + int col) { + for (int i = 0; i < 4; i++) { + *dst->data(row + i, col) = (src.buf.reg[0] >> (8 * i)); + } + for (int i = 0; i < 4; i++) { + *dst->data(row + 4 + i, col) = (src.buf.reg[1] >> (8 * i)); + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockUint8<1, 4>, DstType> { + static void Run(const RegBlockUint8<1, 4>& src, DstType* dst, int row, + int col) { + for (int i = 0; i < 4; i++) { + *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i)); + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockUint8<4, 4>, DstType> { + static void Run(const RegBlockUint8<4, 4>& src, DstType* dst, int row, + int col) { + std::uint8_t buf[16]; + StoreUint8x16(buf, src.buf.reg[0]); + for (int c = 0; c < 4; c++) { + for (int r = 0; r < 4; r++) { + *dst->data(row + r, col + c) = buf[r + 4 * c]; + } + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> { + static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row, + int col) { + std::uint8_t buf[32]; + StoreUint8x16(buf, src.buf.reg[0]); + StoreUint8x16(buf + 16, src.buf.reg[1]); + for (int c = 0; c < 4; c++) { + for (int r = 0; r < 8; r++) { + *dst->data(row + r, col + c) = buf[r + 8 * c]; + } + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> { + static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row, + int col) { + std::uint8_t buf[64]; + StoreUint8x16(buf, src.buf.reg[0]); + StoreUint8x16(buf + 16, src.buf.reg[1]); + StoreUint8x16(buf + 32, src.buf.reg[2]); + StoreUint8x16(buf + 48, src.buf.reg[3]); + for (int c = 0; c < 8; c++) { + for (int r = 0; r < 8; r++) { + *dst->data(row + r, col + c) = buf[r + 8 * c]; + } + } + } +}; + +} // namespace gemmlowp + +#endif // GEMMLOWP_INTERNAL_OUTPUT_SSE_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/internal/pack.h b/runtimes/nn/depend/external/gemmlowp/internal/pack.h new file mode 100644 index 000000000..339539602 --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/internal/pack.h @@ -0,0 +1,435 @@ +// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// pack.h: packing blocks of the LHS and RHS into the data layout +// that is expected by compute.h and eventually by kernels. +// Because this data layout depends on the kernel format, code here +// is templated in KernelLhsFormat/KernelRhsFormat. +// +// Readers note: an important theme around here is that we try hard +// to handle both Lhs and Rhs with a single piece of code. We indifferently +// refer to the Lhs and Rhs as a 'Side'. Instead of addressing matrices +// by (row, column) indices, we address them by (width, depth), as explained +// in kernel.h. This allows us to handle both Lhs and Rhs on an equal footing, +// at once. + +#ifndef GEMMLOWP_INTERNAL_PACK_H_ +#define GEMMLOWP_INTERNAL_PACK_H_ + +#include <cstring> + +#include "allocator.h" +#include "block_params.h" +#include "common.h" +#include "kernel.h" + +namespace gemmlowp { + +// A PackedSideBlock instance is a packed block of either the LHS or RHS +// (whence the generic 'Side' name). +// +// 'Packed' means that it is laid out in the storage order that +// is expected by the specified kernel format. From a block of the input +// LHS or RHS matrix, one obtains a PackedSideBlock by calling PackLhs() +// or PackRhs(). +template <typename tKernelSideFormat> +class PackedSideBlock { + public: + typedef tKernelSideFormat KernelSideFormat; + + PackedSideBlock(Side side, Allocator* allocator, + const BlockParams& block_params) + : allocator_(allocator), pos_(0) { + GetSideBlockParams(side, ¶ms_, block_params); + data_handle_ = + allocator_->Reserve<std::uint8_t>(params_.l2_width * params_.l2_depth); + sums_of_each_slice_handle_ = + allocator_->Reserve<std::int32_t>(params_.l2_width); + } + + ~PackedSideBlock() {} + + void seek_run(int start_width, int start_depth) const { + int kernel_run_depth = + std::min<int>(params_.l1_depth, params_.l2_depth - start_depth); + pos_ = params_.l2_width * start_depth + start_width * kernel_run_depth; + } + + void seek_next_cell() const { pos_ += KernelSideFormat::Cell::kSize; } + + void seek_forward_n_cells(int n) const { + pos_ += n * KernelSideFormat::Cell::kSize; + } + + const std::uint8_t* current_data() const { + return allocator_->GetPointer<std::uint8_t>(data_handle_) + pos_; + } + + std::uint8_t* current_data() { + return allocator_->GetPointer<std::uint8_t>(data_handle_) + pos_; + } + + std::int32_t* sums_of_each_slice() { + return allocator_->GetPointer<std::int32_t>(sums_of_each_slice_handle_); + } + + const std::int32_t* sums_of_each_slice() const { + return allocator_->GetPointer<const std::int32_t>( + sums_of_each_slice_handle_); + } + + const SideBlockParams& params() const { return params_; } + + private: + // The block size parameters that this PackedSizeBlock follows. + // The L2 parameters determine its overall size, while the L1 parameters, + // together with the kernel format template parameter, determine + // the fine details of the storage/traversal order. + SideBlockParams params_; + + // Pointer to the allocator provided by the caller. Not owned. + // The Allocator is assumed to outlive the PackedSideBlock. + Allocator* const allocator_; + + // Handle on the buffer backing this packed block. Owned. + Allocator::Handle data_handle_; + + // Handle on the additional buffer backing the vector of sums of slices + // associated with this block. Owned. + Allocator::Handle sums_of_each_slice_handle_; + + // pos_ is the current position in the buffer, which we access + // sequentially, like a file. + // The idea is that we pack data in the same order as it is + // going to be traversed during the computation, which for + // cache-friendliness reasons is complicated to random-access, + // as the offsets calculations would be intricate. So we + // give up random-access addressing, and instead content ourselves + // with sequential access. + // + // pos_ is mutable because during the computation we will want to + // be able to iterate on the data in a const PackedSideBlock. + mutable int pos_; +}; + +// WidthMajor and DepthMajor are custom phrases modelled after the +// standard terminology 'row-major' and 'column-major'. Their meaning +// should be transparent once one has read the explanation in kernel.h: +// for example, in the Lhs, the 'width' dimension is the rows dimension, +// so there WidthMajor means RowMajor, while in the Rhs it is the opposite. +// Another way to put it: WidthMajor means that contiguous storage is used +// for entries having the same 'width' index. +enum class SideMapOrder { WidthMajor, DepthMajor }; + +// Similar to MatrixMap from map.h, but in terms of width/depth instead of +// rows/columns. Used to address blocks of the input LHS/RHS matrices when +// packing them. +template <typename tScalar, SideMapOrder tOrder> +class SideMap { + public: + typedef tScalar Scalar; + static const SideMapOrder kOrder = tOrder; + + SideMap(Scalar* data, int width, int depth, int stride) + : data_(data), width_(width), depth_(depth), stride_(stride) {} + + SideMap(Scalar* data, int width, int depth) + : data_(data), width_(width), depth_(depth) { + stride_ = kOrder == SideMapOrder::WidthMajor ? depth_ : width_; + } + + SideMap(const SideMap& other) + : data_(other.data_), + width_(other.width_), + depth_(other.depth_), + stride_(other.stride_) {} + + int width() const { return width_; } + int depth() const { return depth_; } + int stride() const { return stride_; } + int width_stride() const { + return kOrder == SideMapOrder::DepthMajor ? 1 : stride_; + } + int depth_stride() const { + return kOrder == SideMapOrder::WidthMajor ? 1 : stride_; + } + Scalar* data() const { return data_; } + Scalar* data(int w, int d) const { + return data_ + w * width_stride() + d * depth_stride(); + } + Scalar operator()(int w, int d) const { return *data(w, d); } + Scalar& operator()(int w, int d) { return *data(w, d); } + + SideMap block(int start_width, int start_depth, int block_width, + int block_depth) const { + assert(start_width >= 0); + assert(start_width + block_width <= width_); + assert(start_depth >= 0); + assert(start_depth + block_depth <= depth_); + + return SideMap(data(start_width, start_depth), block_width, block_depth, + stride_); + } + + private: + Scalar* data_; // not owned. + int width_, depth_, stride_; +}; + +// A PackingRegisterBlock is a small fixed-size block of a matrix being +// packed. This class is the generic non-optimized implementation, +// it is inherited by the generic implementation of PackingRegisterBlock, +// which may be overriden by template specialization. Overriding it is how +// one may provide optimized packing code paths. +// +// The packing of a block proceeds in two steps: +// 1. Ensuring that we have a complete block of source data, i.e. a block of +// the compile-time prescribed size. This is where we handle unaligned +// boundaries: if we don't have a complete block of source data, then +// we copy and zero-extend it into a local temporary (complete_src_), +// see MakeCompleteSrc. In the generic case, we do have a complete block, +// so we just use it in-place, see UseCompleteSrcInPlace. +// 2. Packing a complete block into the destination, see Pack. This is the +// most critical part, so it's convenient that unaligned boundaries have +// already been handled in step 1. +template <typename SrcMapType, typename PackedSideBlock> +class PackingRegisterBlockBase { + public: + typedef typename PackedSideBlock::KernelSideFormat KernelSideFormat; + typedef typename KernelSideFormat::Cell CellFormat; + typedef typename KernelSideFormat::Scalar KernelScalar; + static const int kCells = KernelSideFormat::kCells; + static const int kCellWidth = CellFormat::kWidth; + static const int kKernelWidth = CellFormat::kWidth * kCells; + static const int kCellDepth = CellFormat::kDepth; + static const int kCellSize = CellFormat::kSize; + static const SideMapOrder kSrcOrder = SrcMapType::kOrder; + static const int kZeroPointInputValue = + ZeroPointInputValue<KernelScalar>::kValue; + + PackingRegisterBlockBase() : complete_src_(nullptr, 0, 0, 0) {} + + protected: + // The source data that's ready for packing. May point to + // in-place actual source data if it's already a complete block, + // (see UseCompleteSrcInPlace) + // or to the local buf_ below into which we copy incomplete blocks + // (see MakeCompleteSrc) + SrcMapType complete_src_; + + // Temporary buffer for loading incomplete blocks to, + // in the source storage order + std::uint8_t buf_[kKernelWidth * kRegisterSize]; + + public: + // Selects a block if in-place source data that's already a complete block + void UseCompleteSrcInPlace(const SrcMapType& src) { complete_src_ = src; } + // Copies an incomplete block of source data into a local temporary + // complete block by zero-extending it. + void MakeCompleteSrc(const SrcMapType& src) { + memset(buf_, kZeroPointInputValue, kKernelWidth * kRegisterSize); + if (kSrcOrder == SideMapOrder::WidthMajor) { + for (int w = 0; w < src.width(); w++) { + memcpy(buf_ + w * kRegisterSize, src.data(w, 0), src.depth()); + } + } else { + assert(kSrcOrder == SideMapOrder::DepthMajor); + for (int d = 0; d < src.depth(); d++) { + memcpy(buf_ + d * kKernelWidth, src.data(0, d), src.width()); + } + } + complete_src_ = SrcMapType(buf_, kKernelWidth, kRegisterSize); + } + // Packs a complete block into the destination. This is the most + // critical part and the part that we most typically want to + // override in architecture-specific optimized specializations. + void Pack(PackedSideBlock* dst, int start_width) { + std::uint8_t* dst_ptr = dst->current_data(); + for (int cell_start_depth = 0; cell_start_depth < kRegisterSize; + cell_start_depth += kCellDepth) { + for (int cell_start_width = 0; cell_start_width < kKernelWidth; + cell_start_width += kCellWidth) { + std::int32_t* cell_sums_of_each_slice_ptr = + dst->sums_of_each_slice() + start_width + cell_start_width; + const SideMap<const std::uint8_t, kSrcOrder> src_cell_map( + complete_src_.block(cell_start_width, cell_start_depth, kCellWidth, + kCellDepth)); + for (int w = 0; w < kCellWidth; w++) { + std::int32_t sum = 0; + for (int d = 0; d < kCellDepth; d++) { + const std::uint8_t src_val = src_cell_map(w, d); + const std::int16_t kernel_val_unwrapped = + src_val - kZeroPointInputValue; + const std::uint8_t kernel_val_uint8 = kernel_val_unwrapped; + dst_ptr[OffsetIntoCell<CellFormat>(w, d)] = kernel_val_uint8; + sum += kernel_val_unwrapped; + } + cell_sums_of_each_slice_ptr[w] += sum; + } + dst_ptr += kCellSize; + } + } + dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth); + } +}; + +template <typename SrcMapType, typename PackedSideBlock> +class PackingRegisterBlock + : public PackingRegisterBlockBase<SrcMapType, PackedSideBlock> {}; + +// Large-scale implementation of packing. +template <typename SrcMapType, typename PackedSideBlock> +class PackSideBlockImpl { + public: + typedef typename PackedSideBlock::KernelSideFormat KernelSideFormat; + typedef typename KernelSideFormat::Cell CellFormat; + static const int kCells = KernelSideFormat::kCells; + static const int kCellWidth = CellFormat::kWidth; + static const int kKernelWidth = CellFormat::kWidth * kCells; + static const int kCellDepth = CellFormat::kDepth; + + typedef PackingRegisterBlock<SrcMapType, PackedSideBlock> + PackingRegisterBlockType; + + PackSideBlockImpl(PackedSideBlock* packed_side_block, + const SrcMapType& src_map) + : packed_side_block_(packed_side_block), src_map_(src_map) {} + + PackedSideBlock* packed_side_block() const { return packed_side_block_; } + + const SrcMapType& src_map() const { return src_map_; } + + // The public entry point to pack a block. + void PackL2() { + memset(packed_side_block_->sums_of_each_slice(), 0, + sizeof(std::int32_t) * packed_side_block_->params().l2_width); + for (int d = 0; d < src_map_.depth(); + d += packed_side_block_->params().l1_depth) { + int ds = std::min<int>(packed_side_block_->params().l1_depth, + src_map_.depth() - d); + + for (int w = 0; w < src_map_.width(); + w += packed_side_block_->params().l1_width) { + int ws = std::min<int>(packed_side_block_->params().l1_width, + src_map_.width() - w); + + PrefetchL1(w, ws, d, ds); + PackL1(w, ws, d, ds); + } + } + } + + protected: + // The intermediate-level loops, between PackL2 and PackRun. + void PackL1(int start_width, int width, int start_depth, int depth) { + for (int w = 0; w < width; w += kKernelWidth) { + int ws = std::min(+kKernelWidth, width - w); + packed_side_block_->seek_run(start_width + w, start_depth); + PackRun(start_width + w, ws, start_depth, depth); + } + } + + // Prefetches the data that will be read by PackL1 + void PrefetchL1(int start_width, int width, int start_depth, int depth) { + if (SrcMapType::kOrder == SideMapOrder::WidthMajor) { + for (int d = 0; d < depth; d += kDefaultCacheLineSize) { + for (int w = 0; w < width; w += 1) { + Prefetch(src_map_.data(start_width + w, start_depth + d)); + } + } + } else { + for (int d = 0; d < depth; d++) { + for (int w = 0; w < width; w += kDefaultCacheLineSize) { + Prefetch(src_map_.data(start_width + w, start_depth + d)); + } + } + } + } + + // PackRun packs only a run i.e. is the inner loop in the depth dimension. + void PackRun(int start_width, int width, int start_depth, int depth) { + PackingRegisterBlockType b; + if (width == kKernelWidth) { + const int register_aligned_depth = RoundDown<kRegisterSize>(depth); + if (register_aligned_depth) { + for (int d = 0; d < register_aligned_depth; d += kRegisterSize) { + b.UseCompleteSrcInPlace(src_map_.block(start_width, start_depth + d, + width, kRegisterSize)); + b.Pack(packed_side_block_, start_width); + } + } + if (register_aligned_depth < depth) { + b.MakeCompleteSrc( + src_map_.block(start_width, start_depth + register_aligned_depth, + width, depth - register_aligned_depth)); + b.Pack(packed_side_block_, start_width); + } + } else { + assert(width < kKernelWidth); + for (int d = 0; d < depth; d += kRegisterSize) { + const int ds = std::min(+kRegisterSize, depth - d); + b.MakeCompleteSrc( + src_map_.block(start_width, start_depth + d, width, ds)); + b.Pack(packed_side_block_, start_width); + } + } + } + + // The PackedSideBlock being packed, i.e. the 'destination'. + PackedSideBlock* const packed_side_block_; + + // A map on the block of the original matrix block being packed, + // i.e. the 'source'. + const SrcMapType& src_map_; +}; + +// Packs a block of the input LHS matrix, into a PackedSideBlock +template <typename PackedSideBlock, typename MatrixMapType> +void PackLhs(PackedSideBlock* dst, const MatrixMapType& src) { + ScopedProfilingLabel label("pack LHS"); + static const SideMapOrder kSideMapOrder = + MatrixMapType::kOrder == MapOrder::RowMajor ? SideMapOrder::WidthMajor + : SideMapOrder::DepthMajor; + typedef typename MatrixMapType::Scalar Scalar; + typedef SideMap<Scalar, kSideMapOrder> SideMapType; + SideMapType src_side_map(src.data(), src.rows(), src.cols(), src.stride()); + typedef PackSideBlockImpl<SideMapType, PackedSideBlock> ImplType; + ImplType impl(dst, src_side_map); + impl.PackL2(); +} + +// Packs a block of the input RHS matrix, into a PackedSideBlock +template <typename PackedSideBlock, typename MatrixMapType> +void PackRhs(PackedSideBlock* dst, const MatrixMapType& src) { + ScopedProfilingLabel label("pack RHS"); + static const SideMapOrder kSideMapOrder = + MatrixMapType::kOrder == MapOrder::ColMajor ? SideMapOrder::WidthMajor + : SideMapOrder::DepthMajor; + typedef typename MatrixMapType::Scalar Scalar; + typedef SideMap<Scalar, kSideMapOrder> SideMapType; + SideMapType src_side_map(src.data(), src.cols(), src.rows(), src.stride()); + typedef PackSideBlockImpl<SideMapType, PackedSideBlock> ImplType; + ImplType impl(dst, src_side_map); + impl.PackL2(); +} + +} // namespace gemmlowp + +#ifdef GEMMLOWP_NEON +#include "pack_neon.h" +#elif defined(GEMMLOWP_SSE4) +#include "pack_sse.h" +#endif + +#endif // GEMMLOWP_INTERNAL_PACK_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/internal/pack_neon.h b/runtimes/nn/depend/external/gemmlowp/internal/pack_neon.h new file mode 100644 index 000000000..e212d0756 --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/internal/pack_neon.h @@ -0,0 +1,320 @@ +// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// pack_neon.h: optimized NEON specializations of the templates in pack.h. + +#ifndef GEMMLOWP_INTERNAL_PACK_NEON_H_ +#define GEMMLOWP_INTERNAL_PACK_NEON_H_ + +#include "pack.h" + +#include <arm_neon.h> + +namespace gemmlowp { + +typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor> + WidthMajorUint8SideMap; + +template <int Cells> +using DepthMajorSideFormatNCells4x2 = KernelSideFormat<CellFormat<4, 2>, Cells>; + +template <int Cells> +class PackingRegisterBlock< + WidthMajorUint8SideMap, + PackedSideBlock<DepthMajorSideFormatNCells4x2<Cells>>> + : public PackingRegisterBlockBase< + WidthMajorUint8SideMap, + PackedSideBlock<DepthMajorSideFormatNCells4x2<Cells>>> { + public: + typedef DepthMajorSideFormatNCells4x2<Cells> KernelSideFormat; + typedef typename KernelSideFormat::Cell CellFormat; + static const int kCells = KernelSideFormat::kCells; + static const int kCellWidth = CellFormat::kWidth; + static const int kKernelWidth = CellFormat::kWidth * kCells; + static const int kCellDepth = CellFormat::kDepth; + static const int kCellSize = CellFormat::kSize; + + void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) { + std::uint8_t* dst_ptr = dst->current_data(); + const std::uint8_t* const src_ptr = this->complete_src_.data(); + const int stride = this->complete_src_.stride(); + // Load source WidthMajor data + uint8x16_t src_lines[4 * kCells]; + for (int i = 0; i < 4 * kCells; i++) { + src_lines[i] = vld1q_u8(src_ptr + i * stride); + } + // Reorder the data within registers to make DepthMajor 4x2 cells + uint8x16x2_t src_lines_intertwined_2x[2 * kCells]; + for (int i = 0; i < kCells; i++) { + src_lines_intertwined_2x[2 * i] = + vzipq_u8(src_lines[4 * i], src_lines[4 * i + 2]); + src_lines_intertwined_2x[2 * i + 1] = + vzipq_u8(src_lines[4 * i + 1], src_lines[4 * i + 3]); + } + uint8x16x2_t src_lines_intertwined_4x[2 * kCells]; + for (int i = 0; i < kCells; i++) { + src_lines_intertwined_4x[2 * i] = + vzipq_u8(src_lines_intertwined_2x[2 * i].val[0], + src_lines_intertwined_2x[2 * i + 1].val[0]); + src_lines_intertwined_4x[2 * i + 1] = + vzipq_u8(src_lines_intertwined_2x[2 * i].val[1], + src_lines_intertwined_2x[2 * i + 1].val[1]); + } + // Store the resulting DepthMajor 4x2 cells in the destination packed block + for (int outer = 0; outer < 2; outer++) { + for (int inner = 0; inner < 2; inner++) { + for (int cell = 0; cell < kCells; cell++) { + uint8x8_t value = vget_low_u8( + src_lines_intertwined_4x[2 * cell + outer].val[inner]); + vst1_u8(dst_ptr, value); + dst_ptr += 8; + } + for (int cell = 0; cell < kCells; cell++) { + uint8x8_t value = vget_high_u8( + src_lines_intertwined_4x[2 * cell + outer].val[inner]); + vst1_u8(dst_ptr, value); + dst_ptr += 8; + } + } + } + // Compute sums across the depth dimension + uint16x8_t sums_of_2_cells[kCells][4]; + for (int outer = 0; outer < 2; outer++) { + for (int inner = 0; inner < 2; inner++) { + int i = 2 * outer + inner; + for (int cell = 0; cell < kCells; cell++) { + sums_of_2_cells[cell][i] = vaddl_u8( + vget_low_u8( + src_lines_intertwined_4x[2 * cell + outer].val[inner]), + vget_high_u8( + src_lines_intertwined_4x[2 * cell + outer].val[inner])); + } + } + } + int32x4_t sums_of_4_cells[kCells][4]; + for (int i = 0; i < 4; i++) { + for (int cell = 0; cell < kCells; cell++) { + sums_of_4_cells[cell][i] = vreinterpretq_s32_u32( + vaddl_u16(vget_low_u16(sums_of_2_cells[cell][i]), + vget_high_u16(sums_of_2_cells[cell][i]))); + } + } + // Update the sums_of_each_slice vector + for (int cell = 0; cell < kCells; cell++) { + int32x4_t s01 = + vaddq_s32(sums_of_4_cells[cell][0], sums_of_4_cells[cell][1]); + int32x4_t s23 = + vaddq_s32(sums_of_4_cells[cell][2], sums_of_4_cells[cell][3]); + int32x4_t s = vaddq_s32(s01, s23); + std::int32_t* sums_of_each_slice_ptr = + dst->sums_of_each_slice() + start_width + 4 * cell; + vst1q_s32(sums_of_each_slice_ptr, + vaddq_s32(s, vld1q_s32(sums_of_each_slice_ptr))); + } + dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth); + } +}; + +template <int Cells> +using WidthMajorSideFormatNCells4x2 = + KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, Cells>; + +template <int Cells> +class PackingRegisterBlock< + WidthMajorUint8SideMap, + PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>> + : public PackingRegisterBlockBase< + WidthMajorUint8SideMap, + PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>> { + public: + typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat; + typedef typename KernelSideFormat::Cell CellFormat; + static const int kCells = KernelSideFormat::kCells; + static const int kCellWidth = CellFormat::kWidth; + static const int kKernelWidth = CellFormat::kWidth * kCells; + static const int kCellDepth = CellFormat::kDepth; + static const int kCellSize = CellFormat::kSize; + + void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) { + std::uint8_t* dst_ptr = dst->current_data(); + const std::uint8_t* src_ptr = this->complete_src_.data(); + const int stride = this->complete_src_.stride(); + // Load source WidthMajor data + uint16x8_t src_lines[kCells * 4]; + for (int i = 0; i < kCells; i++) { +// This packing path is used with our current +// less-than-8-bit kernel, and the partial unrolling of this loop +// results in substantially faster code (thanks to better +// register allocation) on Nexus 5. + +#define GEMMLOWP_UNROLLED_LOOP_ITER(k) \ + src_lines[4 * i + k] = vreinterpretq_u16_u8(vld1q_u8(src_ptr)); \ + src_ptr += stride; + + GEMMLOWP_UNROLLED_LOOP_ITER(0) + GEMMLOWP_UNROLLED_LOOP_ITER(1) + GEMMLOWP_UNROLLED_LOOP_ITER(2) + GEMMLOWP_UNROLLED_LOOP_ITER(3) + +#undef GEMMLOWP_UNROLLED_LOOP_ITER + } + // Reorder the data within registers to make WidthMajor 4x2 cells + uint16x8x2_t src_lines_intertwined_2x[2 * kCells]; + for (int i = 0; i < kCells; i++) { + src_lines_intertwined_2x[2 * i] = + vzipq_u16(src_lines[4 * i], src_lines[4 * i + 2]); + src_lines_intertwined_2x[2 * i + 1] = + vzipq_u16(src_lines[4 * i + 1], src_lines[4 * i + 3]); + } + uint16x8x2_t src_lines_intertwined_4x[2 * kCells]; + for (int i = 0; i < kCells; i++) { + src_lines_intertwined_4x[2 * i] = + vzipq_u16(src_lines_intertwined_2x[2 * i].val[0], + src_lines_intertwined_2x[2 * i + 1].val[0]); + src_lines_intertwined_4x[2 * i + 1] = + vzipq_u16(src_lines_intertwined_2x[2 * i].val[1], + src_lines_intertwined_2x[2 * i + 1].val[1]); + } + // Store the resulting WidthMajor 4x2 cells in the destination packed block + for (int outer = 0; outer < 2; outer++) { + for (int inner = 0; inner < 2; inner++) { + for (int cell = 0; cell < kCells; cell++) { + uint8x8_t value = vreinterpret_u8_u16(vget_low_u16( + src_lines_intertwined_4x[2 * cell + outer].val[inner])); + vst1_u8(dst_ptr, value); + dst_ptr += 8; + } + for (int cell = 0; cell < kCells; cell++) { + uint8x8_t value = vreinterpret_u8_u16(vget_high_u16( + src_lines_intertwined_4x[2 * cell + outer].val[inner])); + vst1_u8(dst_ptr, value); + dst_ptr += 8; + } + } + } + // Compute sums across the depth dimension + uint16x8_t sums_of_2[kCells][4]; + for (int outer = 0; outer < 2; outer++) { + for (int inner = 0; inner < 2; inner++) { + int i = 2 * outer + inner; + for (int cell = 0; cell < kCells; cell++) { + sums_of_2[cell][i] = vpaddlq_u8(vreinterpretq_u8_u16( + src_lines_intertwined_4x[2 * cell + outer].val[inner])); + } + } + } + uint16x8_t sums_of_4[kCells][2]; + for (int i = 0; i < 2; i++) { + for (int cell = 0; cell < kCells; cell++) { + sums_of_4[cell][i] = + vaddq_u16(sums_of_2[cell][2 * i], sums_of_2[cell][2 * i + 1]); + } + } + uint16x8_t sums_of_8[kCells]; + for (int cell = 0; cell < kCells; cell++) { + sums_of_8[cell] = vaddq_u16(sums_of_4[cell][0], sums_of_4[cell][1]); + } + + uint16x4_t sums_of_16[kCells]; + for (int cell = 0; cell < kCells; cell++) { + sums_of_16[cell] = vadd_u16(vget_low_u16(sums_of_8[cell]), + vget_high_u16(sums_of_8[cell])); + } + // Update the sums_of_each_slice vector + for (int cell = 0; cell < kCells; cell++) { + int32x4_t s = vreinterpretq_s32_u32(vmovl_u16(sums_of_16[cell])); + std::int32_t* sums_of_each_slice_ptr = + dst->sums_of_each_slice() + start_width + 4 * cell; + vst1q_s32(sums_of_each_slice_ptr, + vaddq_s32(s, vld1q_s32(sums_of_each_slice_ptr))); + } + dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth); + } +}; + +#ifdef GEMMLOWP_NEON_32 +inline int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) { + const int16x4_t c = vpadd_s16(vget_low_s16(a), vget_high_s16(a)); + const int16x4_t d = vpadd_s16(vget_low_s16(b), vget_high_s16(b)); + return vcombine_s16(c, d); +} +#endif + +template <int Width> +using Int8FastKernelFormat = + KernelSideFormatInt8<CellFormat<Width, 16, CellOrder::WidthMajor>, 1>; + +template <int Width> +class PackingRegisterBlock<WidthMajorUint8SideMap, + PackedSideBlock<Int8FastKernelFormat<Width>>> + : public PackingRegisterBlockBase< + WidthMajorUint8SideMap, + PackedSideBlock<Int8FastKernelFormat<Width>>> { + public: + static_assert(Width == 2 || Width == 4, ""); + typedef Int8FastKernelFormat<Width> KernelSideFormat; + typedef typename KernelSideFormat::Cell CellFormat; + static const int kCells = KernelSideFormat::kCells; + static const int kCellWidth = CellFormat::kWidth; + static const int kKernelWidth = CellFormat::kWidth * kCells; + static const int kCellDepth = CellFormat::kDepth; + static const int kCellSize = CellFormat::kSize; + + void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) { + std::int32_t* sums_ptr = dst->sums_of_each_slice() + start_width; + std::uint8_t* dst_ptr = dst->current_data(); + const std::uint8_t* const src_ptr = this->complete_src_.data(); + const int stride = this->complete_src_.stride(); + // Load source WidthMajor data + uint8x16_t src_lines[Width]; + for (int i = 0; i < Width; i++) { + src_lines[i] = vld1q_u8(src_ptr + i * stride); + } + const uint8x16_t sign_bit_dup = vdupq_n_u8(0x80); + for (int i = 0; i < Width; i++) { + src_lines[i] = veorq_u8(src_lines[i], sign_bit_dup); + } + for (int i = 0; i < Width; i++) { + vst1q_u8(dst_ptr + 16 * i, src_lines[i]); + } + int16x8_t sums2[Width]; + for (int i = 0; i < Width; i++) { + const int8x8_t lo = vreinterpret_s8_u8(vget_low_u8(src_lines[i])); + const int8x8_t hi = vreinterpret_s8_u8(vget_high_u8(src_lines[i])); + sums2[i] = vaddl_s8(lo, hi); + } + int16x8_t sums4[Width / 2]; + for (int i = 0; i < Width / 2; i++) { + sums4[i] = vpaddq_s16(sums2[2 * i], sums2[2 * i + 1]); + } + if (Width == 4) { + int32x4_t sum = vld1q_s32(sums_ptr); + int16x8_t sums8 = vpaddq_s16(sums4[0], sums4[1]); + sum = vpadalq_s16(sum, sums8); + vst1q_s32(sums_ptr, sum); + } else { + assert(Width == 2); + int32x2_t sum = vld1_s32(sums_ptr); + int16x4_t sums8 = + vpadd_s16(vget_low_s16(sums4[0]), vget_high_s16(sums4[0])); + sum = vpadal_s16(sum, sums8); + vst1_s32(sums_ptr, sum); + } + dst->seek_forward_n_cells(1); + } +}; + +} // namespace gemmlowp + +#endif // GEMMLOWP_INTERNAL_PACK_NEON_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/internal/pack_sse.h b/runtimes/nn/depend/external/gemmlowp/internal/pack_sse.h new file mode 100644 index 000000000..52163c4e5 --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/internal/pack_sse.h @@ -0,0 +1,128 @@ +// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// pack_SSE.h: optimized SSE specializations of the templates in pack.h. + +#ifndef GEMMLOWP_INTERNAL_PACK_SSE_H_ +#define GEMMLOWP_INTERNAL_PACK_SSE_H_ + +#include <smmintrin.h> +#include "pack.h" + +namespace gemmlowp { + +// TODO: Add DepthMajorUint8SideMap + +typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor> + WidthMajorUint8SideMap; + +template <int Cells> +using WidthMajorSideFormatNCells4x2 = + KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, Cells>; + +template <int Cells> +class PackingRegisterBlock< + WidthMajorUint8SideMap, + PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells> > > + : public PackingRegisterBlockBase< + WidthMajorUint8SideMap, + PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells> > > { + public: + typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat; + typedef typename KernelSideFormat::Cell CellFormat; + static const int kCells = KernelSideFormat::kCells; + static const int kCellWidth = CellFormat::kWidth; + static const int kKernelWidth = CellFormat::kWidth * kCells; + static const int kCellDepth = CellFormat::kDepth; + static const int kCellSize = CellFormat::kSize; + + void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) { + std::uint8_t* dst_ptr = dst->current_data(); + const int width_stride = this->complete_src_.width_stride(); + int depth_step = 8; + + __m128i one = _mm_set1_epi16(1); + for (int cell_start_depth = 0; cell_start_depth < kRegisterSize; + cell_start_depth += depth_step) { + for (int cell_start_width = 0; cell_start_width < kKernelWidth; + cell_start_width += kCellWidth) { + std::int32_t* cell_sums_of_each_slice_ptr = + dst->sums_of_each_slice() + start_width + cell_start_width; + const std::uint8_t* src_data = + this->complete_src_.data(cell_start_width, cell_start_depth); + + __m128i xmm1 = + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(&src_data[0])); + __m128i xmm2 = _mm_loadl_epi64( + reinterpret_cast<const __m128i*>(&src_data[1 * width_stride])); + __m128i xmm3 = _mm_loadl_epi64( + reinterpret_cast<const __m128i*>(&src_data[2 * width_stride])); + __m128i xmm4 = _mm_loadl_epi64( + reinterpret_cast<const __m128i*>(&src_data[3 * width_stride])); + + __m128i xmm5 = _mm_unpacklo_epi16(xmm1, xmm2); + __m128i xmm8 = _mm_shuffle_epi32(xmm5, 0x31); + + __m128i xmm6 = _mm_unpacklo_epi16(xmm3, xmm4); + __m128i xmm7 = _mm_shuffle_epi32(xmm6, 0x80); + + __m128i xmm9 = _mm_blend_epi16(xmm5, xmm7, 0xcc); + __m128i xmm10 = _mm_blend_epi16(xmm8, xmm6, 0xcc); + + _mm_storel_epi64(reinterpret_cast<__m128i*>(&dst_ptr[0]), xmm9); + _mm_storel_epi64( + reinterpret_cast<__m128i*>(&dst_ptr[kCellSize * kCells]), xmm10); + + __m128i xmm11 = _mm_shuffle_epi32(xmm9, 0xee); + __m128i xmm12 = _mm_shuffle_epi32(xmm10, 0xee); + + _mm_storel_epi64( + reinterpret_cast<__m128i*>(&dst_ptr[2 * kCellSize * kCells]), + xmm11); + _mm_storel_epi64( + reinterpret_cast<__m128i*>(&dst_ptr[3 * kCellSize * kCells]), + xmm12); + + xmm1 = _mm_cvtepu8_epi16(xmm9); + xmm2 = _mm_madd_epi16(xmm1, one); + __m128i sums_of_each_slice_xmm = _mm_loadu_si128( + reinterpret_cast<const __m128i*>(&cell_sums_of_each_slice_ptr[0])); + sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); + + xmm1 = _mm_cvtepu8_epi16(xmm10); + xmm2 = _mm_madd_epi16(xmm1, one); + sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); + + xmm1 = _mm_cvtepu8_epi16(xmm11); + xmm2 = _mm_madd_epi16(xmm1, one); + sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); + + xmm1 = _mm_cvtepu8_epi16(xmm12); + xmm2 = _mm_madd_epi16(xmm1, one); + sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); + + _mm_storeu_si128( + reinterpret_cast<__m128i*>(&cell_sums_of_each_slice_ptr[0]), + sums_of_each_slice_xmm); + dst_ptr += kCellSize; + } + dst_ptr += 3 * kCellSize * kCells; + } + dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth); + } +}; + +} // namespace gemmlowp + +#endif // GEMMLOWP_INTERNAL_PACK_SSE_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/internal/simd_wrappers.h b/runtimes/nn/depend/external/gemmlowp/internal/simd_wrappers.h new file mode 100644 index 000000000..e39eaf89f --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/internal/simd_wrappers.h @@ -0,0 +1,508 @@ +// Copyright 2017 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// simd_wrappers.h: some inline functions wrapping SIMD intrinsics, +// extending the set of such functions from fixedpoint.h. + +#ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_ +#define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_ + +#include <algorithm> +#include <type_traits> +#include "../fixedpoint/fixedpoint.h" + +namespace gemmlowp { + +template <typename ScalarType, int ScalarCount> +struct RegisterType { + using Type = ScalarType; +}; + +inline std::int32_t Min(std::int32_t a, std::int32_t b) { + return std::min(a, b); +} + +inline std::int32_t Max(std::int32_t a, std::int32_t b) { + return std::max(a, b); +} + +inline void MulAdd(std::int32_t lhs, std::int32_t rhs, std::int32_t* acc) { + *acc += lhs * rhs; +} + +template <typename tScalarType, int tScalarCount> +struct RegisterBuffer { + using ScalarType = tScalarType; + static constexpr int kScalarCount = tScalarCount; + using RegisterType = typename RegisterType<ScalarType, kScalarCount>::Type; + static_assert((kScalarCount & (kScalarCount - 1)) == 0, + "kScalarCount must be a power of two"); + static_assert(sizeof(RegisterType) % sizeof(ScalarType) == 0, ""); + static constexpr int kRegisterLanes = + sizeof(RegisterType) / sizeof(ScalarType); + static constexpr int kRegisterCount = + (kScalarCount * sizeof(ScalarType) + sizeof(RegisterType) - 1) / + sizeof(RegisterType); + + RegisterType reg[kRegisterCount]; +}; + +template <typename tScalarType, int tRows, int tCols> +struct RegisterBlock { + using ScalarType = tScalarType; + static constexpr int kRows = tRows; + static constexpr int kCols = tCols; + static constexpr int kScalarCount = kRows * kCols; + using BufferType = RegisterBuffer<ScalarType, kScalarCount>; + using RegisterType = typename BufferType::RegisterType; + static constexpr int kRegisterCount = BufferType::kRegisterCount; + static constexpr int kRegisterLanes = BufferType::kRegisterLanes; + + BufferType buf; +}; + +template <typename RegisterBlockType> +struct RegisterBlockAddImpl { + static RegisterBlockType Run(const RegisterBlockType& lhs, + const RegisterBlockType& rhs) { + RegisterBlockType result; + for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) { + result.buf.reg[i] = Add(lhs.buf.reg[i], rhs.buf.reg[i]); + } + return result; + } +}; + +template <typename RegisterBlockType> +RegisterBlockType RegisterBlockAdd(const RegisterBlockType& lhs, + const RegisterBlockType& rhs) { + return RegisterBlockAddImpl<RegisterBlockType>::Run(lhs, rhs); +} + +template <typename LhsType, typename RhsType> +struct ShouldFlipLhsRhs { + static constexpr bool kValue = + (LhsType::kScalarCount < RhsType::kScalarCount) || + (LhsType::kScalarCount == RhsType::kScalarCount && + (LhsType::kRows < RhsType::kRows)); +}; + +template <typename LhsType, typename RhsType, + bool Flip = ShouldFlipLhsRhs<LhsType, RhsType>::kValue> +struct FlipLhsRhs { + using FlippedLhsType = LhsType; + using FlippedRhsType = RhsType; + static const FlippedLhsType& FlippedLhs(const LhsType& lhs, + const RhsType& rhs) { + return lhs; + } + static const FlippedRhsType& FlippedRhs(const LhsType& lhs, + const RhsType& rhs) { + return rhs; + } +}; + +template <typename LhsType, typename RhsType> +struct FlipLhsRhs<LhsType, RhsType, true> { + using FlippedLhsType = RhsType; + using FlippedRhsType = LhsType; + static const FlippedLhsType& FlippedLhs(const LhsType& lhs, + const RhsType& rhs) { + return rhs; + } + static const FlippedRhsType& FlippedRhs(const LhsType& lhs, + const RhsType& rhs) { + return lhs; + } +}; + +template <typename Lhs, typename Rhs> +struct BroadcastBinaryOpShape { + static constexpr int kRows = + Lhs::kRows > Rhs::kRows ? Lhs::kRows : Rhs::kRows; + static constexpr int kCols = + Lhs::kCols > Rhs::kCols ? Lhs::kCols : Rhs::kCols; +}; + +template <typename Lhs, typename Rhs> +struct BroadcastBinaryOpRegisterBlock { + using Shape = BroadcastBinaryOpShape<Lhs, Rhs>; + using ScalarType = typename Lhs::ScalarType; + using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>; +}; + +template <typename Lhs, typename Rhs> +struct BroadcastAddImpl { + using ResultBlockType = + typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type; + static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) { + ResultBlockType result; + static constexpr int Rows = ResultBlockType::kRows; + static constexpr int Cols = ResultBlockType::kCols; + static constexpr int LhsRows = Lhs::kRows; + static constexpr int LhsCols = Lhs::kCols; + static constexpr int RhsRows = Rhs::kRows; + static constexpr int RhsCols = Rhs::kCols; + + static_assert(LhsRows == Rows || LhsRows == 1, ""); + static_assert(RhsRows == Rows || RhsRows == 1, ""); + static_assert(LhsCols == Cols || LhsCols == 1, ""); + static_assert(RhsCols == Cols || RhsCols == 1, ""); + static_assert(ResultBlockType::kRegisterLanes == 1, + "This path is only for scalar values"); + static_assert(Lhs::kRegisterLanes == 1, + "This path is only for scalar values"); + static_assert(Rhs::kRegisterLanes == 1, + "This path is only for scalar values"); + + for (int c = 0; c < Cols; c++) { + const int lhs_c = LhsCols == Cols ? c : 0; + const int rhs_c = RhsCols == Cols ? c : 0; + for (int r = 0; r < Rows; r++) { + const int lhs_r = LhsRows == Rows ? r : 0; + const int rhs_r = RhsRows == Rows ? r : 0; + result.buf.reg[r + c * Rows] = + Add(lhs.buf.reg[lhs_r + lhs_c * LhsRows], + rhs.buf.reg[rhs_r + rhs_c * RhsRows]); + } + } + return result; + } +}; + +template <typename Lhs, typename Rhs> +typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastAdd( + const Lhs& lhs, const Rhs& rhs) { + using Flip = FlipLhsRhs<Lhs, Rhs>; + return BroadcastAddImpl< + typename Flip::FlippedLhsType, + typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs), + Flip::FlippedRhs(lhs, rhs)); +} + +template <typename Lhs, typename Rhs> +struct BroadcastMulImpl { + using ResultBlockType = + typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type; + static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) { + ResultBlockType result; + static constexpr int Rows = ResultBlockType::kRows; + static constexpr int Cols = ResultBlockType::kCols; + static constexpr int LhsRows = Lhs::kRows; + static constexpr int LhsCols = Lhs::kCols; + static constexpr int RhsRows = Rhs::kRows; + static constexpr int RhsCols = Rhs::kCols; + static_assert(ResultBlockType::kRegisterLanes == 1, + "This path is only for scalar values"); + static_assert(Lhs::kRegisterLanes == 1, + "This path is only for scalar values"); + static_assert(Rhs::kRegisterLanes == 1, + "This path is only for scalar values"); + + static_assert(LhsRows == Rows || LhsRows == 1, ""); + static_assert(RhsRows == Rows || RhsRows == 1, ""); + static_assert(LhsCols == Cols || LhsCols == 1, ""); + static_assert(RhsCols == Cols || RhsCols == 1, ""); + for (int c = 0; c < Cols; c++) { + const int lhs_c = LhsCols == Cols ? c : 0; + const int rhs_c = RhsCols == Cols ? c : 0; + for (int r = 0; r < Rows; r++) { + const int lhs_r = LhsRows == Rows ? r : 0; + const int rhs_r = RhsRows == Rows ? r : 0; + result.buf.reg[r + c * Rows] = + Mul(lhs.buf.reg[lhs_r + lhs_c * LhsRows], + rhs.buf.reg[rhs_r + rhs_c * RhsRows]); + } + } + return result; + } +}; + +template <typename Lhs, typename Rhs> +typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastMul( + const Lhs& lhs, const Rhs& rhs) { + using Flip = FlipLhsRhs<Lhs, Rhs>; + return BroadcastMulImpl< + typename Flip::FlippedLhsType, + typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs), + Flip::FlippedRhs(lhs, rhs)); +} + +template <typename Lhs, typename Rhs, typename Acc> +struct BroadcastMulAddImpl { + static void Run(const Lhs& lhs, const Rhs& rhs, Acc* acc) { + static constexpr int Rows = Acc::kRows; + static constexpr int Cols = Acc::kCols; + static constexpr int LhsRows = Lhs::kRows; + static constexpr int LhsCols = Lhs::kCols; + static constexpr int RhsRows = Rhs::kRows; + static constexpr int RhsCols = Rhs::kCols; + static_assert(Acc::kRegisterLanes == 1, + "This path is only for scalar values"); + static_assert(Lhs::kRegisterLanes == 1, + "This path is only for scalar values"); + static_assert(Rhs::kRegisterLanes == 1, + "This path is only for scalar values"); + + static_assert(LhsRows == Rows || LhsRows == 1, ""); + static_assert(RhsRows == Rows || RhsRows == 1, ""); + static_assert(LhsCols == Cols || LhsCols == 1, ""); + static_assert(RhsCols == Cols || RhsCols == 1, ""); + for (int c = 0; c < Cols; c++) { + const int lhs_c = LhsCols == Cols ? c : 0; + const int rhs_c = RhsCols == Cols ? c : 0; + for (int r = 0; r < Rows; r++) { + const int lhs_r = LhsRows == Rows ? r : 0; + const int rhs_r = RhsRows == Rows ? r : 0; + MulAdd(lhs.buf.reg[lhs_r + lhs_c * LhsRows], + rhs.buf.reg[rhs_r + rhs_c * RhsRows], + &acc->buf.reg[r + c * Rows]); + } + } + } +}; + +template <typename Lhs, typename Rhs, typename Acc> +void BroadcastMulAdd(const Lhs& lhs, const Rhs& rhs, Acc* acc) { + using Flip = FlipLhsRhs<Lhs, Rhs>; + BroadcastMulAddImpl<typename Flip::FlippedLhsType, + typename Flip::FlippedRhsType, + Acc>::Run(Flip::FlippedLhs(lhs, rhs), + Flip::FlippedRhs(lhs, rhs), acc); +} + +template <typename RegisterBlockType, typename SrcObjectType> +struct LoadImpl { + static_assert(std::is_same<SrcObjectType, void>::value, + "This generic impl should never be hit"); +}; + +template <typename ScalarType, int Rows, int Cols, typename SrcScalarType> +struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>, + MatrixMap<SrcScalarType, MapOrder::ColMajor>> { + using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; + using SrcObjectType = MatrixMap<SrcScalarType, MapOrder::ColMajor>; + static RegisterBlockType Run(const SrcObjectType& src, int row, int col) { + RegisterBlockType result; + int i = 0; + for (int c = 0; c < Cols; c++) { + const ScalarType* src_ptr = src.data(row, col + c); + for (int r = 0; r < Rows; r++) { + result.buf.reg[i++] = *src_ptr++; + } + } + return result; + } +}; + +template <typename ScalarType, int Rows, int Cols, typename SrcScalarType, + VectorShape Shape> +struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>, + VectorMap<SrcScalarType, Shape>> { + using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; + using SrcObjectType = VectorMap<SrcScalarType, Shape>; + static RegisterBlockType Run(const SrcObjectType& src, int pos) { + static_assert(Shape == VectorShape::Col || Rows == 1, ""); + static_assert(Shape == VectorShape::Row || Cols == 1, ""); + RegisterBlockType result; + for (int i = 0; i < Rows * Cols; i++) { + result.buf.reg[i] = src(pos + i); + } + return result; + } +}; + +template <typename ScalarType, int Rows, int Cols, typename SrcScalarType, + VectorShape Shape> +struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>, + VectorDup<SrcScalarType, Shape>> { + using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; + using SrcObjectType = VectorDup<SrcScalarType, Shape>; + static RegisterBlockType Run(const SrcObjectType& src, int) { + static_assert(Shape == VectorShape::Col || Rows == 1, ""); + static_assert(Shape == VectorShape::Row || Cols == 1, ""); + RegisterBlockType result; + for (int i = 0; i < Rows * Cols; i++) { + result.buf.reg[i] = src(0); + } + return result; + } +}; + +template <typename RegisterBlockType, typename SrcObjectType> +RegisterBlockType Load(const SrcObjectType& src, int row, int col) { + return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, row, col); +} + +template <typename RegisterBlockType, typename SrcObjectType> +RegisterBlockType Load(const SrcObjectType& src, int pos) { + return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, pos); +} + +template <typename RegisterBlockType> +struct LoadContiguousImpl { + using ScalarType = typename RegisterBlockType::ScalarType; + static_assert(RegisterBlockType::kRegisterLanes == 1, + "This path is only for scalar values"); + static RegisterBlockType Run(const ScalarType* src) { + RegisterBlockType result; + for (int i = 0; i < RegisterBlockType::kScalarCount; i++) { + result.buf.reg[i] = src[i]; + } + return result; + } +}; + +template <typename RegisterBlockType> +RegisterBlockType LoadContiguous( + const typename RegisterBlockType::ScalarType* src) { + return LoadContiguousImpl<RegisterBlockType>::Run(src); +} + +template <int BroadcastRows, int BroadcastCols, typename SrcObjectType> +struct LoadForBroadcastingShape {}; + +template <int BroadcastRows, int BroadcastCols, typename ScalarType, + VectorShape Shape> +struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols, + VectorMap<ScalarType, Shape>> { + static constexpr int kRows = Shape == VectorShape::Col ? BroadcastRows : 1; + static constexpr int kCols = Shape == VectorShape::Row ? BroadcastCols : 1; +}; + +template <int BroadcastRows, int BroadcastCols, typename ScalarType, + VectorShape Shape> +struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols, + VectorDup<ScalarType, Shape>> { + static constexpr int kRows = 1; + static constexpr int kCols = 1; +}; + +template <typename RegisterBlockType, typename SrcObjectType> +struct LoadForBroadcastingRegisterBlock { + using Shape = + LoadForBroadcastingShape<RegisterBlockType::kRows, + RegisterBlockType::kCols, SrcObjectType>; + using ScalarType = typename RegisterBlockType::ScalarType; + using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>; +}; + +template <typename RegisterBlockType, typename SrcObjectType> +struct LoadForBroadcastingImpl { + static_assert(std::is_same<SrcObjectType, void>::value, + "This generic impl should never be hit"); +}; + +template <typename ScalarType, int Rows, int Cols, typename SrcScalarType, + VectorShape Shape> +struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>, + VectorMap<SrcScalarType, Shape>> { + using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; + using SrcObjectType = VectorMap<SrcScalarType, Shape>; + using ResultBlockType = + typename LoadForBroadcastingRegisterBlock<RegisterBlockType, + SrcObjectType>::Type; + static_assert(ResultBlockType::kRegisterLanes == 1, + "This path is only for scalar values"); + static ResultBlockType Run(const SrcObjectType& src, int pos) { + ResultBlockType result; + for (int c = 0; c < ResultBlockType::kCols; c++) { + for (int r = 0; r < ResultBlockType::kRows; r++) { + const int i = Shape == VectorShape::Col ? r : c; + result.buf.reg[r + c * ResultBlockType::kRows] = src(pos + i); + } + } + return result; + } +}; + +template <typename ScalarType, int Rows, int Cols, typename SrcScalarType, + VectorShape Shape> +struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>, + VectorDup<SrcScalarType, Shape>> { + using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; + using SrcObjectType = VectorDup<SrcScalarType, Shape>; + using ResultBlockType = + typename LoadForBroadcastingRegisterBlock<RegisterBlockType, + SrcObjectType>::Type; + static_assert(ResultBlockType::kRegisterLanes == 1, + "This path is only for scalar values"); + static ResultBlockType Run(const SrcObjectType& src, int) { + ResultBlockType result; + for (int c = 0; c < ResultBlockType::kCols; c++) { + for (int r = 0; r < ResultBlockType::kRows; r++) { + result.buf.reg[r + c * ResultBlockType::kRows] = src(0); + } + } + return result; + } +}; + +template <typename RegisterBlockType, typename SrcObjectType> +typename LoadForBroadcastingRegisterBlock<RegisterBlockType, + SrcObjectType>::Type +LoadForBroadcasting(const SrcObjectType& src, int row, int col) { + return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run( + src, row, col); +} + +template <typename RegisterBlockType, typename SrcObjectType> +typename LoadForBroadcastingRegisterBlock<RegisterBlockType, + SrcObjectType>::Type +LoadForBroadcasting(const SrcObjectType& src, int pos) { + return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run(src, + pos); +} + +template <int ConstantValue, typename RegisterBlockType> +struct AddConstantImpl { + static void Run(RegisterBlockType* block) { + using RegisterType = typename RegisterBlockType::RegisterType; + const RegisterType dup = Dup<RegisterType>(ConstantValue); + for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) { + block->buf.reg[i] = Add(block->buf.reg[i], dup); + } + } +}; + +template <typename RegisterBlockType> +struct AddConstantImpl<0, RegisterBlockType> { + static void Run(RegisterBlockType*) { + // This is a no-op. + } +}; + +template <int ConstantValue, typename RegisterBlockType> +void AddConstant(RegisterBlockType* block) { + AddConstantImpl<ConstantValue, RegisterBlockType>::Run(block); +} + +template <int N> +using RegBufferInt32 = RegisterBuffer<std::int32_t, N>; +template <int N> +using RegBufferUint8 = RegisterBuffer<std::uint8_t, N>; +template <int R, int C> +using RegBlockInt32 = RegisterBlock<std::int32_t, R, C>; +template <int R, int C> +using RegBlockUint8 = RegisterBlock<std::uint8_t, R, C>; + +} // end namespace gemmlowp + +#if defined GEMMLOWP_NEON +#include "simd_wrappers_neon.h" +#elif defined GEMMLOWP_SSE4 +#include "simd_wrappers_sse.h" +#endif + +#endif // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/internal/simd_wrappers_common_neon_sse.h b/runtimes/nn/depend/external/gemmlowp/internal/simd_wrappers_common_neon_sse.h new file mode 100644 index 000000000..3830eb169 --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/internal/simd_wrappers_common_neon_sse.h @@ -0,0 +1,646 @@ +// Copyright 2015 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// simd_wrappers_common_neon_sse.h: common SIMD (NEON and SSE) wrapper code + +#ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_ +#define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_ + +#include "simd_wrappers.h" + +namespace gemmlowp { + +template <typename SrcScalarType, int N> +struct LoadImpl<RegBlockInt32<4, N>, + MatrixMap<SrcScalarType, MapOrder::ColMajor>> { + static RegBlockInt32<4, N> Run( + const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row, + int col) { + RegBlockInt32<4, N> result; + for (int i = 0; i < N; i++) { + result.buf.reg[i] = LoadInt32x4(src.data(row, col + i)); + } + return result; + } +}; + +template <typename SrcScalarType, int N> +struct LoadImpl<RegBlockInt32<8, N>, + MatrixMap<SrcScalarType, MapOrder::ColMajor>> { + static RegBlockInt32<8, N> Run( + const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row, + int col) { + RegBlockInt32<8, N> result; + for (int i = 0; i < N; i++) { + result.buf.reg[2 * i + 0] = LoadInt32x4(src.data(row + 0, col + i)); + result.buf.reg[2 * i + 1] = LoadInt32x4(src.data(row + 4, col + i)); + } + return result; + } +}; + +template <typename SrcScalarType> +struct LoadImpl<RegBlockInt32<1, 4>, + MatrixMap<SrcScalarType, MapOrder::ColMajor>> { + static RegBlockInt32<1, 4> Run( + const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row, + int col) { + RegBlockInt32<1, 4> result; + std::int32_t buf[4]; + for (int i = 0; i < 4; i++) { + buf[i] = src(row, col + i); + } + result.buf.reg[0] = LoadInt32x4(buf); + return result; + } +}; + +template <typename SrcScalarType> +struct LoadImpl<RegBlockInt32<1, 8>, + MatrixMap<SrcScalarType, MapOrder::ColMajor>> { + static RegBlockInt32<1, 8> Run( + const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row, + int col) { + RegBlockInt32<1, 8> result; + std::int32_t buf[8]; + for (int i = 0; i < 8; i++) { + buf[i] = src(row, col + i); + } + result.buf.reg[0] = LoadInt32x4(buf); + result.buf.reg[1] = LoadInt32x4(buf + 4); + return result; + } +}; + +template <typename SrcScalarType> +struct LoadImpl<RegBlockInt32<4, 1>, + VectorMap<SrcScalarType, VectorShape::Col>> { + static RegBlockInt32<4, 1> Run( + const VectorMap<SrcScalarType, VectorShape::Col>& src, int pos) { + RegBlockInt32<4, 1> result; + result.buf.reg[0] = LoadInt32x4(src.data(pos)); + return result; + } +}; + +template <typename SrcScalarType> +struct LoadImpl<RegBlockInt32<4, 1>, + VectorDup<SrcScalarType, VectorShape::Col>> { + static RegBlockInt32<4, 1> Run( + const VectorDup<SrcScalarType, VectorShape::Col>& src, int) { + RegBlockInt32<4, 1> result; + result.buf.reg[0] = LoadInt32x4(src(0)); + return result; + } +}; + +template <typename SrcScalarType, int N> +struct LoadForBroadcastingImpl<RegBlockInt32<4, N>, + VectorMap<SrcScalarType, VectorShape::Col>> { + using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Col>; + using RegisterBlockType = RegBlockInt32<4, N>; + using ResultBlockType = + typename LoadForBroadcastingRegisterBlock<RegisterBlockType, + SrcObjectType>::Type; + + static ResultBlockType Run(const SrcObjectType& src, int pos) { + ResultBlockType result; + static_assert(ResultBlockType::kRegisterCount == 1, ""); + result.buf.reg[0] = LoadInt32x4(src.data(pos)); + return result; + } +}; + +template <typename SrcScalarType, int N> +struct LoadForBroadcastingImpl<RegBlockInt32<8, N>, + VectorMap<SrcScalarType, VectorShape::Col>> { + using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Col>; + using RegisterBlockType = RegBlockInt32<8, N>; + using ResultBlockType = + typename LoadForBroadcastingRegisterBlock<RegisterBlockType, + SrcObjectType>::Type; + + static ResultBlockType Run(const SrcObjectType& src, int pos) { + ResultBlockType result; + static_assert(ResultBlockType::kRegisterCount == 2, ""); + result.buf.reg[0] = LoadInt32x4(src.data(pos)); + result.buf.reg[1] = LoadInt32x4(src.data(pos + 4)); + return result; + } +}; + +template <typename SrcScalarType> +struct LoadForBroadcastingImpl<RegBlockInt32<4, 1>, + VectorMap<SrcScalarType, VectorShape::Row>> { + using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>; + using RegisterBlockType = RegBlockInt32<4, 1>; + using ResultBlockType = + typename LoadForBroadcastingRegisterBlock<RegisterBlockType, + SrcObjectType>::Type; + + static ResultBlockType Run(const SrcObjectType& src, int pos) { + ResultBlockType result; + result.buf.reg[0] = src(pos); + return result; + } +}; + +template <typename SrcScalarType, int N> +struct LoadForBroadcastingImpl<RegBlockInt32<N, 4>, + VectorMap<SrcScalarType, VectorShape::Row>> { + using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>; + using RegisterBlockType = RegBlockInt32<N, 4>; + using ResultBlockType = + typename LoadForBroadcastingRegisterBlock<RegisterBlockType, + SrcObjectType>::Type; + + static ResultBlockType Run(const SrcObjectType& src, int pos) { + ResultBlockType result; + static_assert(ResultBlockType::kRegisterCount == 1, ""); + result.buf.reg[0] = LoadInt32x4(src.data(pos)); + return result; + } +}; + +template <typename SrcScalarType, int N> +struct LoadForBroadcastingImpl<RegBlockInt32<N, 8>, + VectorMap<SrcScalarType, VectorShape::Row>> { + using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>; + using RegisterBlockType = RegBlockInt32<N, 8>; + using ResultBlockType = + typename LoadForBroadcastingRegisterBlock<RegisterBlockType, + SrcObjectType>::Type; + + static ResultBlockType Run(const SrcObjectType& src, int pos) { + ResultBlockType result; + static_assert(ResultBlockType::kRegisterCount == 2, ""); + result.buf.reg[0] = LoadInt32x4(src.data(pos)); + result.buf.reg[1] = LoadInt32x4(src.data(pos + 4)); + return result; + } +}; + +// 4x1 := 4x1 + 1x1 +template <> +struct BroadcastAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> { + static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<4, 1> result; + result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); + return result; + } +}; + +// 1x4 := 1x4 + 1x1 +template <> +struct BroadcastAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>> { + static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<1, 4> result; + result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); + return result; + } +}; + +// 4x1 := 4x1 + 4x1 +template <> +struct BroadcastAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<4, 1>> { + static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, + const RegBlockInt32<4, 1>& rhs) { + RegBlockInt32<4, 1> result; + result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]); + return result; + } +}; + +// 1x4 := 1x4 + 1x4 +template <> +struct BroadcastAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 4>> { + static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, + const RegBlockInt32<1, 4>& rhs) { + RegBlockInt32<1, 4> result; + result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]); + return result; + } +}; + +// 4x4 := 4x4 + 1x4 +template <> +struct BroadcastAddImpl<RegBlockInt32<4, 4>, RegBlockInt32<1, 4>> { + static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, + const RegBlockInt32<1, 4>& rhs) { + RegBlockInt32<4, 4> result; + result.buf.reg[0] = Add(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); + result.buf.reg[1] = Add(lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0])); + result.buf.reg[2] = Add(lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0])); + result.buf.reg[3] = Add(lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0])); + return result; + } +}; + +// 4x4 := 4x4 + 4x1 +template <> +struct BroadcastAddImpl<RegBlockInt32<4, 4>, RegBlockInt32<4, 1>> { + static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, + const RegBlockInt32<4, 1>& rhs) { + RegBlockInt32<4, 4> result; + result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]); + result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[0]); + result.buf.reg[2] = Add(lhs.buf.reg[2], rhs.buf.reg[0]); + result.buf.reg[3] = Add(lhs.buf.reg[3], rhs.buf.reg[0]); + return result; + } +}; + +// 8x1 := 8x1 + 1x1 +template <> +struct BroadcastAddImpl<RegBlockInt32<8, 1>, RegBlockInt32<1, 1>> { + static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<8, 1> result; + const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]); + for (int i = 0; i < 2; i++) { + result.buf.reg[i] = Add(lhs.buf.reg[i], p); + } + return result; + } +}; + +// 8x1 := 8x1 + 8x1 +template <> +struct BroadcastAddImpl<RegBlockInt32<8, 1>, RegBlockInt32<8, 1>> { + static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, + const RegBlockInt32<8, 1>& rhs) { + RegBlockInt32<8, 1> result; + for (int i = 0; i < 2; i++) { + result.buf.reg[i] = Add(lhs.buf.reg[i], rhs.buf.reg[i]); + } + return result; + } +}; + +// 8x4 := 8x4 + 1x4 +template <> +struct BroadcastAddImpl<RegBlockInt32<8, 4>, RegBlockInt32<1, 4>> { + static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, + const RegBlockInt32<1, 4>& rhs) { + RegBlockInt32<8, 4> result; + result.buf.reg[0] = Add(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); + result.buf.reg[1] = Add(lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0])); + result.buf.reg[2] = Add(lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0])); + result.buf.reg[3] = Add(lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0])); + result.buf.reg[4] = Add(lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0])); + result.buf.reg[5] = Add(lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0])); + result.buf.reg[6] = Add(lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0])); + result.buf.reg[7] = Add(lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0])); + return result; + } +}; + +// 8x4 := 8x4 + 8x1 +template <> +struct BroadcastAddImpl<RegBlockInt32<8, 4>, RegBlockInt32<8, 1>> { + static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, + const RegBlockInt32<8, 1>& rhs) { + RegBlockInt32<8, 4> result; + result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]); + result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[1]); + result.buf.reg[2] = Add(lhs.buf.reg[2], rhs.buf.reg[0]); + result.buf.reg[3] = Add(lhs.buf.reg[3], rhs.buf.reg[1]); + result.buf.reg[4] = Add(lhs.buf.reg[4], rhs.buf.reg[0]); + result.buf.reg[5] = Add(lhs.buf.reg[5], rhs.buf.reg[1]); + result.buf.reg[6] = Add(lhs.buf.reg[6], rhs.buf.reg[0]); + result.buf.reg[7] = Add(lhs.buf.reg[7], rhs.buf.reg[1]); + return result; + } +}; + +// 1x8 := 1x8 + 1x8 +template <> +struct BroadcastAddImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 8>> { + static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, + const RegBlockInt32<1, 8>& rhs) { + RegBlockInt32<1, 8> result; + result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]); + result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[1]); + return result; + } +}; + +// 1x8 := 1x8 + 1x1 +template <> +struct BroadcastAddImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 1>> { + static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<1, 8> result; + result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); + result.buf.reg[1] = Add(lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0])); + return result; + } +}; + +// 4x1 := 4x1 * 1x1 +template <> +struct BroadcastMulImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> { + static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<4, 1> result; + result.buf.reg[0] = Mul(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); + return result; + } +}; + +// 4x1 := 4x1 * 4x1 +template <> +struct BroadcastMulImpl<RegBlockInt32<4, 1>, RegBlockInt32<4, 1>> { + static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, + const RegBlockInt32<4, 1>& rhs) { + RegBlockInt32<4, 1> result; + result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]); + return result; + } +}; + +// 1x4 := 1x4 * 1x4 +template <> +struct BroadcastMulImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 4>> { + static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, + const RegBlockInt32<1, 4>& rhs) { + RegBlockInt32<1, 4> result; + result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]); + return result; + } +}; + +// 1x4 := 1x4 * 1x1 +template <> +struct BroadcastMulImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>> { + static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<1, 4> result; + result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]); + return result; + } +}; + +// 4x4 := 4x4 * 1x4 +template <> +struct BroadcastMulImpl<RegBlockInt32<4, 4>, RegBlockInt32<1, 4>> { + static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, + const RegBlockInt32<1, 4>& rhs) { + RegBlockInt32<4, 4> result; + const Int32x4 p = rhs.buf.reg[0]; + result.buf.reg[0] = MulByRhsLane<0>(lhs.buf.reg[0], p); + result.buf.reg[1] = MulByRhsLane<1>(lhs.buf.reg[1], p); + result.buf.reg[2] = MulByRhsLane<2>(lhs.buf.reg[2], p); + result.buf.reg[3] = MulByRhsLane<3>(lhs.buf.reg[3], p); + return result; + } +}; + +// 4x4 := 4x4 * 4x1 +template <> +struct BroadcastMulImpl<RegBlockInt32<4, 4>, RegBlockInt32<4, 1>> { + static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, + const RegBlockInt32<4, 1>& rhs) { + RegBlockInt32<4, 4> result; + const Int32x4 p = rhs.buf.reg[0]; + result.buf.reg[0] = Mul(lhs.buf.reg[0], p); + result.buf.reg[1] = Mul(lhs.buf.reg[1], p); + result.buf.reg[2] = Mul(lhs.buf.reg[2], p); + result.buf.reg[3] = Mul(lhs.buf.reg[3], p); + return result; + } +}; + +// 8x1 := 8x1 * 1x1 +template <> +struct BroadcastMulImpl<RegBlockInt32<8, 1>, RegBlockInt32<1, 1>> { + static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<8, 1> result; + const std::int32_t p = rhs.buf.reg[0]; + for (int i = 0; i < 2; i++) { + result.buf.reg[i] = Mul(lhs.buf.reg[i], p); + } + return result; + } +}; + +// 8x1 := 8x1 * 8x1 +template <> +struct BroadcastMulImpl<RegBlockInt32<8, 1>, RegBlockInt32<8, 1>> { + static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, + const RegBlockInt32<8, 1>& rhs) { + RegBlockInt32<8, 1> result; + for (int i = 0; i < 2; i++) { + result.buf.reg[i] = Mul(lhs.buf.reg[i], rhs.buf.reg[i]); + } + return result; + } +}; + +// 8x4 := 8x4 * 1x4 +template <> +struct BroadcastMulImpl<RegBlockInt32<8, 4>, RegBlockInt32<1, 4>> { + static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, + const RegBlockInt32<1, 4>& rhs) { + RegBlockInt32<8, 4> result; + const Int32x4 p = rhs.buf.reg[0]; + for (int i = 0; i < 2; i++) { + result.buf.reg[i + 0] = MulByRhsLane<0>(lhs.buf.reg[i + 0], p); + result.buf.reg[i + 2] = MulByRhsLane<1>(lhs.buf.reg[i + 2], p); + result.buf.reg[i + 4] = MulByRhsLane<2>(lhs.buf.reg[i + 4], p); + result.buf.reg[i + 6] = MulByRhsLane<3>(lhs.buf.reg[i + 6], p); + } + return result; + } +}; + +// 8x4 := 8x4 * 8x1 +template <> +struct BroadcastMulImpl<RegBlockInt32<8, 4>, RegBlockInt32<8, 1>> { + static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, + const RegBlockInt32<8, 1>& rhs) { + RegBlockInt32<8, 4> result; + const Int32x4 p[2]{rhs.buf.reg[0], rhs.buf.reg[1]}; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 2; j++) { + const int k = j + 2 * i; + result.buf.reg[k] = Mul(lhs.buf.reg[k], p[j]); + } + } + return result; + } +}; + +// Rx1 += Rx1 * 1x1 +template <int Rows> +struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 1>, + RegBlockInt32<Rows, 1>> { + static void Run(const RegBlockInt32<Rows, 1>& lhs, + const RegBlockInt32<1, 1>& rhs, RegBlockInt32<Rows, 1>* acc) { + const std::int32_t p = rhs.buf.reg[0]; + for (int i = 0; i < RegBlockInt32<Rows, 1>::kRegisterCount; i++) { + MulAdd(lhs.buf.reg[i], p, &acc->buf.reg[i]); + } + } +}; + +// RxC += Rx1 * 1x1 +template <int Rows, int Cols> +struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 1>, + RegBlockInt32<Rows, Cols>> { + static void Run(const RegBlockInt32<Rows, 1>& lhs, + const RegBlockInt32<1, 1>& rhs, + RegBlockInt32<Rows, Cols>* acc) { + const std::int32_t p = rhs.buf.reg[0]; + static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount; + for (int i = 0; i < kRegsPerCol; i++) { + const Int32x4 q = Mul(lhs.buf.reg[i], p); + for (int j = 0; j < Cols; j++) { + acc->buf.reg[i + j * kRegsPerCol] = + Add(acc->buf.reg[i + j * kRegsPerCol], q); + } + } + } +}; + +// 1xC += 1xC * 1x1 +template <int Cols> +struct BroadcastMulAddImpl<RegBlockInt32<1, Cols>, RegBlockInt32<1, 1>, + RegBlockInt32<1, Cols>> { + static void Run(const RegBlockInt32<1, Cols>& lhs, + const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, Cols>* acc) { + const std::int32_t p = rhs.buf.reg[0]; + for (int i = 0; i < RegBlockInt32<1, Cols>::kRegisterCount; i++) { + MulAdd(lhs.buf.reg[i], p, &acc->buf.reg[i]); + } + } +}; + +// RxC += 1x1 * 1x1 +template <int Rows, int Cols> +struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>, + RegBlockInt32<Rows, Cols>> { + static void Run(const RegBlockInt32<1, 1>& lhs, + const RegBlockInt32<1, 1>& rhs, + RegBlockInt32<Rows, Cols>* acc) { + const Int32x4 p = Dup<Int32x4>(Mul(lhs.buf.reg[0], rhs.buf.reg[0])); + for (int i = 0; i < RegBlockInt32<Rows, Cols>::kRegisterCount; i++) { + acc->buf.reg[i] = Add(acc->buf.reg[i], p); + } + } +}; + +// 1x1 += 1x1 * 1x1 +template <> +struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>, + RegBlockInt32<1, 1>> { + static void Run(const RegBlockInt32<1, 1>& lhs, + const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, 1>* acc) { + MulAdd(lhs.buf.reg[0], rhs.buf.reg[0], &acc->buf.reg[0]); + } +}; + +// Rx4 += Rx1 * 1x4 +template <int Rows> +struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 4>, + RegBlockInt32<Rows, 4>> { + static void Run(const RegBlockInt32<Rows, 1>& lhs, + const RegBlockInt32<1, 4>& rhs, RegBlockInt32<Rows, 4>* acc) { + const Int32x4 p = rhs.buf.reg[0]; + static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount; + for (int i = 0; i < kRegsPerCol; i++) { + MulAddByRhsLane<0>(lhs.buf.reg[i], p, &acc->buf.reg[i + 0 * kRegsPerCol]); + MulAddByRhsLane<1>(lhs.buf.reg[i], p, &acc->buf.reg[i + 1 * kRegsPerCol]); + MulAddByRhsLane<2>(lhs.buf.reg[i], p, &acc->buf.reg[i + 2 * kRegsPerCol]); + MulAddByRhsLane<3>(lhs.buf.reg[i], p, &acc->buf.reg[i + 3 * kRegsPerCol]); + } + } +}; + +// Rx4 += 1x4 * 1x1 +template <int Rows> +struct BroadcastMulAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>, + RegBlockInt32<Rows, 4>> { + static void Run(const RegBlockInt32<1, 4>& lhs, + const RegBlockInt32<1, 1>& rhs, RegBlockInt32<Rows, 4>* acc) { + const Int32x4 p = Mul(lhs.buf.reg[0], rhs.buf.reg[0]); + Int32x4 q[4]; + q[0] = DupLane<0>(p); + q[1] = DupLane<1>(p); + q[2] = DupLane<2>(p); + q[3] = DupLane<3>(p); + static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount; + for (int i = 0; i < kRegsPerCol; i++) { + for (int j = 0; j < 4; j++) { + acc->buf.reg[i + j * kRegsPerCol] = + Add(q[j], acc->buf.reg[i + j * kRegsPerCol]); + } + } + } +}; + +// 1xC += 1x1 * 1x1 +template <int Cols> +struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>, + RegBlockInt32<1, Cols>> { + static void Run(const RegBlockInt32<1, 1>& lhs, + const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, Cols>* acc) { + const Int32x4 p = Dup<Int32x4>(Mul(lhs.buf.reg[0], rhs.buf.reg[0])); + for (int i = 0; i < RegBlockInt32<1, Cols>::kRegisterCount; i++) { + acc->buf.reg[i] = Add(acc->buf.reg[i], p); + } + } +}; + +// 1x4 += 1x4 * 1x1 +template <> +struct BroadcastMulAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>, + RegBlockInt32<1, 4>> { + static void Run(const RegBlockInt32<1, 4>& lhs, + const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, 4>* acc) { + const std::int32_t p = rhs.buf.reg[0]; + MulAdd(lhs.buf.reg[0], p, &acc->buf.reg[0]); + } +}; + +// 4xC += 4x1 * 1x1 +template <int Cols> +struct BroadcastMulAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>, + RegBlockInt32<4, Cols>> { + static void Run(const RegBlockInt32<4, 1>& lhs, + const RegBlockInt32<1, 1>& rhs, RegBlockInt32<4, Cols>* acc) { + const Int32x4 p = Mul(lhs.buf.reg[0], rhs.buf.reg[0]); + for (int i = 0; i < Cols; i++) { + acc->buf.reg[i] = Add(p, acc->buf.reg[i]); + } + } +}; + +// 4x1 += 4x1 * 1x1 +template <> +struct BroadcastMulAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>, + RegBlockInt32<4, 1>> { + static void Run(const RegBlockInt32<4, 1>& lhs, + const RegBlockInt32<1, 1>& rhs, RegBlockInt32<4, 1>* acc) { + const std::int32_t p = rhs.buf.reg[0]; + MulAdd(lhs.buf.reg[0], p, &acc->buf.reg[0]); + } +}; + +} // namespace gemmlowp + +#endif // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/internal/simd_wrappers_neon.h b/runtimes/nn/depend/external/gemmlowp/internal/simd_wrappers_neon.h new file mode 100644 index 000000000..c992b1597 --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/internal/simd_wrappers_neon.h @@ -0,0 +1,150 @@ +// Copyright 2017 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// simd_wrappers_neon.h: NEON specialization of simd_wrappers.h + +#ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_NEON_H_ +#define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_NEON_H_ + +#include <arm_neon.h> + +namespace gemmlowp { + +using Int32x4 = int32x4_t; +using Uint8x8 = uint8x8_t; + +template <int ScalarCount> +struct RegisterType<std::int32_t, ScalarCount> { + using Type = + typename std::conditional<ScalarCount >= 4, Int32x4, std::int32_t>::type; +}; + +template <int ScalarCount> +struct RegisterType<std::uint8_t, ScalarCount> { + using Type = typename std::conditional< + ScalarCount >= 8, Uint8x8, + typename std::conditional<ScalarCount >= 4, std::uint32_t, + std::uint8_t>::type>::type; +}; + +inline Int32x4 LoadInt32x4(const std::int32_t* src) { return vld1q_s32(src); } + +inline void StoreInt32x4(std::int32_t* dst, Int32x4 value) { + vst1q_s32(dst, value); +} + +template <int Lane> +std::int32_t GetLane(Int32x4 value) { + return vgetq_lane_s32(value, Lane); +} + +template <int Lane> +Int32x4 DupLane(Int32x4 value) { + switch (Lane) { + case 0: + return vdupq_lane_s32(vget_low_s32(value), 0); + case 1: + return vdupq_lane_s32(vget_low_s32(value), 1); + case 2: + return vdupq_lane_s32(vget_high_s32(value), 0); + case 3: + return vdupq_lane_s32(vget_high_s32(value), 1); + default: + static_assert(Lane >= 0 && Lane <= 3, ""); + return vdupq_n_s32(0); + } +} + +inline Int32x4 Mul(Int32x4 a, std::int32_t b) { return vmulq_n_s32(a, b); } + +inline Int32x4 Min(Int32x4 a, Int32x4 b) { return vminq_s32(a, b); } + +inline Int32x4 Max(Int32x4 a, Int32x4 b) { return vmaxq_s32(a, b); } + +inline Int32x4 SaturatingRoundingDoublingHighMul(Int32x4 a, std::int32_t b) { + return vqrdmulhq_n_s32(a, b); +} + +template <int Lane> +Int32x4 MulByRhsLane(Int32x4 a, Int32x4 b) { + switch (Lane) { + case 0: + return vmulq_lane_s32(a, vget_low_s32(b), 0); + case 1: + return vmulq_lane_s32(a, vget_low_s32(b), 1); + case 2: + return vmulq_lane_s32(a, vget_high_s32(b), 0); + case 3: + return vmulq_lane_s32(a, vget_high_s32(b), 1); + default: + static_assert(Lane >= 0 && Lane <= 3, ""); + return vdupq_n_s32(0); + } +} + +inline void MulAdd(Int32x4 lhs, Int32x4 rhs, Int32x4* acc) { + *acc = vmlaq_s32(*acc, lhs, rhs); +} + +inline void MulAdd(Int32x4 lhs, std::int32_t rhs, Int32x4* acc) { + *acc = vmlaq_n_s32(*acc, lhs, rhs); +} + +template <int Lane> +inline void MulAddByRhsLane(Int32x4 lhs, Int32x4 rhs, Int32x4* acc) { + switch (Lane) { + case 0: + *acc = vmlaq_lane_s32(*acc, lhs, vget_low_s32(rhs), 0); + break; + case 1: + *acc = vmlaq_lane_s32(*acc, lhs, vget_low_s32(rhs), 1); + break; + case 2: + *acc = vmlaq_lane_s32(*acc, lhs, vget_high_s32(rhs), 0); + break; + case 3: + *acc = vmlaq_lane_s32(*acc, lhs, vget_high_s32(rhs), 1); + break; + default: + static_assert(Lane >= 0 && Lane <= 3, ""); + } +} + +template <> +struct LoadContiguousImpl<RegBlockUint8<8, 8>> { + static RegBlockUint8<8, 8> Run(const std::uint8_t* src) { + RegBlockUint8<8, 8> result; + for (int i = 0; i < 8; i++) { + result.buf.reg[i] = vld1_u8(src + 8 * i); + } + return result; + } +}; + +template <> +struct LoadContiguousImpl<RegBlockInt32<8, 8>> { + static RegBlockInt32<8, 8> Run(const std::int32_t* src) { + RegBlockInt32<8, 8> result; + for (int i = 0; i < 16; i++) { + result.buf.reg[i] = vld1q_s32(src + 4 * i); + } + return result; + } +}; + +} // end namespace gemmlowp + +#include "simd_wrappers_common_neon_sse.h" + +#endif // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_NEON_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/internal/simd_wrappers_sse.h b/runtimes/nn/depend/external/gemmlowp/internal/simd_wrappers_sse.h new file mode 100644 index 000000000..6480b6690 --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/internal/simd_wrappers_sse.h @@ -0,0 +1,123 @@ +// Copyright 2017 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// simd_wrappers_neon.h: SSE SIMD wrappers + +#ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_SSE_H_ +#define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_SSE_H_ + +#include <smmintrin.h> + +namespace gemmlowp { + +using Int32x4 = __m128i; +using Uint8x16 = __m128i; + +template <int ScalarCount> +struct RegisterType<std::int32_t, ScalarCount> { + using Type = + typename std::conditional<ScalarCount >= 4, Int32x4, std::int32_t>::type; +}; + +template <int ScalarCount> +struct RegisterType<std::uint8_t, ScalarCount> { + using Type = typename std::conditional< + ScalarCount >= 16, Uint8x16, + typename std::conditional<ScalarCount >= 4, std::uint32_t, + std::uint8_t>::type>::type; +}; + +inline Int32x4 LoadInt32x4(const std::int32_t* src) { + return _mm_loadu_si128(reinterpret_cast<const Int32x4*>(src)); +} + +inline void StoreInt32x4(std::int32_t* dst, Int32x4 value) { + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), value); +} + +inline Uint8x16 LoadUint8x16(const std::uint8_t* src) { + return _mm_loadu_si128(reinterpret_cast<const Uint8x16*>(src)); +} + +inline void StoreUint8x16(std::uint8_t* dst, Uint8x16 value) { + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), value); +} + +template <int Lane> +std::int32_t GetLane(Int32x4 value) { + return _mm_extract_epi32(value, Lane); +} + +template <int Lane> +Int32x4 DupLane(Int32x4 value) { + return _mm_shuffle_epi32(value, _MM_SHUFFLE(Lane, Lane, Lane, Lane)); +} + +inline Int32x4 Mul(Int32x4 a, std::int32_t b) { + return Mul(a, Dup<Int32x4>(b)); +} + +inline Int32x4 Min(Int32x4 a, Int32x4 b) { return _mm_min_epi32(a, b); } + +inline Int32x4 Max(Int32x4 a, Int32x4 b) { return _mm_max_epi32(a, b); } + +inline Int32x4 SaturatingRoundingDoublingHighMul(Int32x4 a, std::int32_t b) { + return SaturatingRoundingDoublingHighMul(a, Dup<Int32x4>(b)); +} + +template <int Lane> +Int32x4 MulByRhsLane(Int32x4 a, Int32x4 b) { + return Mul(a, DupLane<Lane>(b)); +} + +inline void MulAdd(Int32x4 lhs, Int32x4 rhs, Int32x4* acc) { + *acc = Add(*acc, Mul(lhs, rhs)); +} + +inline void MulAdd(Int32x4 lhs, std::int32_t rhs, Int32x4* acc) { + *acc = Add(*acc, Mul(lhs, rhs)); +} + +template <int Lane> +inline void MulAddByRhsLane(Int32x4 lhs, Int32x4 rhs, Int32x4* acc) { + *acc = Add(*acc, MulByRhsLane<Lane>(lhs, rhs)); +} + +template <> +struct LoadContiguousImpl<RegBlockUint8<8, 8>> { + static RegBlockUint8<8, 8> Run(const std::uint8_t* src) { + RegBlockUint8<8, 8> result; + for (int i = 0; i < 4; i++) { + result.buf.reg[i] = LoadUint8x16(src + 16 * i); + } + return result; + } +}; + +template <> +struct LoadContiguousImpl<RegBlockInt32<8, 8>> { + static RegBlockInt32<8, 8> Run(const std::int32_t* src) { + RegBlockInt32<8, 8> result; + for (int i = 0; i < 16; i++) { + result.buf.reg[i] = LoadInt32x4(src + 4 * i); + } + return result; + } +}; + +} // end namespace gemmlowp + +#include "simd_wrappers_common_neon_sse.h" + +#endif // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_SSE_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/internal/single_thread_gemm.h b/runtimes/nn/depend/external/gemmlowp/internal/single_thread_gemm.h new file mode 100644 index 000000000..3d430c5d4 --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/internal/single_thread_gemm.h @@ -0,0 +1,158 @@ +// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// single_thread_gemm.h: Single-threaded GEMM implementation. +// This is a good place to start reading code, as it shows the overall +// structure of a GEMM and is much simpler than multi_thread_gemm.h. + +#ifndef GEMMLOWP_INTERNAL_SINGLE_THREAD_GEMM_H_ +#define GEMMLOWP_INTERNAL_SINGLE_THREAD_GEMM_H_ + +#include <cassert> + +#include "../public/map.h" +#include "allocator.h" +#include "compute.h" +#include "kernel.h" +#include "pack.h" +#include "unpack.h" + +#ifdef GEMMLOWP_PROFILING_SIZES +#ifndef GEMMLOWP_PROFILING +#error GEMMLOWP_PROFILING_SIZES without GEMMLOWP_PROFILING +#endif +#include <string> +#include <unordered_map> +#endif + +namespace gemmlowp { + +class SingleThreadGemmContext { + public: + Allocator* allocator() { return &allocator_; } + + void set_l1_bytes_to_use(int n) { l1_bytes_to_use_ = n; } + void set_l2_bytes_to_use(int n) { l2_bytes_to_use_ = n; } + void set_l2_rhs_factor(float n) { l2_rhs_factor_ = n; } + + int l1_bytes_to_use() const { return l1_bytes_to_use_; } + int l2_bytes_to_use() const { return l2_bytes_to_use_; } + float l2_rhs_factor() const { return l2_rhs_factor_; } + + protected: + Allocator allocator_; + + // The cache configurationt to use. + int l1_bytes_to_use_ = kDefaultL1CacheSize; + int l2_bytes_to_use_ = kDefaultL2CacheSize; + float l2_rhs_factor_ = kDefaultL2RhsFactor; +}; + +template <typename KernelFormat, typename InputScalar, typename OutputScalar, + typename BitDepthParams, MapOrder LhsOrder, MapOrder RhsOrder, + MapOrder ResultOrder, typename LhsOffset, typename RhsOffset, + typename OutputPipelineType> +void SingleThreadGemm(SingleThreadGemmContext* context, + const KernelBase& kernel, + const MatrixMap<const InputScalar, LhsOrder>& lhs, + const MatrixMap<const InputScalar, RhsOrder>& rhs, + MatrixMap<OutputScalar, ResultOrder>* result, + const LhsOffset& lhs_offset, const RhsOffset& rhs_offset, + const OutputPipelineType& output_pipeline) { + ScopedProfilingLabel label("gemmlowp::SingleThreadGemm"); + + assert(lhs.cols() == rhs.rows()); + + int rows = result->rows(); + int cols = result->cols(); + int depth = lhs.cols(); + + // zero sizes should have been caught earlier and early-returned. + assert(rows > 0); + assert(cols > 0); + assert(depth > 0); + + // The case of rows<cols should have been caught earlier and transposed. + assert(rows >= cols); + + Allocator* allocator = context->allocator(); + + BlockParams block_params; + block_params.Init<KernelFormat>(rows, cols, depth, 1, + context->l1_bytes_to_use(), + context->l2_bytes_to_use(), + context->l2_rhs_factor()); + +#ifdef GEMMLOWP_PROFILING_SIZES + // Using a static map of label strings. Not reentrant at all! + static std::unordered_map<std::uint64_t, std::string> labels_map; + std::uint64_t sizes_hash = static_cast<std::uint64_t>(rows) ^ + (static_cast<std::uint64_t>(depth) << 16) ^ + (static_cast<std::uint64_t>(cols) << 32); + if (!labels_map.count(sizes_hash)) { + char label[256]; + snprintf(label, sizeof(label), + "(rows = %d, depth = %d, cols = %d, l2_rows = %d, l2_depth = %d, " + "l2_cols = %d, l1_rows = %d, l1_depth = %d, l1_cols = %d)", + rows, depth, cols, block_params.l2_rows, block_params.l2_depth, + block_params.l2_cols, block_params.l1_rows, block_params.l1_depth, + block_params.l1_cols); + labels_map[sizes_hash] = label; + } + ScopedProfilingLabel size_label(labels_map[sizes_hash].c_str()); +#endif + + PackedSideBlock<typename KernelFormat::Lhs> packed_lhs(Side::Lhs, allocator, + block_params); + PackedSideBlock<typename KernelFormat::Rhs> packed_rhs(Side::Rhs, allocator, + block_params); + + PackedResult packed_result(allocator, block_params); + + allocator->Commit(); + + const bool pack_rhs_once = block_params.l2_cols >= cols; + + if (pack_rhs_once) { + PackRhs(&packed_rhs, rhs); + } + + for (int r = 0; r < rows; r += block_params.l2_rows) { + int rs = std::min(block_params.l2_rows, rows - r); + + PackLhs(&packed_lhs, lhs.block(r, 0, rs, depth)); + + for (int c = 0; c < cols; c += block_params.l2_cols) { + int cs = std::min(block_params.l2_cols, cols - c); + + if (!pack_rhs_once) { + PackRhs(&packed_rhs, rhs.block(0, c, depth, cs)); + } + + Compute(kernel, block_params, &packed_result, packed_lhs, packed_rhs, + depth); + + UnpackResult<KernelFormat>( + result, MatrixBlockBounds(r, c, rs, cs), packed_result, depth, + packed_lhs.sums_of_each_slice(), packed_rhs.sums_of_each_slice(), + lhs_offset.block(r, rs), rhs_offset.block(c, cs), output_pipeline); + } + } + + allocator->Decommit(); +} + +} // namespace gemmlowp + +#endif // GEMMLOWP_INTERNAL_SINGLE_THREAD_GEMM_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/internal/unpack.h b/runtimes/nn/depend/external/gemmlowp/internal/unpack.h new file mode 100644 index 000000000..33aee13b8 --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/internal/unpack.h @@ -0,0 +1,278 @@ +// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// unpack.h: unpacking the result blocks computed by compute.h, +// storing them into the destination matrix. + +#ifndef GEMMLOWP_INTERNAL_UNPACK_H_ +#define GEMMLOWP_INTERNAL_UNPACK_H_ + +#include "allocator.h" +#include "block_params.h" +#include "output.h" +#include "pack.h" + +#include <cmath> + +namespace gemmlowp { + +class PackedResult { + public: + PackedResult(Allocator* _allocator, const BlockParams& _block_params) + : allocator_(_allocator), block_params_(_block_params) { + matrix_handle_ = allocator_->Reserve<std::int32_t>(block_params_.l2_rows * + block_params_.l2_cols); + } + + ~PackedResult() {} + + MatrixMap<std::int32_t, MapOrder::ColMajor> Map() { + return MatrixMap<std::int32_t, MapOrder::ColMajor>( + allocator_->GetPointer<std::int32_t>(matrix_handle_), + block_params_.l2_rows, block_params_.l2_cols, block_params_.l2_rows); + } + + MatrixMap<const std::int32_t, MapOrder::ColMajor> Map() const { + return MatrixMap<const std::int32_t, MapOrder::ColMajor>( + allocator_->GetPointer<const std::int32_t>(matrix_handle_), + block_params_.l2_rows, block_params_.l2_cols, block_params_.l2_rows); + } + + private: + Allocator* allocator_; + Allocator::Handle matrix_handle_; + const BlockParams& block_params_; +}; + +struct MatrixBlockBounds { + int start_row; + int start_col; + int rows; + int cols; + + MatrixBlockBounds(int start_row_, int start_col_, int rows_, int cols_) + : start_row(start_row_), + start_col(start_col_), + rows(rows_), + cols(cols_) {} +}; + +template <int Rows, int Cols, typename SrcMapType> +void PrefetchResultBlock(const SrcMapType& src, + const VectorMap<const std::int32_t, VectorShape::Col>& + lhs_sums_of_each_slice, + int src_row, int src_col) { + const std::int32_t* src_data = src.data(src_row, src_col); + const int src_stride = src.stride(); + const std::int32_t* lhs_sums_data = lhs_sums_of_each_slice.data(src_row); + for (int r = 0; r < Rows; r += 4) { + Prefetch(lhs_sums_data + r); + } + for (int c = 0; c < Cols; c++) { + for (int r = 0; r < Rows; r += 4) { + Prefetch(src_data + r + c * src_stride); + } + } +} + +template <typename KernelFormat, typename RegisterBlockType, + typename SrcMapType, typename LhsOffset, typename RhsOffset, + typename OutputPipelineExecutorType, typename DstType> +void UnpackResultBlock(const SrcMapType& src, + const OutputPipelineExecutorType& executor, DstType* dst, + const VectorMap<const std::int32_t, VectorShape::Col>& + lhs_sums_of_each_slice, + const VectorMap<const std::int32_t, VectorShape::Row>& + rhs_sums_of_each_slice, + const LhsOffset& lhs_offset, const RhsOffset& rhs_offset, + int depth, int src_row, int src_col, int src_global_row, + int src_global_col, int dst_row, int dst_col) { + using KernelLhsScalar = typename KernelFormat::Lhs::Scalar; + using KernelRhsScalar = typename KernelFormat::Rhs::Scalar; + static constexpr int KernelLhsZeroPointInput = + ZeroPointInputValue<KernelLhsScalar>::kValue; + static constexpr int KernelRhsZeroPointInput = + ZeroPointInputValue<KernelRhsScalar>::kValue; + auto acc = Load<RegisterBlockType>(src, src_row, src_col); + const auto& lhs_sums_of_each_slice_block = + LoadForBroadcasting<RegisterBlockType>(lhs_sums_of_each_slice, src_row); + const auto& rhs_sums_of_each_slice_block = + LoadForBroadcasting<RegisterBlockType>(rhs_sums_of_each_slice, src_col); + auto lhs_offset_block = + LoadForBroadcasting<RegisterBlockType>(lhs_offset, src_row); + auto rhs_offset_block = + LoadForBroadcasting<RegisterBlockType>(rhs_offset, src_col); + AddConstant<KernelLhsZeroPointInput>(&lhs_offset_block); + AddConstant<KernelRhsZeroPointInput>(&rhs_offset_block); + BroadcastMulAdd(lhs_sums_of_each_slice_block, rhs_offset_block, &acc); + for (int i = 0; i < decltype(rhs_offset_block)::kRegisterCount; i++) { + rhs_offset_block.buf.reg[i] = Mul(rhs_offset_block.buf.reg[i], depth); + } + BroadcastMulAdd(BroadcastAdd(rhs_sums_of_each_slice_block, rhs_offset_block), + lhs_offset_block, &acc); + executor.Execute(acc, dst, src_global_row, src_global_col, dst_row, dst_col); +} + +template <typename KernelFormat, typename ResultBlockType, + typename PackedResultType, typename LhsOffset, typename RhsOffset, + typename OutputPipelineType> +void UnpackResult(ResultBlockType* dst, const MatrixBlockBounds& dst_block, + const PackedResultType& src, int depth, + const std::int32_t* lhs_sums_of_each_slice_ptr, + const std::int32_t* rhs_sums_of_each_slice_ptr, + const LhsOffset& lhs_offset, const RhsOffset& rhs_offset, + const OutputPipelineType& output_pipeline) { + ScopedProfilingLabel label(ResultBlockType::kOrder == MapOrder::ColMajor + ? "unpack to column-major" + : "unpack to row-major"); + assert(dst_block.start_row >= 0); + assert(dst_block.start_row + dst_block.rows <= dst->rows()); + assert(dst_block.start_col >= 0); + assert(dst_block.start_col + dst_block.cols <= dst->cols()); + const auto src_map = src.Map(); + const VectorMap<const std::int32_t, VectorShape::Col> lhs_sums_of_each_slice( + lhs_sums_of_each_slice_ptr, dst_block.rows); + const VectorMap<const std::int32_t, VectorShape::Row> rhs_sums_of_each_slice( + rhs_sums_of_each_slice_ptr, dst_block.cols); + using Int32x1x1 = RegisterBlock<std::int32_t, 1, 1>; + using Int32x4x1 = RegisterBlock<std::int32_t, 4, 1>; + using Int32x8x1 = RegisterBlock<std::int32_t, 8, 1>; + using Int32x1x4 = RegisterBlock<std::int32_t, 1, 4>; + using Int32x4x4 = RegisterBlock<std::int32_t, 4, 4>; + using Int32x8x4 = RegisterBlock<std::int32_t, 8, 4>; + + using DstScalarType = typename ResultBlockType::Scalar; + using DstScalarx8x8 = RegisterBlock<DstScalarType, 8, 8>; + + OutputPipelineExecutor<OutputPipelineType, Int32x1x1> + output_pipeline_executor_1x1(output_pipeline); + OutputPipelineExecutor<OutputPipelineType, Int32x4x1> + output_pipeline_executor_4x1(output_pipeline); + OutputPipelineExecutor<OutputPipelineType, Int32x8x1> + output_pipeline_executor_8x1(output_pipeline); + OutputPipelineExecutor<OutputPipelineType, Int32x1x4> + output_pipeline_executor_1x4(output_pipeline); + OutputPipelineExecutor<OutputPipelineType, Int32x4x4> + output_pipeline_executor_4x4(output_pipeline); + OutputPipelineExecutor<OutputPipelineType, Int32x8x4> + output_pipeline_executor_8x4(output_pipeline); + + int c8 = 0; + if (ResultBlockType::kOrder == MapOrder::RowMajor) { + for (; c8 <= dst_block.cols - 8; c8 += 8) { + PrefetchResultBlock<8, 8>(src_map, lhs_sums_of_each_slice, 0, c8); + int r = 0; + for (; r <= dst_block.rows - 8; r += 8) { + const int global_row = r + dst_block.start_row; + PrefetchResultBlock<8, 8>(src_map, lhs_sums_of_each_slice, r + 8, c8); + DstScalarType dst_colmajor_buf[64]; + MatrixMap<DstScalarType, MapOrder::ColMajor> dst_colmajor_map( + dst_colmajor_buf, 8, 8); + for (int cx = 0; cx < 8; cx += 4) { + const int c = c8 + cx; + const int global_col = c + dst_block.start_col; + UnpackResultBlock<KernelFormat, Int32x8x4>( + src_map, output_pipeline_executor_8x4, &dst_colmajor_map, + lhs_sums_of_each_slice, rhs_sums_of_each_slice, lhs_offset, + rhs_offset, depth, r, c, global_row, global_col, 0, cx); + } + StoreFinalOutput(LoadContiguous<DstScalarx8x8>(dst_colmajor_buf), dst, + r + dst_block.start_row, c8 + dst_block.start_col); + } + for (; r <= dst_block.rows - 4; r += 4) { + const int global_row = r + dst_block.start_row; + for (int cx = 0; cx < 8; cx += 4) { + const int c = c8 + cx; + const int global_col = c + dst_block.start_col; + UnpackResultBlock<KernelFormat, Int32x4x4>( + src_map, output_pipeline_executor_4x4, dst, + lhs_sums_of_each_slice, rhs_sums_of_each_slice, lhs_offset, + rhs_offset, depth, r, c, global_row, global_col, global_row, + global_col); + } + } + for (; r < dst_block.rows; r++) { + const int global_row = r + dst_block.start_row; + for (int cx = 0; cx < 8; cx += 4) { + const int c = c8 + cx; + const int global_col = c + dst_block.start_col; + UnpackResultBlock<KernelFormat, Int32x1x4>( + src_map, output_pipeline_executor_1x4, dst, + lhs_sums_of_each_slice, rhs_sums_of_each_slice, lhs_offset, + rhs_offset, depth, r, c, global_row, global_col, global_row, + global_col); + } + } + } + } + int c = c8; + for (; c <= dst_block.cols - 4; c += 4) { + const int global_col = c + dst_block.start_col; + PrefetchResultBlock<8, 4>(src_map, lhs_sums_of_each_slice, 0, c); + int r = 0; + for (; r <= dst_block.rows - 8; r += 8) { + const int global_row = r + dst_block.start_row; + PrefetchResultBlock<8, 4>(src_map, lhs_sums_of_each_slice, r + 8, c); + UnpackResultBlock<KernelFormat, Int32x8x4>( + src_map, output_pipeline_executor_8x4, dst, lhs_sums_of_each_slice, + rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c, + global_row, global_col, global_row, global_col); + } + for (; r <= dst_block.rows - 4; r += 4) { + const int global_row = r + dst_block.start_row; + UnpackResultBlock<KernelFormat, Int32x4x4>( + src_map, output_pipeline_executor_4x4, dst, lhs_sums_of_each_slice, + rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c, + global_row, global_col, global_row, global_col); + } + for (; r < dst_block.rows; r++) { + const int global_row = r + dst_block.start_row; + UnpackResultBlock<KernelFormat, Int32x1x4>( + src_map, output_pipeline_executor_1x4, dst, lhs_sums_of_each_slice, + rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c, + global_row, global_col, global_row, global_col); + } + } + for (; c < dst_block.cols; c++) { + const int global_col = c + dst_block.start_col; + PrefetchResultBlock<8, 1>(src_map, lhs_sums_of_each_slice, 0, c); + int r = 0; + for (; r <= dst_block.rows - 8; r += 8) { + const int global_row = r + dst_block.start_row; + PrefetchResultBlock<8, 1>(src_map, lhs_sums_of_each_slice, r + 8, c); + UnpackResultBlock<KernelFormat, Int32x8x1>( + src_map, output_pipeline_executor_8x1, dst, lhs_sums_of_each_slice, + rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c, + global_row, global_col, global_row, global_col); + } + for (; r <= dst_block.rows - 4; r += 4) { + const int global_row = r + dst_block.start_row; + UnpackResultBlock<KernelFormat, Int32x4x1>( + src_map, output_pipeline_executor_4x1, dst, lhs_sums_of_each_slice, + rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c, + global_row, global_col, global_row, global_col); + } + for (; r < dst_block.rows; r++) { + const int global_row = r + dst_block.start_row; + UnpackResultBlock<KernelFormat, Int32x1x1>( + src_map, output_pipeline_executor_1x1, dst, lhs_sums_of_each_slice, + rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c, + global_row, global_col, global_row, global_col); + } + } +} + +} // end namespace gemmlowp + +#endif // GEMMLOWP_INTERNAL_UNPACK_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/profiling/instrumentation.h b/runtimes/nn/depend/external/gemmlowp/profiling/instrumentation.h new file mode 100644 index 000000000..51b652590 --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/profiling/instrumentation.h @@ -0,0 +1,244 @@ +// Copyright 2015 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// instrumentation.h: contains the definitions needed to +// instrument code for profiling: +// ScopedProfilingLabel, RegisterCurrentThreadForProfiling. +// +// profiler.h is only needed to drive the profiler: +// StartProfiling, FinishProfiling. +// +// See the usage example in profiler.h. + +#ifndef GEMMLOWP_PROFILING_INSTRUMENTATION_H_ +#define GEMMLOWP_PROFILING_INSTRUMENTATION_H_ + +#include <pthread.h> +#include <cstdio> + +#ifndef GEMMLOWP_USE_STLPORT +#include <cstdint> +#else +#include <stdint.h> +namespace std { +using ::uint8_t; +using ::uint16_t; +using ::uint32_t; +using ::int8_t; +using ::int16_t; +using ::int32_t; +using ::size_t; +using ::uintptr_t; +} +#endif + +#include <algorithm> +#include <cassert> +#include <cstdlib> + +#ifdef GEMMLOWP_PROFILING +#include <cstring> +#include <set> +#endif + +// We should always use C++11 thread_local; unfortunately that +// isn't fully supported on Apple yet. +#ifdef __APPLE__ +#define GEMMLOWP_THREAD_LOCAL static __thread +#define GEMMLOWP_USING_OLD_THREAD_LOCAL +#else +#define GEMMLOWP_THREAD_LOCAL thread_local +#endif + +namespace gemmlowp { + +inline void ReleaseBuildAssertion(bool condition, const char* msg) { + if (!condition) { + fprintf(stderr, "gemmlowp error: %s\n", msg); + abort(); + } +} + +// To be used as template parameter for GlobalLock. +// GlobalLock<ProfilerLockId> is the profiler global lock: +// registering threads, starting profiling, finishing profiling, and +// the profiler itself as it samples threads, all need to lock it. +struct ProfilerLockId; + +// A very plain global lock. Templated in LockId so we can have multiple +// locks, one for each LockId type. +template <typename LockId> +class GlobalLock { + static pthread_mutex_t* Mutex() { + static pthread_mutex_t m = PTHREAD_MUTEX_INITIALIZER; + return &m; + } + + public: + static void Lock() { pthread_mutex_lock(Mutex()); } + static void Unlock() { pthread_mutex_unlock(Mutex()); } +}; + +// A very simple RAII helper to lock and unlock a GlobalLock +template <typename LockId> +struct AutoGlobalLock { + AutoGlobalLock() { GlobalLock<LockId>::Lock(); } + ~AutoGlobalLock() { GlobalLock<LockId>::Unlock(); } +}; + +// MemoryBarrier is purely a compile-time thing; it tells two things +// to the compiler: +// 1) It prevents reordering code across it +// (thanks to the 'volatile' after 'asm') +// 2) It requires the compiler to assume that any value previously +// read from memory, may have changed. Thus it offers an alternative +// to using 'volatile' variables. +inline void MemoryBarrier() { asm volatile("" ::: "memory"); } + +// Profiling definitions. Two paths: when profiling is enabled, +// and when profiling is disabled. +#ifdef GEMMLOWP_PROFILING +// This code path is when profiling is enabled. + +// A pseudo-call-stack. Contrary to a real call-stack, this only +// contains pointers to literal strings that were manually entered +// in the instrumented code (see ScopedProfilingLabel). +struct ProfilingStack { + static const std::size_t kMaxSize = 15; + typedef const char* LabelsArrayType[kMaxSize]; + LabelsArrayType labels; + std::size_t size; + + ProfilingStack() { memset(this, 0, sizeof(ProfilingStack)); } + + void Push(const char* label) { + MemoryBarrier(); + ReleaseBuildAssertion(size < kMaxSize, "ProfilingStack overflow"); + labels[size] = label; + MemoryBarrier(); + size++; + MemoryBarrier(); + } + + void Pop() { + MemoryBarrier(); + ReleaseBuildAssertion(size > 0, "ProfilingStack underflow"); + size--; + MemoryBarrier(); + } + + void UpdateTop(const char* new_label) { + MemoryBarrier(); + assert(size); + labels[size - 1] = new_label; + MemoryBarrier(); + } + + ProfilingStack& operator=(const ProfilingStack& other) { + memcpy(this, &other, sizeof(ProfilingStack)); + return *this; + } + + bool operator==(const ProfilingStack& other) const { + return !memcmp(this, &other, sizeof(ProfilingStack)); + } +}; + +static_assert( + !(sizeof(ProfilingStack) & (sizeof(ProfilingStack) - 1)), + "ProfilingStack should have power-of-two size to fit in cache lines"); + +struct ThreadInfo; + +// The global set of threads being profiled. +inline std::set<ThreadInfo*>& ThreadsUnderProfiling() { + static std::set<ThreadInfo*> v; + return v; +} + +struct ThreadInfo { + pthread_key_t key; // used only to get a callback at thread exit. + ProfilingStack stack; + + ThreadInfo() { + pthread_key_create(&key, ThreadExitCallback); + pthread_setspecific(key, this); + } + + static void ThreadExitCallback(void* ptr) { + AutoGlobalLock<ProfilerLockId> lock; + ThreadInfo* self = static_cast<ThreadInfo*>(ptr); + ThreadsUnderProfiling().erase(self); + pthread_key_delete(self->key); + } +}; + +inline ThreadInfo& ThreadLocalThreadInfo() { +#ifdef GEMMLOWP_USING_OLD_THREAD_LOCAL + // We're leaking this ThreadInfo structure, because Apple doesn't support + // non-trivial constructors or destructors for their __thread type modifier. + GEMMLOWP_THREAD_LOCAL ThreadInfo* i = nullptr; + if (i == nullptr) { + i = new ThreadInfo(); + } + return *i; +#else + GEMMLOWP_THREAD_LOCAL ThreadInfo i; + return i; +#endif +} + +// ScopedProfilingLabel is how one instruments code for profiling +// with this profiler. Construct local ScopedProfilingLabel variables, +// passing a literal string describing the local code. Profile +// samples will then be annotated with this label, while it is in scope +// (whence the name --- also known as RAII). +// See the example in profiler.h. +class ScopedProfilingLabel { + ProfilingStack* profiling_stack_; + + public: + explicit ScopedProfilingLabel(const char* label) + : profiling_stack_(&ThreadLocalThreadInfo().stack) { + profiling_stack_->Push(label); + } + + ~ScopedProfilingLabel() { profiling_stack_->Pop(); } + + void Update(const char* new_label) { profiling_stack_->UpdateTop(new_label); } +}; + +// To be called once on each thread to be profiled. +inline void RegisterCurrentThreadForProfiling() { + AutoGlobalLock<ProfilerLockId> lock; + ThreadsUnderProfiling().insert(&ThreadLocalThreadInfo()); +} + +#else // not GEMMLOWP_PROFILING +// This code path is when profiling is disabled. + +// This empty definition of ScopedProfilingLabel ensures that +// it has zero runtime overhead when profiling is disabled. +struct ScopedProfilingLabel { + explicit ScopedProfilingLabel(const char*) {} + void Update(const char*) {} +}; + +inline void RegisterCurrentThreadForProfiling() {} + +#endif + +} // end namespace gemmlowp + +#endif // GEMMLOWP_PROFILING_INSTRUMENTATION_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/profiling/profiler.h b/runtimes/nn/depend/external/gemmlowp/profiling/profiler.h new file mode 100644 index 000000000..a18c036c8 --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/profiling/profiler.h @@ -0,0 +1,373 @@ +// Copyright 2015 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// profiler.h: a simple sampling profiler that's always just one #include away! +// +// Overview +// ======== +// +// This profiler only samples a pseudo-stack, not the actual call stack. +// The code to be profiled needs to be instrumented with +// pseudo-stack "labels", see ScopedProfilingLabel. +// Using pseudo-stacks allows this profiler to be very simple, low-overhead, +// portable, and independent of compilation details such as function inlining +// and frame pointers. The granularity of instrumentation can be freely chosen, +// and it is possible to get some annotate-like detail, i.e. detail within one +// function without splitting it into multiple functions. +// +// This profiler should remain small and simple; its key feature is to fit in +// a single header file so that there should never be a reason to refrain +// from profiling. More complex and feature-rich alternatives are +// readily available. This one offers a strict superset of its +// functionality: https://github.com/bgirard/GeckoProfiler, including +// intertwining pseudostacks with real call stacks, more annotation options, +// and advanced visualization. +// +// Usage +// ===== +// +// 0. Enable profiling by defining GEMMLOWP_PROFILING. When profiling is +// not enabled, profiling instrumentation from instrumentation.h +// (ScopedProfilingLabel, RegisterCurrentThreadForProfiling) +// is still defined but does nothing. On the other hand, +// when profiling is not enabled, it is an error to #include the +// present file. +// +// 1. Each thread can opt in to profiling by calling +// RegisterCurrentThreadForProfiling() defined in instrumentation.h. +// This can be done at any time, before or during profiling. +// No sample will be collected from a thread until +// it has called RegisterCurrentThreadForProfiling(). +// +// 2. Instrument your code to be profiled with ScopedProfilingLabel, +// which is a RAII helper defined in instrumentation.h. The identifier +// names (some_label, etc) do not matter; what will show up +// in the profile is the string passed to the constructor, which +// must be a literal string. See the full example below. +// +// Note: the overhead of ScopedProfilingLabel is zero when not +// enabling profiling (when not defining GEMMLOWP_PROFILING). +// +// 3. Use the profiler.h interface to control profiling. There are two +// functions: StartProfiling() and FinishProfiling(). They must be +// called on the same thread. FinishProfiling() prints the profile +// on stdout. +// +// Full example +// ============ +/* + #define GEMMLOWP_PROFILING + #include "profiling/instrumentation.h" + using namespace gemmlowp; + + const int iters = 100000000; + volatile int i; + + void Bar() { + ScopedProfilingLabel label("Bar"); + for (i = 0; i < iters; i++) {} + } + + void Foo() { + ScopedProfilingLabel label("Foo"); + for (i = 0; i < iters; i++) {} + Bar(); + } + + void Init() { + RegisterCurrentThreadForProfiling(); + } + + #include "profiling/profiler.h" + + int main() { + Init(); + StartProfiling(); + Foo(); + FinishProfiling(); + } +* +* Output: +* + gemmlowp profile (1 threads, 304 samples) + 100.00% Foo + 51.32% other + 48.68% Bar + 0.00% other (outside of any label) +*/ +// +// Interpreting results +// ==================== +// +// Each node shows the absolute percentage, among all the samples, +// of the number of samples that recorded the given pseudo-stack. +// The percentages are *NOT* relative to the parent node. In addition +// to your own labels, you will also see 'other' nodes that collect +// the remainder of samples under the parent node that didn't fall into +// any of the labelled child nodes. Example: +// +// 20% Foo +// 12% Bar +// 6% Xyz +// 2% other +// +// This means that 20% of all labels were under Foo, of which 12%/20%==60% +// were under Bar, 6%/20%==30% were under Xyz, and 2%/20%==10% were not +// under either Bar or Xyz. +// +// Typically, one wants to keep adding ScopedProfilingLabel's until +// the 'other' nodes show low percentages. +// +// Interpreting results with multiple threads +// ========================================== +// +// At each sample, each thread registered for profiling gets sampled once. +// So if there is one "main thread" spending its time in MainFunc() and +// 4 "worker threads" spending time in WorkerFunc(), then 80% (=4/5) of the +// samples will be in WorkerFunc, so the profile will look like this: +// +// 80% WorkerFunc +// 20% MainFunc + +#ifndef GEMMLOWP_PROFILING_PROFILER_H_ +#define GEMMLOWP_PROFILING_PROFILER_H_ + +#ifndef GEMMLOWP_PROFILING +#error Profiling is not enabled! +#endif + +#include <vector> + +#include "instrumentation.h" + +namespace gemmlowp { + +// A tree view of a profile. +class ProfileTreeView { + struct Node { + std::vector<Node*> children; + const char* label; + std::size_t weight; + Node() : label(nullptr), weight(0) {} + ~Node() { + for (auto child : children) { + delete child; + } + } + }; + + static bool CompareNodes(Node* n1, Node* n2) { + return n1->weight > n2->weight; + } + + Node root_; + + void PrintNode(const Node* node, int level) const { + if (level) { + for (int i = 1; i < level; i++) { + printf(" "); + } + printf("%.2f%% %s\n", 100.0f * node->weight / root_.weight, node->label); + } + for (auto child : node->children) { + PrintNode(child, level + 1); + } + } + + static void AddStackToNode(const ProfilingStack& stack, Node* node, + std::size_t level) { + node->weight++; + if (stack.size == level) { + return; + } + Node* child_to_add_to = nullptr; + for (auto child : node->children) { + if (child->label == stack.labels[level]) { + child_to_add_to = child; + break; + } + } + if (!child_to_add_to) { + child_to_add_to = new Node; + child_to_add_to->label = stack.labels[level]; + node->children.push_back(child_to_add_to); + } + AddStackToNode(stack, child_to_add_to, level + 1); + return; + } + + void AddStack(const ProfilingStack& stack) { + AddStackToNode(stack, &root_, 0); + } + + void AddOtherChildrenToNode(Node* node) { + std::size_t top_level_children_weight = 0; + for (auto c : node->children) { + AddOtherChildrenToNode(c); + top_level_children_weight += c->weight; + } + if (top_level_children_weight) { + Node* other_child = new Node; + other_child->label = + node == &root_ ? "other (outside of any label)" : "other"; + other_child->weight = node->weight - top_level_children_weight; + node->children.push_back(other_child); + } + } + + void AddOtherNodes() { AddOtherChildrenToNode(&root_); } + + void SortNode(Node* node) { + std::sort(node->children.begin(), node->children.end(), CompareNodes); + for (auto child : node->children) { + SortNode(child); + } + } + + void Sort() { SortNode(&root_); } + + public: + explicit ProfileTreeView(const std::vector<ProfilingStack>& stacks) { + for (auto stack : stacks) { + AddStack(stack); + } + AddOtherNodes(); + Sort(); + } + + void Print() const { + printf("\n"); + printf("gemmlowp profile (%d threads, %d samples)\n", + static_cast<int>(ThreadsUnderProfiling().size()), + static_cast<int>(root_.weight)); + PrintNode(&root_, 0); + printf("\n"); + } +}; + +// This function is the only place that determines our sampling frequency. +inline void WaitOneProfilerTick() { + static const int millisecond = 1000000; + +#if defined __arm__ || defined __aarch64__ + // Reduced sampling frequency on mobile devices helps limit time and memory + // overhead there. + static const int interval = 10 * millisecond; +#else + static const int interval = 1 * millisecond; +#endif + + timespec ts; + ts.tv_sec = 0; + ts.tv_nsec = interval; + nanosleep(&ts, nullptr); +} + +// This is how we track whether we've already started profiling, +// to guard against misuse of the API. +inline bool& IsProfiling() { + static bool b; + return b; +} + +// This is how we tell the profiler thread to finish. +inline bool& ProfilerThreadShouldFinish() { + static bool b; + return b; +} + +// The profiler thread. See ProfilerThreadFunc. +inline pthread_t& ProfilerThread() { + static pthread_t t; + return t; +} + +// Records a stack from a running thread. +// The tricky part is that we're not interrupting the thread. +// This is OK because we're looking at a pseudo-stack of labels, +// not at the real thread stack, and if the pseudo-stack changes +// while we're recording it, we are OK with getting either the +// old or the new stack. Note that ProfilingStack::Pop +// only decrements the size, and doesn't null the popped label, +// so if we're concurrently recording it, it shouldn't change +// under our feet until another label is pushed, at which point +// we are OK with getting either this new label or the old one. +// In the end, the key atomicity property that we are relying on +// here is that pointers are changed atomically, and the labels +// are pointers (to literal strings). +inline void RecordStack(const ThreadInfo* thread, ProfilingStack* dst) { + assert(!dst->size); + while (dst->size < thread->stack.size) { + dst->labels[dst->size] = thread->stack.labels[dst->size]; + dst->size++; + MemoryBarrier(); // thread->stack can change at any time + } +} + +// The profiler thread's entry point. +// Note that a separate thread is to be started each time we call +// StartProfiling(), and finishes when we call FinishProfiling(). +// So here we only need to handle the recording and reporting of +// a single profile. +inline void* ProfilerThreadFunc(void*) { + assert(ProfilerThread() == pthread_self()); + + // Since we only handle one profile per profiler thread, the + // profile data (the array of recorded stacks) can be a local variable here. + std::vector<ProfilingStack> stacks; + + while (!ProfilerThreadShouldFinish()) { + WaitOneProfilerTick(); + { + AutoGlobalLock<ProfilerLockId> lock; + for (auto t : ThreadsUnderProfiling()) { + ProfilingStack s; + RecordStack(t, &s); + stacks.push_back(s); + } + } + } + + // Profiling is finished and we now report the results. + ProfileTreeView(stacks).Print(); + + return nullptr; +} + +// Starts recording samples. +inline void StartProfiling() { + AutoGlobalLock<ProfilerLockId> lock; + ReleaseBuildAssertion(!IsProfiling(), "We're already profiling!"); + IsProfiling() = true; + ProfilerThreadShouldFinish() = false; + pthread_create(&ProfilerThread(), nullptr, ProfilerThreadFunc, nullptr); +} + +// Stops recording samples, and prints a profile tree-view on stdout. +inline void FinishProfiling() { + { + AutoGlobalLock<ProfilerLockId> lock; + ReleaseBuildAssertion(IsProfiling(), "We weren't profiling!"); + // The ProfilerThreadShouldFinish() mechanism here is really naive and bad, + // as the scary comments below should make clear. + // Should we use a condition variable? + ProfilerThreadShouldFinish() = true; + } // must release the lock here to avoid deadlock with profiler thread. + pthread_join(ProfilerThread(), nullptr); + IsProfiling() = false; // yikes, this should be guarded by the lock! +} + +} // namespace gemmlowp + +#endif // GEMMLOWP_PROFILING_PROFILER_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/public/bit_depth.h b/runtimes/nn/depend/external/gemmlowp/public/bit_depth.h new file mode 100644 index 000000000..6cb4ecf0d --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/public/bit_depth.h @@ -0,0 +1,62 @@ +// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// bit_depth.h: defines the settins controlling LHS/RHS bit depth + +#ifndef GEMMLOWP_PUBLIC_BIT_DEPTH_H_ +#define GEMMLOWP_PUBLIC_BIT_DEPTH_H_ + +namespace gemmlowp { + +// The range of allowed values for an operand. +template <int tMinValue, int tMaxValue> +struct OperandRange { + static const int kMinValue = tMinValue; + static const int kMaxValue = tMaxValue; + static_assert(0 <= kMinValue, ""); + static_assert(kMinValue < kMaxValue, ""); + static_assert(kMaxValue <= 255, ""); +}; + +using Uint8Range = OperandRange<0, 255>; +using Uint8RangeExcludingZero = OperandRange<1, 255>; + +template <typename tLhsRange, typename tRhsRange> +struct BitDepthParams { + using LhsRange = tLhsRange; + using RhsRange = tRhsRange; +}; + +// Default: LHS and RHS are 8bit. +using DefaultL8R8BitDepthParams = BitDepthParams<Uint8Range, Uint8Range>; + +// Variant: LHS may not take the value 0. This allows using +// faster kernels using signed arithmetic, see +// NEON_64bit_GEMM_Int8Operands_Int32Accumulators_AccumTwoWithin16Bits +using L8R8WithLhsNonzeroBitDepthParams = + BitDepthParams<Uint8RangeExcludingZero, Uint8Range>; + +// Deprecated: when gemmlowp used to allow requantizing 8bit +// inputs to less-than-8-bit depths, the public setting allowing +// that was DefaultL7R5BitDepthParams. That requantization +// feature has been removed, but as the whole point of that +// requantization was to make less-than-8-bit an internal +// optimization without any impact on the API (other than lowering +// accuracy), we can temporarily support users who were using it +// by mapping it to the default 8bit behavior. +using DefaultL7R5BitDepthParams = DefaultL8R8BitDepthParams; + +} // namespace gemmlowp + +#endif // GEMMLOWP_PUBLIC_BIT_DEPTH_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/public/gemmlowp.h b/runtimes/nn/depend/external/gemmlowp/public/gemmlowp.h new file mode 100644 index 000000000..05b0f4714 --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/public/gemmlowp.h @@ -0,0 +1,87 @@ +// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// gemmlowp.h: the main public interface header of gemmlowp. + +#ifndef GEMMLOWP_PUBLIC_GEMMLOWP_H_ +#define GEMMLOWP_PUBLIC_GEMMLOWP_H_ +#include "../internal/dispatch_gemm_shape.h" +#include "bit_depth.h" +#include "map.h" +#include "output_stages.h" + +namespace gemmlowp { + +class GemmContext : public MultiThreadGemmContext {}; + +// Computes a general matrix product ("GEMM"). +// This is a version that supports per channel quantization. +template <typename InputScalar, typename OutputScalar, typename BitDepthParams, + MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder, + typename LhsOffset, typename RhsOffset, typename OutputPipelineType, + typename GemmContextType> +void GemmWithOutputPipelinePC(GemmContextType* context, + const MatrixMap<const InputScalar, LhsOrder>& lhs, + const MatrixMap<const InputScalar, RhsOrder>& rhs, + MatrixMap<OutputScalar, ResultOrder>* result, + const LhsOffset& lhs_offset, + const RhsOffset& rhs_offset, + const OutputPipelineType& output_pipeline) { + DispatchGemmShape<InputScalar, OutputScalar, BitDepthParams>( + context, lhs, rhs, result, lhs_offset, rhs_offset, output_pipeline); +} + +// Computes a general matrix product ("GEMM"). +// This is the legacy version that does not support per channel quantization. +// The meaning of the offsets, result_mult_int and result_shift +// parameters is the same as in the standard EightBitIntGemm interface +// (which is also implemented in the eight_bit_int_gemm directory). +template <typename InputScalar, typename OutputScalar, typename BitDepthParams, + MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder, + typename OutputPipelineType, typename GemmContextType> +void GemmWithOutputPipeline(GemmContextType* context, + const MatrixMap<const InputScalar, LhsOrder>& lhs, + const MatrixMap<const InputScalar, RhsOrder>& rhs, + MatrixMap<OutputScalar, ResultOrder>* result, + int lhs_offset, int rhs_offset, + const OutputPipelineType& output_pipeline) { + typedef VectorDup<const std::int32_t, VectorShape::Col> OffsetColDup; + typedef VectorDup<const std::int32_t, VectorShape::Row> OffsetRowDup; + const OffsetColDup lhs_offset_vector(lhs_offset, lhs.rows()); + const OffsetRowDup rhs_offset_vector(rhs_offset, rhs.cols()); + DispatchGemmShape<InputScalar, OutputScalar, BitDepthParams>( + context, lhs, rhs, result, lhs_offset_vector, rhs_offset_vector, + output_pipeline); +} + +// Computes a general matrix product ("GEMM"). +// The meaning of the offsets, result_mult_int and result_shift +// parameters is the same as in the standard EightBitIntGemm interface +// (which is also implemented in the eight_bit_int_gemm directory). +template <typename Scalar, typename BitDepthParams, MapOrder LhsOrder, + MapOrder RhsOrder, MapOrder ResultOrder, typename GemmContextType> +void Gemm(GemmContextType* context, + const MatrixMap<const Scalar, LhsOrder>& lhs, + const MatrixMap<const Scalar, RhsOrder>& rhs, + MatrixMap<Scalar, ResultOrder>* result, int lhs_offset, + int rhs_offset, int result_offset, int result_mult_int, + int result_shift) { + GemmWithOutputPipeline<Scalar, Scalar, BitDepthParams>( + context, lhs, rhs, result, lhs_offset, rhs_offset, + MakeStandardOutputPipeline(result_offset, result_mult_int, result_shift)); +} + +} // namespace gemmlowp + +#endif // GEMMLOWP_PUBLIC_GEMMLOWP_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/public/map.h b/runtimes/nn/depend/external/gemmlowp/public/map.h new file mode 100644 index 000000000..3073e05f5 --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/public/map.h @@ -0,0 +1,140 @@ +// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// map.h: a minimalist view-existing-buffer-as-a-matrix class, +// which is how gemmlowp interfaces with external matrix data. + +#ifndef GEMMLOWP_PUBLIC_MAP_H_ +#define GEMMLOWP_PUBLIC_MAP_H_ + +#include "../internal/common.h" + +namespace gemmlowp { + +// The two storage orders allowed to map buffers as matrices: ColMajor +// means column-major, RowMajor means row-major. +enum class MapOrder { ColMajor, RowMajor }; + +// A MatrixMap is a view of an existing buffer as a matrix. It does not own +// the buffer. +template <typename tScalar, MapOrder tOrder> +class MatrixMap { + public: + typedef tScalar Scalar; + static const MapOrder kOrder = tOrder; + + protected: + Scalar* data_; // not owned. + int rows_, cols_, stride_; + + public: + MatrixMap() : data_(nullptr), rows_(0), cols_(0), stride_(0) {} + MatrixMap(Scalar* data, int rows, int cols) + : data_(data), + rows_(rows), + cols_(cols), + stride_(kOrder == MapOrder::ColMajor ? rows : cols) {} + MatrixMap(Scalar* data, int rows, int cols, int stride) + : data_(data), rows_(rows), cols_(cols), stride_(stride) {} + MatrixMap(const MatrixMap& other) + : data_(other.data_), + rows_(other.rows_), + cols_(other.cols_), + stride_(other.stride_) {} + + int rows() const { return rows_; } + int cols() const { return cols_; } + int stride() const { return stride_; } + int rows_stride() const { return kOrder == MapOrder::ColMajor ? 1 : stride_; } + int cols_stride() const { return kOrder == MapOrder::RowMajor ? 1 : stride_; } + Scalar* data() const { return data_; } + Scalar* data(int row, int col) const { + return data_ + row * rows_stride() + col * cols_stride(); + } + Scalar& operator()(int row, int col) const { return *data(row, col); } + + MatrixMap block(int start_row, int start_col, int block_rows, + int block_cols) const { + assert(start_row >= 0); + assert(start_row + block_rows <= rows_); + assert(start_col >= 0); + assert(start_col + block_cols <= cols_); + + return MatrixMap(data(start_row, start_col), block_rows, block_cols, + stride_); + } +}; + +enum class VectorShape { Col, Row }; + +// A VectorMap is a view of an existing buffer as a vector. It does not own +// the buffer. +template <typename tScalar, VectorShape tShape> +class VectorMap { + public: + typedef tScalar Scalar; + static const VectorShape kShape = tShape; + + protected: + Scalar* data_; // not owned. + int size_; + + public: + VectorMap() : data_(nullptr), size_(0) {} + VectorMap(Scalar* data, int size) : data_(data), size_(size) {} + VectorMap(const VectorMap& other) : data_(other.data_), size_(other.size_) {} + + int size() const { return size_; } + Scalar* data() const { return data_; } + Scalar* data(int index) const { return data_ + index; } + Scalar& operator()(int index) const { return *data(index); } + + VectorMap block(int start, int len) const { + assert(start >= 0); + assert(start + len <= size_); + + return VectorMap(data(start), len); + } +}; + +// A VectorDup is a (duplicated value) vector where all components are the same. +template <typename tScalar, VectorShape tShape> +class VectorDup { + public: + typedef tScalar Scalar; + static const VectorShape kShape = tShape; + + protected: + Scalar data_; + int size_; + + public: + VectorDup() : data_(0), size_(0) {} + VectorDup(Scalar data, int size) : data_(data), size_(size) {} + VectorDup(const VectorDup& other) : data_(other.data_), size_(other.size_) {} + + int size() const { return size_; } + Scalar& operator()(int) const { return data_; } + + VectorDup block(int start, int len) const { + assert(start >= 0); + assert(start + len <= size_); + + return VectorDup(data_, len); + } +}; + +} // namespace gemmlowp + +#endif // GEMMLOWP_PUBLIC_MAP_H_ diff --git a/runtimes/nn/depend/external/gemmlowp/public/output_stages.h b/runtimes/nn/depend/external/gemmlowp/public/output_stages.h new file mode 100644 index 000000000..23bcdc05f --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/public/output_stages.h @@ -0,0 +1,185 @@ +// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// output_stages.h: public definitions of the output stages that can +// be assembled into an output pipeline, to control how internal +// 32-bit accumulators are transformed to obtain the final uint8 +// result matrix entries. + +#ifndef GEMMLOWP_PUBLIC_OUTPUT_STAGES_H_ +#define GEMMLOWP_PUBLIC_OUTPUT_STAGES_H_ + +#include <tuple> + +#include "../internal/common.h" + +namespace gemmlowp { + +// This output stage takes int32 values and returns still int32 values, +// but "quantized down" to the uint8 scale; in other words, its output +// is typically what one would then clamp to [0..255] and cast to uint8 +// (see OutputStageSaturatingCastToUint8). +// +// This "quantization down" process depends on 3 parameters, +// result_offset, result_mult_int, result_shift, +// and the result is: +// ((input + result_offset) * result_mult_int + rounding) >> result_shift +// where +// rounding = (result_shift < 1) ? 0 : (1 << (result_shift - 1)); +struct OutputStageQuantizeDownInt32ToUint8Scale { + std::int32_t result_offset; + std::int32_t result_mult_int; + std::int32_t result_shift; +}; + +// This output stage takes int32 values and returns still int32 values, +// but "quantized down" to the uint8 scale; in other words, its output +// is typically what one would then clamp to [0..255] and cast to uint8 +// (see OutputStageSaturatingCastToUint8). +// +// This "quantization down" process depends on 3 parameters, +// result_offset, result_mult_int, result_shift, +// and the result is: +// ((input + result_offset) * result_mult_int + rounding) >> result_shift +// where +// rounding = (result_shift < 1) ? 0 : (1 << (result_shift - 1)); +// +// Difference from OutputStageQuantizeDownInt32ToUint8Scale here is that each +// row or column of the output (depending on tShape) has its own result_offset +// and result_mult_int numbers. +template <VectorShape tShape> +struct OutputStageQuantizeDownInt32ToUint8ScalePC { + VectorMap<const std::int32_t, tShape> result_offset; + VectorMap<const std::int32_t, tShape> result_mult_int; + std::int32_t result_shift; +}; + +// This output stage takes int32 values and returns still int32 values, +// but "quantized down" to the uint8 scale; in other words, its output +// is typically what one would then clamp to [0..255] and cast to uint8 +// (see OutputStageSaturatingCastToUint8). +// +// This "quantization down" process depends on 3 parameters, +// result_offset, result_fixedpoint_multiplier, result_shift, +// and the result is: +// ((FixedPointMul(input, result_fixedpoint_multiplier) + +// rounding) >> result_shift) + result_offset_after_shift +// where +// rounding = (result_shift < 1) ? 0 : (1 << (result_shift - 1)); +// and where FixedPointMul(x, y) is the nearest integer to the following +// mathematical expression, evaluated without overflow or intermediate +// rounding: +// (x * y) / 2^31 +// In practice, it is expected that FixedPointMul will be implemented +// using hardware "rounding doubling int32 multiply high" instructions, +// such as VQRDMULH on ARM. See in fixedpoint.h the generic function, +// SaturatingRoundingDoublingHighMul. +// +// Notice that the other difference from +// OutputStageQuantizeDownInt32ToUint8Scale is that the result offset +// is applied after the multiplier and shift, not before. This ensures +// that no matter what the multiplier and shift are, the result offset +// is effectively integral: offsetting the final result by an integer. +// The motivation for this is to faithfully support quantization schemes +// where the formula linking quantized values to the real mathematical +// values that they represent, is of the form +// +// real_value = scale * (quantized_value - zero_point) +// +// where scale is a real number (represented in quantized form by +// result_fixedpoint_multiplier and result_shift) and zero_point +// is an integer telling which quantized value correspond to the +// real value 0, and is represented here by (the opposite of) +// result_offset_after_shift. +// The motivation for such a quantization scheme, designed to +// ensure that 0 is always a representable value, is that in +// many applications, we need to 0-pad arrays and that can only be +// done for quantized arrays if 0 is a representable value in +// quantized form. In particular, convolution-like operations +// are often implemented using 0-padding, or "im2col"-like +// expansions that implicitly rely on 0-padding. If 0 were not +// a representable value, such operations would have to pad +// using a nonzero value, introducing bias in the computation. +struct OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint { + std::int32_t result_fixedpoint_multiplier; + std::int32_t result_shift; + std::int32_t result_offset_after_shift; +}; + +// This output stage takes int32 values that are expected to be already +// on the final uint8 scale, but not necessarily in the [0..255] range. +// It clamps them to the [0..255] range and returns them casted to uint8. +struct OutputStageSaturatingCastToUint8 {}; + +// This output stage depends on a "bias vector" that should contain int32 +// entries, and be either a row-vector of the same number of columns as the +// result matrix, or a column-vector of the same number of rows as the +// result matrix. This output stage takes int32 values and adds to them +// the corresponding entry of the bias vector (broadcasted in the other +// direction to fit the matrix's shape), outputting int32 values. +template <typename VectorType> +struct OutputStageBiasAddition { + VectorType bias_vector; +}; + +// This output stage clamps value between the specified min and max bounds. +// It can be used to implement "rectified linear unit" activation functions +// in neural networks. +struct OutputStageClamp { + std::int32_t min; + std::int32_t max; +}; + +struct OutputStageTanh { + std::int32_t real_zero_as_int32; + std::int32_t real_amplitude_as_int32; +}; + +// An output pipeline is just a std::tuple of output stages. +// This function generates a standard output pipeline consisting of two stages: +// OutputStageQuantizeDownInt32ToUint8Scale, OutputStageSaturatingCastToUint8. +inline std::tuple<OutputStageQuantizeDownInt32ToUint8Scale, + OutputStageSaturatingCastToUint8> +MakeStandardOutputPipeline(std::int32_t result_offset, + std::int32_t result_mult_int, + std::int32_t result_shift) { + OutputStageQuantizeDownInt32ToUint8Scale quantize_down_stage; + quantize_down_stage.result_offset = result_offset; + quantize_down_stage.result_mult_int = result_mult_int; + quantize_down_stage.result_shift = result_shift; + OutputStageSaturatingCastToUint8 saturating_cast_stage; + return std::make_tuple(quantize_down_stage, saturating_cast_stage); +} + +// An output pipeline is just a std::tuple of output stages. +// This function generates a standard output pipeline consisting of two stages: +// OutputStageQuantizeDownInt32ToUint8ScalePC, OutputStageSaturatingCastToUint8. +template <VectorShape tShape> +inline std::tuple<OutputStageQuantizeDownInt32ToUint8ScalePC<tShape>, + OutputStageSaturatingCastToUint8> +MakeStandardOutputPipeline( + const VectorMap<const std::int32_t, tShape>& result_offset, + const VectorMap<const std::int32_t, tShape>& result_mult_int, + std::int32_t result_shift) { + OutputStageQuantizeDownInt32ToUint8ScalePC<tShape> quantize_down_stage; + quantize_down_stage.result_offset = result_offset; + quantize_down_stage.result_mult_int = result_mult_int; + quantize_down_stage.result_shift = result_shift; + OutputStageSaturatingCastToUint8 saturating_cast_stage; + return std::make_tuple(quantize_down_stage, saturating_cast_stage); +} + +} // namespace gemmlowp + +#endif // GEMMLOWP_PUBLIC_OUTPUT_STAGES_H_ |