summaryrefslogtreecommitdiff
path: root/runtimes/nn/depend/external/gemmlowp/internal/simd_wrappers.h
diff options
context:
space:
mode:
Diffstat (limited to 'runtimes/nn/depend/external/gemmlowp/internal/simd_wrappers.h')
-rw-r--r--runtimes/nn/depend/external/gemmlowp/internal/simd_wrappers.h508
1 files changed, 508 insertions, 0 deletions
diff --git a/runtimes/nn/depend/external/gemmlowp/internal/simd_wrappers.h b/runtimes/nn/depend/external/gemmlowp/internal/simd_wrappers.h
new file mode 100644
index 000000000..e39eaf89f
--- /dev/null
+++ b/runtimes/nn/depend/external/gemmlowp/internal/simd_wrappers.h
@@ -0,0 +1,508 @@
+// Copyright 2017 The Gemmlowp Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// simd_wrappers.h: some inline functions wrapping SIMD intrinsics,
+// extending the set of such functions from fixedpoint.h.
+
+#ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
+#define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
+
+#include <algorithm>
+#include <type_traits>
+#include "../fixedpoint/fixedpoint.h"
+
+namespace gemmlowp {
+
+template <typename ScalarType, int ScalarCount>
+struct RegisterType {
+ using Type = ScalarType;
+};
+
+inline std::int32_t Min(std::int32_t a, std::int32_t b) {
+ return std::min(a, b);
+}
+
+inline std::int32_t Max(std::int32_t a, std::int32_t b) {
+ return std::max(a, b);
+}
+
+inline void MulAdd(std::int32_t lhs, std::int32_t rhs, std::int32_t* acc) {
+ *acc += lhs * rhs;
+}
+
+template <typename tScalarType, int tScalarCount>
+struct RegisterBuffer {
+ using ScalarType = tScalarType;
+ static constexpr int kScalarCount = tScalarCount;
+ using RegisterType = typename RegisterType<ScalarType, kScalarCount>::Type;
+ static_assert((kScalarCount & (kScalarCount - 1)) == 0,
+ "kScalarCount must be a power of two");
+ static_assert(sizeof(RegisterType) % sizeof(ScalarType) == 0, "");
+ static constexpr int kRegisterLanes =
+ sizeof(RegisterType) / sizeof(ScalarType);
+ static constexpr int kRegisterCount =
+ (kScalarCount * sizeof(ScalarType) + sizeof(RegisterType) - 1) /
+ sizeof(RegisterType);
+
+ RegisterType reg[kRegisterCount];
+};
+
+template <typename tScalarType, int tRows, int tCols>
+struct RegisterBlock {
+ using ScalarType = tScalarType;
+ static constexpr int kRows = tRows;
+ static constexpr int kCols = tCols;
+ static constexpr int kScalarCount = kRows * kCols;
+ using BufferType = RegisterBuffer<ScalarType, kScalarCount>;
+ using RegisterType = typename BufferType::RegisterType;
+ static constexpr int kRegisterCount = BufferType::kRegisterCount;
+ static constexpr int kRegisterLanes = BufferType::kRegisterLanes;
+
+ BufferType buf;
+};
+
+template <typename RegisterBlockType>
+struct RegisterBlockAddImpl {
+ static RegisterBlockType Run(const RegisterBlockType& lhs,
+ const RegisterBlockType& rhs) {
+ RegisterBlockType result;
+ for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) {
+ result.buf.reg[i] = Add(lhs.buf.reg[i], rhs.buf.reg[i]);
+ }
+ return result;
+ }
+};
+
+template <typename RegisterBlockType>
+RegisterBlockType RegisterBlockAdd(const RegisterBlockType& lhs,
+ const RegisterBlockType& rhs) {
+ return RegisterBlockAddImpl<RegisterBlockType>::Run(lhs, rhs);
+}
+
+template <typename LhsType, typename RhsType>
+struct ShouldFlipLhsRhs {
+ static constexpr bool kValue =
+ (LhsType::kScalarCount < RhsType::kScalarCount) ||
+ (LhsType::kScalarCount == RhsType::kScalarCount &&
+ (LhsType::kRows < RhsType::kRows));
+};
+
+template <typename LhsType, typename RhsType,
+ bool Flip = ShouldFlipLhsRhs<LhsType, RhsType>::kValue>
+struct FlipLhsRhs {
+ using FlippedLhsType = LhsType;
+ using FlippedRhsType = RhsType;
+ static const FlippedLhsType& FlippedLhs(const LhsType& lhs,
+ const RhsType& rhs) {
+ return lhs;
+ }
+ static const FlippedRhsType& FlippedRhs(const LhsType& lhs,
+ const RhsType& rhs) {
+ return rhs;
+ }
+};
+
+template <typename LhsType, typename RhsType>
+struct FlipLhsRhs<LhsType, RhsType, true> {
+ using FlippedLhsType = RhsType;
+ using FlippedRhsType = LhsType;
+ static const FlippedLhsType& FlippedLhs(const LhsType& lhs,
+ const RhsType& rhs) {
+ return rhs;
+ }
+ static const FlippedRhsType& FlippedRhs(const LhsType& lhs,
+ const RhsType& rhs) {
+ return lhs;
+ }
+};
+
+template <typename Lhs, typename Rhs>
+struct BroadcastBinaryOpShape {
+ static constexpr int kRows =
+ Lhs::kRows > Rhs::kRows ? Lhs::kRows : Rhs::kRows;
+ static constexpr int kCols =
+ Lhs::kCols > Rhs::kCols ? Lhs::kCols : Rhs::kCols;
+};
+
+template <typename Lhs, typename Rhs>
+struct BroadcastBinaryOpRegisterBlock {
+ using Shape = BroadcastBinaryOpShape<Lhs, Rhs>;
+ using ScalarType = typename Lhs::ScalarType;
+ using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>;
+};
+
+template <typename Lhs, typename Rhs>
+struct BroadcastAddImpl {
+ using ResultBlockType =
+ typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
+ static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
+ ResultBlockType result;
+ static constexpr int Rows = ResultBlockType::kRows;
+ static constexpr int Cols = ResultBlockType::kCols;
+ static constexpr int LhsRows = Lhs::kRows;
+ static constexpr int LhsCols = Lhs::kCols;
+ static constexpr int RhsRows = Rhs::kRows;
+ static constexpr int RhsCols = Rhs::kCols;
+
+ static_assert(LhsRows == Rows || LhsRows == 1, "");
+ static_assert(RhsRows == Rows || RhsRows == 1, "");
+ static_assert(LhsCols == Cols || LhsCols == 1, "");
+ static_assert(RhsCols == Cols || RhsCols == 1, "");
+ static_assert(ResultBlockType::kRegisterLanes == 1,
+ "This path is only for scalar values");
+ static_assert(Lhs::kRegisterLanes == 1,
+ "This path is only for scalar values");
+ static_assert(Rhs::kRegisterLanes == 1,
+ "This path is only for scalar values");
+
+ for (int c = 0; c < Cols; c++) {
+ const int lhs_c = LhsCols == Cols ? c : 0;
+ const int rhs_c = RhsCols == Cols ? c : 0;
+ for (int r = 0; r < Rows; r++) {
+ const int lhs_r = LhsRows == Rows ? r : 0;
+ const int rhs_r = RhsRows == Rows ? r : 0;
+ result.buf.reg[r + c * Rows] =
+ Add(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
+ rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
+ }
+ }
+ return result;
+ }
+};
+
+template <typename Lhs, typename Rhs>
+typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastAdd(
+ const Lhs& lhs, const Rhs& rhs) {
+ using Flip = FlipLhsRhs<Lhs, Rhs>;
+ return BroadcastAddImpl<
+ typename Flip::FlippedLhsType,
+ typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
+ Flip::FlippedRhs(lhs, rhs));
+}
+
+template <typename Lhs, typename Rhs>
+struct BroadcastMulImpl {
+ using ResultBlockType =
+ typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
+ static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
+ ResultBlockType result;
+ static constexpr int Rows = ResultBlockType::kRows;
+ static constexpr int Cols = ResultBlockType::kCols;
+ static constexpr int LhsRows = Lhs::kRows;
+ static constexpr int LhsCols = Lhs::kCols;
+ static constexpr int RhsRows = Rhs::kRows;
+ static constexpr int RhsCols = Rhs::kCols;
+ static_assert(ResultBlockType::kRegisterLanes == 1,
+ "This path is only for scalar values");
+ static_assert(Lhs::kRegisterLanes == 1,
+ "This path is only for scalar values");
+ static_assert(Rhs::kRegisterLanes == 1,
+ "This path is only for scalar values");
+
+ static_assert(LhsRows == Rows || LhsRows == 1, "");
+ static_assert(RhsRows == Rows || RhsRows == 1, "");
+ static_assert(LhsCols == Cols || LhsCols == 1, "");
+ static_assert(RhsCols == Cols || RhsCols == 1, "");
+ for (int c = 0; c < Cols; c++) {
+ const int lhs_c = LhsCols == Cols ? c : 0;
+ const int rhs_c = RhsCols == Cols ? c : 0;
+ for (int r = 0; r < Rows; r++) {
+ const int lhs_r = LhsRows == Rows ? r : 0;
+ const int rhs_r = RhsRows == Rows ? r : 0;
+ result.buf.reg[r + c * Rows] =
+ Mul(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
+ rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
+ }
+ }
+ return result;
+ }
+};
+
+template <typename Lhs, typename Rhs>
+typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastMul(
+ const Lhs& lhs, const Rhs& rhs) {
+ using Flip = FlipLhsRhs<Lhs, Rhs>;
+ return BroadcastMulImpl<
+ typename Flip::FlippedLhsType,
+ typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
+ Flip::FlippedRhs(lhs, rhs));
+}
+
+template <typename Lhs, typename Rhs, typename Acc>
+struct BroadcastMulAddImpl {
+ static void Run(const Lhs& lhs, const Rhs& rhs, Acc* acc) {
+ static constexpr int Rows = Acc::kRows;
+ static constexpr int Cols = Acc::kCols;
+ static constexpr int LhsRows = Lhs::kRows;
+ static constexpr int LhsCols = Lhs::kCols;
+ static constexpr int RhsRows = Rhs::kRows;
+ static constexpr int RhsCols = Rhs::kCols;
+ static_assert(Acc::kRegisterLanes == 1,
+ "This path is only for scalar values");
+ static_assert(Lhs::kRegisterLanes == 1,
+ "This path is only for scalar values");
+ static_assert(Rhs::kRegisterLanes == 1,
+ "This path is only for scalar values");
+
+ static_assert(LhsRows == Rows || LhsRows == 1, "");
+ static_assert(RhsRows == Rows || RhsRows == 1, "");
+ static_assert(LhsCols == Cols || LhsCols == 1, "");
+ static_assert(RhsCols == Cols || RhsCols == 1, "");
+ for (int c = 0; c < Cols; c++) {
+ const int lhs_c = LhsCols == Cols ? c : 0;
+ const int rhs_c = RhsCols == Cols ? c : 0;
+ for (int r = 0; r < Rows; r++) {
+ const int lhs_r = LhsRows == Rows ? r : 0;
+ const int rhs_r = RhsRows == Rows ? r : 0;
+ MulAdd(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
+ rhs.buf.reg[rhs_r + rhs_c * RhsRows],
+ &acc->buf.reg[r + c * Rows]);
+ }
+ }
+ }
+};
+
+template <typename Lhs, typename Rhs, typename Acc>
+void BroadcastMulAdd(const Lhs& lhs, const Rhs& rhs, Acc* acc) {
+ using Flip = FlipLhsRhs<Lhs, Rhs>;
+ BroadcastMulAddImpl<typename Flip::FlippedLhsType,
+ typename Flip::FlippedRhsType,
+ Acc>::Run(Flip::FlippedLhs(lhs, rhs),
+ Flip::FlippedRhs(lhs, rhs), acc);
+}
+
+template <typename RegisterBlockType, typename SrcObjectType>
+struct LoadImpl {
+ static_assert(std::is_same<SrcObjectType, void>::value,
+ "This generic impl should never be hit");
+};
+
+template <typename ScalarType, int Rows, int Cols, typename SrcScalarType>
+struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
+ MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
+ using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
+ using SrcObjectType = MatrixMap<SrcScalarType, MapOrder::ColMajor>;
+ static RegisterBlockType Run(const SrcObjectType& src, int row, int col) {
+ RegisterBlockType result;
+ int i = 0;
+ for (int c = 0; c < Cols; c++) {
+ const ScalarType* src_ptr = src.data(row, col + c);
+ for (int r = 0; r < Rows; r++) {
+ result.buf.reg[i++] = *src_ptr++;
+ }
+ }
+ return result;
+ }
+};
+
+template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
+ VectorShape Shape>
+struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
+ VectorMap<SrcScalarType, Shape>> {
+ using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
+ using SrcObjectType = VectorMap<SrcScalarType, Shape>;
+ static RegisterBlockType Run(const SrcObjectType& src, int pos) {
+ static_assert(Shape == VectorShape::Col || Rows == 1, "");
+ static_assert(Shape == VectorShape::Row || Cols == 1, "");
+ RegisterBlockType result;
+ for (int i = 0; i < Rows * Cols; i++) {
+ result.buf.reg[i] = src(pos + i);
+ }
+ return result;
+ }
+};
+
+template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
+ VectorShape Shape>
+struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
+ VectorDup<SrcScalarType, Shape>> {
+ using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
+ using SrcObjectType = VectorDup<SrcScalarType, Shape>;
+ static RegisterBlockType Run(const SrcObjectType& src, int) {
+ static_assert(Shape == VectorShape::Col || Rows == 1, "");
+ static_assert(Shape == VectorShape::Row || Cols == 1, "");
+ RegisterBlockType result;
+ for (int i = 0; i < Rows * Cols; i++) {
+ result.buf.reg[i] = src(0);
+ }
+ return result;
+ }
+};
+
+template <typename RegisterBlockType, typename SrcObjectType>
+RegisterBlockType Load(const SrcObjectType& src, int row, int col) {
+ return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, row, col);
+}
+
+template <typename RegisterBlockType, typename SrcObjectType>
+RegisterBlockType Load(const SrcObjectType& src, int pos) {
+ return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, pos);
+}
+
+template <typename RegisterBlockType>
+struct LoadContiguousImpl {
+ using ScalarType = typename RegisterBlockType::ScalarType;
+ static_assert(RegisterBlockType::kRegisterLanes == 1,
+ "This path is only for scalar values");
+ static RegisterBlockType Run(const ScalarType* src) {
+ RegisterBlockType result;
+ for (int i = 0; i < RegisterBlockType::kScalarCount; i++) {
+ result.buf.reg[i] = src[i];
+ }
+ return result;
+ }
+};
+
+template <typename RegisterBlockType>
+RegisterBlockType LoadContiguous(
+ const typename RegisterBlockType::ScalarType* src) {
+ return LoadContiguousImpl<RegisterBlockType>::Run(src);
+}
+
+template <int BroadcastRows, int BroadcastCols, typename SrcObjectType>
+struct LoadForBroadcastingShape {};
+
+template <int BroadcastRows, int BroadcastCols, typename ScalarType,
+ VectorShape Shape>
+struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols,
+ VectorMap<ScalarType, Shape>> {
+ static constexpr int kRows = Shape == VectorShape::Col ? BroadcastRows : 1;
+ static constexpr int kCols = Shape == VectorShape::Row ? BroadcastCols : 1;
+};
+
+template <int BroadcastRows, int BroadcastCols, typename ScalarType,
+ VectorShape Shape>
+struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols,
+ VectorDup<ScalarType, Shape>> {
+ static constexpr int kRows = 1;
+ static constexpr int kCols = 1;
+};
+
+template <typename RegisterBlockType, typename SrcObjectType>
+struct LoadForBroadcastingRegisterBlock {
+ using Shape =
+ LoadForBroadcastingShape<RegisterBlockType::kRows,
+ RegisterBlockType::kCols, SrcObjectType>;
+ using ScalarType = typename RegisterBlockType::ScalarType;
+ using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>;
+};
+
+template <typename RegisterBlockType, typename SrcObjectType>
+struct LoadForBroadcastingImpl {
+ static_assert(std::is_same<SrcObjectType, void>::value,
+ "This generic impl should never be hit");
+};
+
+template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
+ VectorShape Shape>
+struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>,
+ VectorMap<SrcScalarType, Shape>> {
+ using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
+ using SrcObjectType = VectorMap<SrcScalarType, Shape>;
+ using ResultBlockType =
+ typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
+ SrcObjectType>::Type;
+ static_assert(ResultBlockType::kRegisterLanes == 1,
+ "This path is only for scalar values");
+ static ResultBlockType Run(const SrcObjectType& src, int pos) {
+ ResultBlockType result;
+ for (int c = 0; c < ResultBlockType::kCols; c++) {
+ for (int r = 0; r < ResultBlockType::kRows; r++) {
+ const int i = Shape == VectorShape::Col ? r : c;
+ result.buf.reg[r + c * ResultBlockType::kRows] = src(pos + i);
+ }
+ }
+ return result;
+ }
+};
+
+template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
+ VectorShape Shape>
+struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>,
+ VectorDup<SrcScalarType, Shape>> {
+ using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
+ using SrcObjectType = VectorDup<SrcScalarType, Shape>;
+ using ResultBlockType =
+ typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
+ SrcObjectType>::Type;
+ static_assert(ResultBlockType::kRegisterLanes == 1,
+ "This path is only for scalar values");
+ static ResultBlockType Run(const SrcObjectType& src, int) {
+ ResultBlockType result;
+ for (int c = 0; c < ResultBlockType::kCols; c++) {
+ for (int r = 0; r < ResultBlockType::kRows; r++) {
+ result.buf.reg[r + c * ResultBlockType::kRows] = src(0);
+ }
+ }
+ return result;
+ }
+};
+
+template <typename RegisterBlockType, typename SrcObjectType>
+typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
+ SrcObjectType>::Type
+LoadForBroadcasting(const SrcObjectType& src, int row, int col) {
+ return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run(
+ src, row, col);
+}
+
+template <typename RegisterBlockType, typename SrcObjectType>
+typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
+ SrcObjectType>::Type
+LoadForBroadcasting(const SrcObjectType& src, int pos) {
+ return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run(src,
+ pos);
+}
+
+template <int ConstantValue, typename RegisterBlockType>
+struct AddConstantImpl {
+ static void Run(RegisterBlockType* block) {
+ using RegisterType = typename RegisterBlockType::RegisterType;
+ const RegisterType dup = Dup<RegisterType>(ConstantValue);
+ for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) {
+ block->buf.reg[i] = Add(block->buf.reg[i], dup);
+ }
+ }
+};
+
+template <typename RegisterBlockType>
+struct AddConstantImpl<0, RegisterBlockType> {
+ static void Run(RegisterBlockType*) {
+ // This is a no-op.
+ }
+};
+
+template <int ConstantValue, typename RegisterBlockType>
+void AddConstant(RegisterBlockType* block) {
+ AddConstantImpl<ConstantValue, RegisterBlockType>::Run(block);
+}
+
+template <int N>
+using RegBufferInt32 = RegisterBuffer<std::int32_t, N>;
+template <int N>
+using RegBufferUint8 = RegisterBuffer<std::uint8_t, N>;
+template <int R, int C>
+using RegBlockInt32 = RegisterBlock<std::int32_t, R, C>;
+template <int R, int C>
+using RegBlockUint8 = RegisterBlock<std::uint8_t, R, C>;
+
+} // end namespace gemmlowp
+
+#if defined GEMMLOWP_NEON
+#include "simd_wrappers_neon.h"
+#elif defined GEMMLOWP_SSE4
+#include "simd_wrappers_sse.h"
+#endif
+
+#endif // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_