diff options
Diffstat (limited to 'runtimes/nn/depend/external/gemmlowp/internal/simd_wrappers.h')
-rw-r--r-- | runtimes/nn/depend/external/gemmlowp/internal/simd_wrappers.h | 508 |
1 files changed, 508 insertions, 0 deletions
diff --git a/runtimes/nn/depend/external/gemmlowp/internal/simd_wrappers.h b/runtimes/nn/depend/external/gemmlowp/internal/simd_wrappers.h new file mode 100644 index 000000000..e39eaf89f --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/internal/simd_wrappers.h @@ -0,0 +1,508 @@ +// Copyright 2017 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// simd_wrappers.h: some inline functions wrapping SIMD intrinsics, +// extending the set of such functions from fixedpoint.h. + +#ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_ +#define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_ + +#include <algorithm> +#include <type_traits> +#include "../fixedpoint/fixedpoint.h" + +namespace gemmlowp { + +template <typename ScalarType, int ScalarCount> +struct RegisterType { + using Type = ScalarType; +}; + +inline std::int32_t Min(std::int32_t a, std::int32_t b) { + return std::min(a, b); +} + +inline std::int32_t Max(std::int32_t a, std::int32_t b) { + return std::max(a, b); +} + +inline void MulAdd(std::int32_t lhs, std::int32_t rhs, std::int32_t* acc) { + *acc += lhs * rhs; +} + +template <typename tScalarType, int tScalarCount> +struct RegisterBuffer { + using ScalarType = tScalarType; + static constexpr int kScalarCount = tScalarCount; + using RegisterType = typename RegisterType<ScalarType, kScalarCount>::Type; + static_assert((kScalarCount & (kScalarCount - 1)) == 0, + "kScalarCount must be a power of two"); + static_assert(sizeof(RegisterType) % sizeof(ScalarType) == 0, ""); + static constexpr int kRegisterLanes = + sizeof(RegisterType) / sizeof(ScalarType); + static constexpr int kRegisterCount = + (kScalarCount * sizeof(ScalarType) + sizeof(RegisterType) - 1) / + sizeof(RegisterType); + + RegisterType reg[kRegisterCount]; +}; + +template <typename tScalarType, int tRows, int tCols> +struct RegisterBlock { + using ScalarType = tScalarType; + static constexpr int kRows = tRows; + static constexpr int kCols = tCols; + static constexpr int kScalarCount = kRows * kCols; + using BufferType = RegisterBuffer<ScalarType, kScalarCount>; + using RegisterType = typename BufferType::RegisterType; + static constexpr int kRegisterCount = BufferType::kRegisterCount; + static constexpr int kRegisterLanes = BufferType::kRegisterLanes; + + BufferType buf; +}; + +template <typename RegisterBlockType> +struct RegisterBlockAddImpl { + static RegisterBlockType Run(const RegisterBlockType& lhs, + const RegisterBlockType& rhs) { + RegisterBlockType result; + for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) { + result.buf.reg[i] = Add(lhs.buf.reg[i], rhs.buf.reg[i]); + } + return result; + } +}; + +template <typename RegisterBlockType> +RegisterBlockType RegisterBlockAdd(const RegisterBlockType& lhs, + const RegisterBlockType& rhs) { + return RegisterBlockAddImpl<RegisterBlockType>::Run(lhs, rhs); +} + +template <typename LhsType, typename RhsType> +struct ShouldFlipLhsRhs { + static constexpr bool kValue = + (LhsType::kScalarCount < RhsType::kScalarCount) || + (LhsType::kScalarCount == RhsType::kScalarCount && + (LhsType::kRows < RhsType::kRows)); +}; + +template <typename LhsType, typename RhsType, + bool Flip = ShouldFlipLhsRhs<LhsType, RhsType>::kValue> +struct FlipLhsRhs { + using FlippedLhsType = LhsType; + using FlippedRhsType = RhsType; + static const FlippedLhsType& FlippedLhs(const LhsType& lhs, + const RhsType& rhs) { + return lhs; + } + static const FlippedRhsType& FlippedRhs(const LhsType& lhs, + const RhsType& rhs) { + return rhs; + } +}; + +template <typename LhsType, typename RhsType> +struct FlipLhsRhs<LhsType, RhsType, true> { + using FlippedLhsType = RhsType; + using FlippedRhsType = LhsType; + static const FlippedLhsType& FlippedLhs(const LhsType& lhs, + const RhsType& rhs) { + return rhs; + } + static const FlippedRhsType& FlippedRhs(const LhsType& lhs, + const RhsType& rhs) { + return lhs; + } +}; + +template <typename Lhs, typename Rhs> +struct BroadcastBinaryOpShape { + static constexpr int kRows = + Lhs::kRows > Rhs::kRows ? Lhs::kRows : Rhs::kRows; + static constexpr int kCols = + Lhs::kCols > Rhs::kCols ? Lhs::kCols : Rhs::kCols; +}; + +template <typename Lhs, typename Rhs> +struct BroadcastBinaryOpRegisterBlock { + using Shape = BroadcastBinaryOpShape<Lhs, Rhs>; + using ScalarType = typename Lhs::ScalarType; + using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>; +}; + +template <typename Lhs, typename Rhs> +struct BroadcastAddImpl { + using ResultBlockType = + typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type; + static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) { + ResultBlockType result; + static constexpr int Rows = ResultBlockType::kRows; + static constexpr int Cols = ResultBlockType::kCols; + static constexpr int LhsRows = Lhs::kRows; + static constexpr int LhsCols = Lhs::kCols; + static constexpr int RhsRows = Rhs::kRows; + static constexpr int RhsCols = Rhs::kCols; + + static_assert(LhsRows == Rows || LhsRows == 1, ""); + static_assert(RhsRows == Rows || RhsRows == 1, ""); + static_assert(LhsCols == Cols || LhsCols == 1, ""); + static_assert(RhsCols == Cols || RhsCols == 1, ""); + static_assert(ResultBlockType::kRegisterLanes == 1, + "This path is only for scalar values"); + static_assert(Lhs::kRegisterLanes == 1, + "This path is only for scalar values"); + static_assert(Rhs::kRegisterLanes == 1, + "This path is only for scalar values"); + + for (int c = 0; c < Cols; c++) { + const int lhs_c = LhsCols == Cols ? c : 0; + const int rhs_c = RhsCols == Cols ? c : 0; + for (int r = 0; r < Rows; r++) { + const int lhs_r = LhsRows == Rows ? r : 0; + const int rhs_r = RhsRows == Rows ? r : 0; + result.buf.reg[r + c * Rows] = + Add(lhs.buf.reg[lhs_r + lhs_c * LhsRows], + rhs.buf.reg[rhs_r + rhs_c * RhsRows]); + } + } + return result; + } +}; + +template <typename Lhs, typename Rhs> +typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastAdd( + const Lhs& lhs, const Rhs& rhs) { + using Flip = FlipLhsRhs<Lhs, Rhs>; + return BroadcastAddImpl< + typename Flip::FlippedLhsType, + typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs), + Flip::FlippedRhs(lhs, rhs)); +} + +template <typename Lhs, typename Rhs> +struct BroadcastMulImpl { + using ResultBlockType = + typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type; + static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) { + ResultBlockType result; + static constexpr int Rows = ResultBlockType::kRows; + static constexpr int Cols = ResultBlockType::kCols; + static constexpr int LhsRows = Lhs::kRows; + static constexpr int LhsCols = Lhs::kCols; + static constexpr int RhsRows = Rhs::kRows; + static constexpr int RhsCols = Rhs::kCols; + static_assert(ResultBlockType::kRegisterLanes == 1, + "This path is only for scalar values"); + static_assert(Lhs::kRegisterLanes == 1, + "This path is only for scalar values"); + static_assert(Rhs::kRegisterLanes == 1, + "This path is only for scalar values"); + + static_assert(LhsRows == Rows || LhsRows == 1, ""); + static_assert(RhsRows == Rows || RhsRows == 1, ""); + static_assert(LhsCols == Cols || LhsCols == 1, ""); + static_assert(RhsCols == Cols || RhsCols == 1, ""); + for (int c = 0; c < Cols; c++) { + const int lhs_c = LhsCols == Cols ? c : 0; + const int rhs_c = RhsCols == Cols ? c : 0; + for (int r = 0; r < Rows; r++) { + const int lhs_r = LhsRows == Rows ? r : 0; + const int rhs_r = RhsRows == Rows ? r : 0; + result.buf.reg[r + c * Rows] = + Mul(lhs.buf.reg[lhs_r + lhs_c * LhsRows], + rhs.buf.reg[rhs_r + rhs_c * RhsRows]); + } + } + return result; + } +}; + +template <typename Lhs, typename Rhs> +typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastMul( + const Lhs& lhs, const Rhs& rhs) { + using Flip = FlipLhsRhs<Lhs, Rhs>; + return BroadcastMulImpl< + typename Flip::FlippedLhsType, + typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs), + Flip::FlippedRhs(lhs, rhs)); +} + +template <typename Lhs, typename Rhs, typename Acc> +struct BroadcastMulAddImpl { + static void Run(const Lhs& lhs, const Rhs& rhs, Acc* acc) { + static constexpr int Rows = Acc::kRows; + static constexpr int Cols = Acc::kCols; + static constexpr int LhsRows = Lhs::kRows; + static constexpr int LhsCols = Lhs::kCols; + static constexpr int RhsRows = Rhs::kRows; + static constexpr int RhsCols = Rhs::kCols; + static_assert(Acc::kRegisterLanes == 1, + "This path is only for scalar values"); + static_assert(Lhs::kRegisterLanes == 1, + "This path is only for scalar values"); + static_assert(Rhs::kRegisterLanes == 1, + "This path is only for scalar values"); + + static_assert(LhsRows == Rows || LhsRows == 1, ""); + static_assert(RhsRows == Rows || RhsRows == 1, ""); + static_assert(LhsCols == Cols || LhsCols == 1, ""); + static_assert(RhsCols == Cols || RhsCols == 1, ""); + for (int c = 0; c < Cols; c++) { + const int lhs_c = LhsCols == Cols ? c : 0; + const int rhs_c = RhsCols == Cols ? c : 0; + for (int r = 0; r < Rows; r++) { + const int lhs_r = LhsRows == Rows ? r : 0; + const int rhs_r = RhsRows == Rows ? r : 0; + MulAdd(lhs.buf.reg[lhs_r + lhs_c * LhsRows], + rhs.buf.reg[rhs_r + rhs_c * RhsRows], + &acc->buf.reg[r + c * Rows]); + } + } + } +}; + +template <typename Lhs, typename Rhs, typename Acc> +void BroadcastMulAdd(const Lhs& lhs, const Rhs& rhs, Acc* acc) { + using Flip = FlipLhsRhs<Lhs, Rhs>; + BroadcastMulAddImpl<typename Flip::FlippedLhsType, + typename Flip::FlippedRhsType, + Acc>::Run(Flip::FlippedLhs(lhs, rhs), + Flip::FlippedRhs(lhs, rhs), acc); +} + +template <typename RegisterBlockType, typename SrcObjectType> +struct LoadImpl { + static_assert(std::is_same<SrcObjectType, void>::value, + "This generic impl should never be hit"); +}; + +template <typename ScalarType, int Rows, int Cols, typename SrcScalarType> +struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>, + MatrixMap<SrcScalarType, MapOrder::ColMajor>> { + using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; + using SrcObjectType = MatrixMap<SrcScalarType, MapOrder::ColMajor>; + static RegisterBlockType Run(const SrcObjectType& src, int row, int col) { + RegisterBlockType result; + int i = 0; + for (int c = 0; c < Cols; c++) { + const ScalarType* src_ptr = src.data(row, col + c); + for (int r = 0; r < Rows; r++) { + result.buf.reg[i++] = *src_ptr++; + } + } + return result; + } +}; + +template <typename ScalarType, int Rows, int Cols, typename SrcScalarType, + VectorShape Shape> +struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>, + VectorMap<SrcScalarType, Shape>> { + using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; + using SrcObjectType = VectorMap<SrcScalarType, Shape>; + static RegisterBlockType Run(const SrcObjectType& src, int pos) { + static_assert(Shape == VectorShape::Col || Rows == 1, ""); + static_assert(Shape == VectorShape::Row || Cols == 1, ""); + RegisterBlockType result; + for (int i = 0; i < Rows * Cols; i++) { + result.buf.reg[i] = src(pos + i); + } + return result; + } +}; + +template <typename ScalarType, int Rows, int Cols, typename SrcScalarType, + VectorShape Shape> +struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>, + VectorDup<SrcScalarType, Shape>> { + using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; + using SrcObjectType = VectorDup<SrcScalarType, Shape>; + static RegisterBlockType Run(const SrcObjectType& src, int) { + static_assert(Shape == VectorShape::Col || Rows == 1, ""); + static_assert(Shape == VectorShape::Row || Cols == 1, ""); + RegisterBlockType result; + for (int i = 0; i < Rows * Cols; i++) { + result.buf.reg[i] = src(0); + } + return result; + } +}; + +template <typename RegisterBlockType, typename SrcObjectType> +RegisterBlockType Load(const SrcObjectType& src, int row, int col) { + return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, row, col); +} + +template <typename RegisterBlockType, typename SrcObjectType> +RegisterBlockType Load(const SrcObjectType& src, int pos) { + return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, pos); +} + +template <typename RegisterBlockType> +struct LoadContiguousImpl { + using ScalarType = typename RegisterBlockType::ScalarType; + static_assert(RegisterBlockType::kRegisterLanes == 1, + "This path is only for scalar values"); + static RegisterBlockType Run(const ScalarType* src) { + RegisterBlockType result; + for (int i = 0; i < RegisterBlockType::kScalarCount; i++) { + result.buf.reg[i] = src[i]; + } + return result; + } +}; + +template <typename RegisterBlockType> +RegisterBlockType LoadContiguous( + const typename RegisterBlockType::ScalarType* src) { + return LoadContiguousImpl<RegisterBlockType>::Run(src); +} + +template <int BroadcastRows, int BroadcastCols, typename SrcObjectType> +struct LoadForBroadcastingShape {}; + +template <int BroadcastRows, int BroadcastCols, typename ScalarType, + VectorShape Shape> +struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols, + VectorMap<ScalarType, Shape>> { + static constexpr int kRows = Shape == VectorShape::Col ? BroadcastRows : 1; + static constexpr int kCols = Shape == VectorShape::Row ? BroadcastCols : 1; +}; + +template <int BroadcastRows, int BroadcastCols, typename ScalarType, + VectorShape Shape> +struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols, + VectorDup<ScalarType, Shape>> { + static constexpr int kRows = 1; + static constexpr int kCols = 1; +}; + +template <typename RegisterBlockType, typename SrcObjectType> +struct LoadForBroadcastingRegisterBlock { + using Shape = + LoadForBroadcastingShape<RegisterBlockType::kRows, + RegisterBlockType::kCols, SrcObjectType>; + using ScalarType = typename RegisterBlockType::ScalarType; + using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>; +}; + +template <typename RegisterBlockType, typename SrcObjectType> +struct LoadForBroadcastingImpl { + static_assert(std::is_same<SrcObjectType, void>::value, + "This generic impl should never be hit"); +}; + +template <typename ScalarType, int Rows, int Cols, typename SrcScalarType, + VectorShape Shape> +struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>, + VectorMap<SrcScalarType, Shape>> { + using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; + using SrcObjectType = VectorMap<SrcScalarType, Shape>; + using ResultBlockType = + typename LoadForBroadcastingRegisterBlock<RegisterBlockType, + SrcObjectType>::Type; + static_assert(ResultBlockType::kRegisterLanes == 1, + "This path is only for scalar values"); + static ResultBlockType Run(const SrcObjectType& src, int pos) { + ResultBlockType result; + for (int c = 0; c < ResultBlockType::kCols; c++) { + for (int r = 0; r < ResultBlockType::kRows; r++) { + const int i = Shape == VectorShape::Col ? r : c; + result.buf.reg[r + c * ResultBlockType::kRows] = src(pos + i); + } + } + return result; + } +}; + +template <typename ScalarType, int Rows, int Cols, typename SrcScalarType, + VectorShape Shape> +struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>, + VectorDup<SrcScalarType, Shape>> { + using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; + using SrcObjectType = VectorDup<SrcScalarType, Shape>; + using ResultBlockType = + typename LoadForBroadcastingRegisterBlock<RegisterBlockType, + SrcObjectType>::Type; + static_assert(ResultBlockType::kRegisterLanes == 1, + "This path is only for scalar values"); + static ResultBlockType Run(const SrcObjectType& src, int) { + ResultBlockType result; + for (int c = 0; c < ResultBlockType::kCols; c++) { + for (int r = 0; r < ResultBlockType::kRows; r++) { + result.buf.reg[r + c * ResultBlockType::kRows] = src(0); + } + } + return result; + } +}; + +template <typename RegisterBlockType, typename SrcObjectType> +typename LoadForBroadcastingRegisterBlock<RegisterBlockType, + SrcObjectType>::Type +LoadForBroadcasting(const SrcObjectType& src, int row, int col) { + return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run( + src, row, col); +} + +template <typename RegisterBlockType, typename SrcObjectType> +typename LoadForBroadcastingRegisterBlock<RegisterBlockType, + SrcObjectType>::Type +LoadForBroadcasting(const SrcObjectType& src, int pos) { + return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run(src, + pos); +} + +template <int ConstantValue, typename RegisterBlockType> +struct AddConstantImpl { + static void Run(RegisterBlockType* block) { + using RegisterType = typename RegisterBlockType::RegisterType; + const RegisterType dup = Dup<RegisterType>(ConstantValue); + for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) { + block->buf.reg[i] = Add(block->buf.reg[i], dup); + } + } +}; + +template <typename RegisterBlockType> +struct AddConstantImpl<0, RegisterBlockType> { + static void Run(RegisterBlockType*) { + // This is a no-op. + } +}; + +template <int ConstantValue, typename RegisterBlockType> +void AddConstant(RegisterBlockType* block) { + AddConstantImpl<ConstantValue, RegisterBlockType>::Run(block); +} + +template <int N> +using RegBufferInt32 = RegisterBuffer<std::int32_t, N>; +template <int N> +using RegBufferUint8 = RegisterBuffer<std::uint8_t, N>; +template <int R, int C> +using RegBlockInt32 = RegisterBlock<std::int32_t, R, C>; +template <int R, int C> +using RegBlockUint8 = RegisterBlock<std::uint8_t, R, C>; + +} // end namespace gemmlowp + +#if defined GEMMLOWP_NEON +#include "simd_wrappers_neon.h" +#elif defined GEMMLOWP_SSE4 +#include "simd_wrappers_sse.h" +#endif + +#endif // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_ |