diff options
author | Chunseok Lee <chunseok.lee@samsung.com> | 2018-05-04 17:57:16 +0900 |
---|---|---|
committer | Chunseok Lee <chunseok.lee@samsung.com> | 2018-05-04 17:57:16 +0900 |
commit | 07659ccd9fe7b1cf1547cc6cad78bcf489f0a361 (patch) | |
tree | cf3a123812b7f1ad8b50d7d0ace891e0c03c6110 /runtimes/nn/depend/external/gemmlowp/internal/output_neon.h | |
parent | da6f7a3e8360a49fd073a6e0031a4da134d9d984 (diff) | |
download | nnfw-07659ccd9fe7b1cf1547cc6cad78bcf489f0a361.tar.gz nnfw-07659ccd9fe7b1cf1547cc6cad78bcf489f0a361.tar.bz2 nnfw-07659ccd9fe7b1cf1547cc6cad78bcf489f0a361.zip |
Imported Upstream version 0.1upstream/0.1submit/tizen/20180504.091146
Diffstat (limited to 'runtimes/nn/depend/external/gemmlowp/internal/output_neon.h')
-rw-r--r-- | runtimes/nn/depend/external/gemmlowp/internal/output_neon.h | 432 |
1 files changed, 432 insertions, 0 deletions
diff --git a/runtimes/nn/depend/external/gemmlowp/internal/output_neon.h b/runtimes/nn/depend/external/gemmlowp/internal/output_neon.h new file mode 100644 index 000000000..7e111e586 --- /dev/null +++ b/runtimes/nn/depend/external/gemmlowp/internal/output_neon.h @@ -0,0 +1,432 @@ +// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// output_neon.h: optimized NEON specializations of the templates in output.h. + +#ifndef GEMMLOWP_INTERNAL_OUTPUT_NEON_H_ +#define GEMMLOWP_INTERNAL_OUTPUT_NEON_H_ + +#include "output.h" + +#include <arm_neon.h> + +namespace gemmlowp { + +template <> +struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, + RegBufferInt32<4>> { + typedef RegBufferInt32<4> InputType; + typedef RegBufferUint8<4> OutputType; + + typedef OutputStageSaturatingCastToUint8 OutputStage; + + OutputStageEvalBufferImpl(const OutputStage&) {} + + OutputType Eval(InputType input) const { + OutputType output; + int16x4_t res_16 = vqmovn_s32(input.reg[0]); + uint8x8_t res_8 = vqmovun_s16(vcombine_s16(res_16, res_16)); + output.reg[0] = vget_lane_u32(vreinterpret_u32_u8(res_8), 0); + return output; + } +}; + +template <> +struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, + RegBufferInt32<8>> { + typedef RegBufferInt32<8> InputType; + typedef RegBufferUint8<8> OutputType; + + typedef OutputStageSaturatingCastToUint8 OutputStage; + + OutputStageEvalBufferImpl(const OutputStage&) {} + + OutputType Eval(InputType input) const { + OutputType output; + int16x8_t res_16 = + vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1])); + output.reg[0] = vqmovun_s16(res_16); + return output; + } +}; + +template <> +struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, + RegBufferInt32<16>> { + typedef RegBufferInt32<16> InputType; + typedef RegBufferUint8<16> OutputType; + + typedef OutputStageSaturatingCastToUint8 OutputStage; + + OutputStageEvalBufferImpl(const OutputStage&) {} + + OutputType Eval(InputType input) const { + OutputType output; + int16x8_t res_16_0 = + vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1])); + int16x8_t res_16_1 = + vcombine_s16(vqmovn_s32(input.reg[2]), vqmovn_s32(input.reg[3])); + output.reg[0] = vqmovun_s16(res_16_0); + output.reg[1] = vqmovun_s16(res_16_1); + return output; + } +}; + +template <> +struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, + RegBufferInt32<32>> { + typedef RegBufferInt32<32> InputType; + typedef RegBufferUint8<32> OutputType; + + typedef OutputStageSaturatingCastToUint8 OutputStage; + + OutputStageEvalBufferImpl(const OutputStage&) {} + + OutputType Eval(InputType input) const { + OutputType output; + int16x8_t res_16[4]; + for (int i = 0; i < 4; i++) { + res_16[i] = vcombine_s16(vqmovn_s32(input.reg[2 * i]), + vqmovn_s32(input.reg[2 * i + 1])); + } + for (int i = 0; i < 4; i++) { + output.reg[i] = vqmovun_s16(res_16[i]); + } + return output; + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> { + static void Run(const RegBlockInt32<8, 1>& src, DstType* dst, int row, + int col) { + if (DstType::kOrder == MapOrder::ColMajor) { + StoreInt32x4(dst->data(row, col), src.buf.reg[0]); + StoreInt32x4(dst->data(row + 4, col), src.buf.reg[1]); + } else { + *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]); + *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]); + *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]); + *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]); + *dst->data(row + 4, col) = GetLane<0>(src.buf.reg[1]); + *dst->data(row + 5, col) = GetLane<1>(src.buf.reg[1]); + *dst->data(row + 6, col) = GetLane<2>(src.buf.reg[1]); + *dst->data(row + 7, col) = GetLane<3>(src.buf.reg[1]); + } + } +}; + +inline RegBlockInt32<4, 4> Transpose(const RegBlockInt32<4, 4>& src) { + const int32x4x2_t t0 = vtrnq_s32(src.buf.reg[0], src.buf.reg[1]); + const int32x4x2_t t1 = vtrnq_s32(src.buf.reg[2], src.buf.reg[3]); + RegBlockInt32<4, 4> result; + result.buf.reg[0] = + vcombine_s32(vget_low_s32(t0.val[0]), vget_low_s32(t1.val[0])); + result.buf.reg[1] = + vcombine_s32(vget_low_s32(t0.val[1]), vget_low_s32(t1.val[1])); + result.buf.reg[2] = + vcombine_s32(vget_high_s32(t0.val[0]), vget_high_s32(t1.val[0])); + result.buf.reg[3] = + vcombine_s32(vget_high_s32(t0.val[1]), vget_high_s32(t1.val[1])); + return result; +} + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockInt32<4, 4>, DstType> { + static void Run(const RegBlockInt32<4, 4>& src, DstType* dst, int row, + int col) { + const auto& block = + DstType::kOrder == MapOrder::ColMajor ? src : Transpose(src); + std::int32_t* dst_ptr = dst->data(row, col); + int stride = dst->stride(); + for (int i = 0; i < 4; i++) { + vst1q_s32(dst_ptr + i * stride, block.buf.reg[i]); + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> { + static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row, + int col) { + std::int32_t* dst_ptr = dst->data(row, col); + if (DstType::kOrder == MapOrder::ColMajor) { + int col_stride = dst->cols_stride(); + for (int i = 0; i < 4; i++) { + vst1q_s32(dst_ptr + i * col_stride + 0, src.buf.reg[2 * i + 0]); + vst1q_s32(dst_ptr + i * col_stride + 4, src.buf.reg[2 * i + 1]); + } + } else { + int row_stride = dst->rows_stride(); + RegBlockInt32<4, 4> top; + top.buf.reg[0] = src.buf.reg[0]; + top.buf.reg[1] = src.buf.reg[2]; + top.buf.reg[2] = src.buf.reg[4]; + top.buf.reg[3] = src.buf.reg[6]; + const auto transpose_top = Transpose(top); + for (int i = 0; i < 4; i++) { + vst1q_s32(dst_ptr + i * row_stride, transpose_top.buf.reg[i]); + } + RegBlockInt32<4, 4> bottom; + bottom.buf.reg[0] = src.buf.reg[1]; + bottom.buf.reg[1] = src.buf.reg[3]; + bottom.buf.reg[2] = src.buf.reg[5]; + bottom.buf.reg[3] = src.buf.reg[7]; + const auto transpose_bottom = Transpose(bottom); + for (int i = 0; i < 4; i++) { + vst1q_s32(dst_ptr + (i + 4) * row_stride, transpose_bottom.buf.reg[i]); + } + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> { + static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row, + int col) { + std::int32_t* dst_ptr = dst->data(row, col); + if (DstType::kOrder == MapOrder::ColMajor) { + int col_stride = dst->cols_stride(); + for (int i = 0; i < 8; i++) { + vst1q_s32(dst_ptr + i * col_stride, src.buf.reg[2 * i]); + vst1q_s32(dst_ptr + i * col_stride + 4, src.buf.reg[2 * i + 1]); + } + } else { + int row_stride = dst->rows_stride(); + RegBlockInt32<4, 4> top_left; + top_left.buf.reg[0] = src.buf.reg[0]; + top_left.buf.reg[1] = src.buf.reg[2]; + top_left.buf.reg[2] = src.buf.reg[4]; + top_left.buf.reg[3] = src.buf.reg[6]; + const auto transpose_top_left = Transpose(top_left); + for (int i = 0; i < 4; i++) { + vst1q_s32(dst_ptr + i * row_stride, transpose_top_left.buf.reg[i]); + } + RegBlockInt32<4, 4> bottom_left; + bottom_left.buf.reg[0] = src.buf.reg[1]; + bottom_left.buf.reg[1] = src.buf.reg[3]; + bottom_left.buf.reg[2] = src.buf.reg[5]; + bottom_left.buf.reg[3] = src.buf.reg[7]; + const auto transpose_bottom_left = Transpose(bottom_left); + for (int i = 0; i < 4; i++) { + vst1q_s32(dst_ptr + (i + 4) * row_stride, + transpose_bottom_left.buf.reg[i]); + } + RegBlockInt32<4, 4> top_right; + top_right.buf.reg[0] = src.buf.reg[8]; + top_right.buf.reg[1] = src.buf.reg[10]; + top_right.buf.reg[2] = src.buf.reg[12]; + top_right.buf.reg[3] = src.buf.reg[14]; + const auto transpose_top_right = Transpose(top_right); + for (int i = 0; i < 4; i++) { + vst1q_s32(dst_ptr + i * row_stride + 4, transpose_top_right.buf.reg[i]); + } + RegBlockInt32<4, 4> bottom_right; + bottom_right.buf.reg[0] = src.buf.reg[9]; + bottom_right.buf.reg[1] = src.buf.reg[11]; + bottom_right.buf.reg[2] = src.buf.reg[13]; + bottom_right.buf.reg[3] = src.buf.reg[15]; + const auto transpose_bottom_right = Transpose(bottom_right); + for (int i = 0; i < 4; i++) { + vst1q_s32(dst_ptr + (i + 4) * row_stride + 4, + transpose_bottom_right.buf.reg[i]); + } + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> { + static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row, + int col) { + std::int32_t* dst_ptr = dst->data(row, col); + if (DstType::kOrder == MapOrder::ColMajor) { + vst1q_s32(dst_ptr, src.buf.reg[0]); + } else { + int row_stride = dst->rows_stride(); + vst1q_lane_s32(dst_ptr + 0 * row_stride, src.buf.reg[0], 0); + vst1q_lane_s32(dst_ptr + 1 * row_stride, src.buf.reg[0], 1); + vst1q_lane_s32(dst_ptr + 2 * row_stride, src.buf.reg[0], 2); + vst1q_lane_s32(dst_ptr + 3 * row_stride, src.buf.reg[0], 3); + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockInt32<1, 4>, DstType> { + static void Run(const RegBlockInt32<1, 4>& src, DstType* dst, int row, + int col) { + std::int32_t* dst_ptr = dst->data(row, col); + if (DstType::kOrder == MapOrder::RowMajor) { + vst1q_s32(dst_ptr, src.buf.reg[0]); + } else { + int col_stride = dst->cols_stride(); + vst1q_lane_s32(dst_ptr + 0 * col_stride, src.buf.reg[0], 0); + vst1q_lane_s32(dst_ptr + 1 * col_stride, src.buf.reg[0], 1); + vst1q_lane_s32(dst_ptr + 2 * col_stride, src.buf.reg[0], 2); + vst1q_lane_s32(dst_ptr + 3 * col_stride, src.buf.reg[0], 3); + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockUint8<4, 1>, DstType> { + static void Run(const RegBlockUint8<4, 1>& src, DstType* dst, int row, + int col) { + const std::uint32_t src_reg = src.buf.reg[0]; + for (int i = 0; i < 4; i++) { + *dst->data(row + i, col) = (src_reg >> (8 * i)); + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockUint8<1, 4>, DstType> { + static void Run(const RegBlockUint8<1, 4>& src, DstType* dst, int row, + int col) { + for (int i = 0; i < 4; i++) { + *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i)); + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockUint8<8, 1>, DstType> { + static void Run(const RegBlockUint8<8, 1>& src, DstType* dst, int row, + int col) { + std::uint8_t* dst_ptr = dst->data(row, col); + if (DstType::kOrder == MapOrder::ColMajor) { + vst1_u8(dst_ptr, src.buf.reg[0]); + } else { + const int row_stride = dst->rows_stride(); + vst1_lane_u8(dst_ptr + 0 * row_stride, src.buf.reg[0], 0); + vst1_lane_u8(dst_ptr + 1 * row_stride, src.buf.reg[0], 1); + vst1_lane_u8(dst_ptr + 2 * row_stride, src.buf.reg[0], 2); + vst1_lane_u8(dst_ptr + 3 * row_stride, src.buf.reg[0], 3); + vst1_lane_u8(dst_ptr + 4 * row_stride, src.buf.reg[0], 4); + vst1_lane_u8(dst_ptr + 5 * row_stride, src.buf.reg[0], 5); + vst1_lane_u8(dst_ptr + 6 * row_stride, src.buf.reg[0], 6); + vst1_lane_u8(dst_ptr + 7 * row_stride, src.buf.reg[0], 7); + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockUint8<4, 4>, DstType> { + static void Run(const RegBlockUint8<4, 4>& src, DstType* dst, int row, + int col) { + std::uint8_t* dst_ptr = dst->data(row, col); + const int row_stride = dst->rows_stride(); + const int col_stride = dst->cols_stride(); + for (int i = 0; i < 2; i++) { + vst1_lane_u8(dst_ptr + 0 * row_stride + (2 * i + 0) * col_stride, + src.buf.reg[i], 0); + vst1_lane_u8(dst_ptr + 1 * row_stride + (2 * i + 0) * col_stride, + src.buf.reg[i], 1); + vst1_lane_u8(dst_ptr + 2 * row_stride + (2 * i + 0) * col_stride, + src.buf.reg[i], 2); + vst1_lane_u8(dst_ptr + 3 * row_stride + (2 * i + 0) * col_stride, + src.buf.reg[i], 3); + vst1_lane_u8(dst_ptr + 0 * row_stride + (2 * i + 1) * col_stride, + src.buf.reg[i], 4); + vst1_lane_u8(dst_ptr + 1 * row_stride + (2 * i + 1) * col_stride, + src.buf.reg[i], 5); + vst1_lane_u8(dst_ptr + 2 * row_stride + (2 * i + 1) * col_stride, + src.buf.reg[i], 6); + vst1_lane_u8(dst_ptr + 3 * row_stride + (2 * i + 1) * col_stride, + src.buf.reg[i], 7); + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> { + static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row, + int col) { + std::uint8_t* dst_ptr = dst->data(row, col); + if (DstType::kOrder == MapOrder::ColMajor) { + int col_stride = dst->cols_stride(); + for (int i = 0; i < 4; i++) { + vst1_u8(dst_ptr + i * col_stride, src.buf.reg[i]); + } + } else { + for (int i = 0; i < 4; i++) { + int row_stride = dst->rows_stride(); + std::uint8_t* col_ptr = dst_ptr + i; + vst1_lane_u8(col_ptr + 0 * row_stride, src.buf.reg[i], 0); + vst1_lane_u8(col_ptr + 1 * row_stride, src.buf.reg[i], 1); + vst1_lane_u8(col_ptr + 2 * row_stride, src.buf.reg[i], 2); + vst1_lane_u8(col_ptr + 3 * row_stride, src.buf.reg[i], 3); + vst1_lane_u8(col_ptr + 4 * row_stride, src.buf.reg[i], 4); + vst1_lane_u8(col_ptr + 5 * row_stride, src.buf.reg[i], 5); + vst1_lane_u8(col_ptr + 6 * row_stride, src.buf.reg[i], 6); + vst1_lane_u8(col_ptr + 7 * row_stride, src.buf.reg[i], 7); + } + } + } +}; + +inline RegBlockUint8<8, 8> Transpose(const RegBlockUint8<8, 8>& src) { + uint8x8x2_t a[4]; + a[0] = vtrn_u8(src.buf.reg[0], src.buf.reg[1]); + a[1] = vtrn_u8(src.buf.reg[2], src.buf.reg[3]); + a[2] = vtrn_u8(src.buf.reg[4], src.buf.reg[5]); + a[3] = vtrn_u8(src.buf.reg[6], src.buf.reg[7]); + uint16x4x2_t b[4]; + b[0] = vtrn_u16(vreinterpret_u16_u8(a[0].val[0]), + vreinterpret_u16_u8(a[1].val[0])); + b[1] = vtrn_u16(vreinterpret_u16_u8(a[0].val[1]), + vreinterpret_u16_u8(a[1].val[1])); + b[2] = vtrn_u16(vreinterpret_u16_u8(a[2].val[0]), + vreinterpret_u16_u8(a[3].val[0])); + b[3] = vtrn_u16(vreinterpret_u16_u8(a[2].val[1]), + vreinterpret_u16_u8(a[3].val[1])); + uint32x2x2_t c[4]; + c[0] = vtrn_u32(vreinterpret_u32_u16(b[0].val[0]), + vreinterpret_u32_u16(b[2].val[0])); + c[1] = vtrn_u32(vreinterpret_u32_u16(b[1].val[0]), + vreinterpret_u32_u16(b[3].val[0])); + c[2] = vtrn_u32(vreinterpret_u32_u16(b[0].val[1]), + vreinterpret_u32_u16(b[2].val[1])); + c[3] = vtrn_u32(vreinterpret_u32_u16(b[1].val[1]), + vreinterpret_u32_u16(b[3].val[1])); + RegBlockUint8<8, 8> result; + result.buf.reg[0] = vreinterpret_u8_u32(c[0].val[0]); + result.buf.reg[1] = vreinterpret_u8_u32(c[1].val[0]); + result.buf.reg[2] = vreinterpret_u8_u32(c[2].val[0]); + result.buf.reg[3] = vreinterpret_u8_u32(c[3].val[0]); + result.buf.reg[4] = vreinterpret_u8_u32(c[0].val[1]); + result.buf.reg[5] = vreinterpret_u8_u32(c[1].val[1]); + result.buf.reg[6] = vreinterpret_u8_u32(c[2].val[1]); + result.buf.reg[7] = vreinterpret_u8_u32(c[3].val[1]); + return result; +} + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> { + static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row, + int col) { + const auto& block = + DstType::kOrder == MapOrder::ColMajor ? src : Transpose(src); + std::uint8_t* dst_ptr = dst->data(row, col); + int stride = dst->stride(); + for (int i = 0; i < 8; i++) { + vst1_u8(dst_ptr + i * stride, block.buf.reg[i]); + } + } +}; + +} // namespace gemmlowp + +#endif // GEMMLOWP_INTERNAL_OUTPUT_NEON_H_ |