diff options
Diffstat (limited to 'compute/ncnn/src/srcn')
25 files changed, 0 insertions, 19736 deletions
diff --git a/compute/ncnn/src/srcn/common.h b/compute/ncnn/src/srcn/common.h deleted file mode 100644 index 778a17a80..000000000 --- a/compute/ncnn/src/srcn/common.h +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __NNFW_SRCN_COMMON_H__ -#define __NNFW_SRCN_COMMON_H__ - -#include <string.h> -#include <limits> -#include <arm_neon.h> - -#include "ncnn/srcn/conv_type.h" - -namespace nnfw -{ -namespace srcn -{ - -#define sizeof_RhsScalar 4 -#define sizeof_LhsScalar 4 -#define sizeof_ResScalar 4 - -#define MIN(a, b) (a) > (b) ? (b) : (a) -#define MAX(a, b) (a) > (b) ? (a) : (b) - -enum shardType_t -{ - shardByCol = 0, - shardByRow -}; - -#ifdef TIZEN -#define L1_CACHE_SIZE (16536 * 2) -#define L2_CACHE_SIZE (524288 * 2) -#define L3_CACHE_SIZE (0) // no L3 -#define MAX_K (512) -// single-thread -#define GEN_COL (1440) -// multi-threads -#define MAX_COL (90) -#define MIN_COL (32) -#elif defined ANDROID -#define L1_CACHE_SIZE (16536 * 4) -#define L2_CACHE_SIZE (524288 * 8) -#define L3_CACHE_SIZE (0) //(524288 * 8) //no L3 -#define MAX_K (512 * 2) -// single-thread -#define GEN_COL (1440) -// multi-threads -#if __aarch64__ -#define MAX_COL (1024) -#else -#define MAX_COL (90) -#endif -#define MIN_COL (32) -#endif - -enum -{ - USE_COMMON_KENEL = 0, - USE_12BIT_KERNEL, - USE_NONZERO_KERENL -}; - -template <typename T> static T divup(const T &x, const T &y) -{ - return static_cast<T>((x + y - 1) / y); -} - -#ifdef NCNN -static inline size_t alignSize(size_t sz, int n) { return (sz + n - 1) / n * n; } - -static inline size_t alignBy2(size_t sz) { return (sz + 1) & -2; } -#endif // NCNN - -static inline int32_t BitNot(int32_t a) { return ~a; } - -static inline int32_t MaskIfNonZero(int32_t a) -{ - static int32_t zero = 0; - return a ? BitNot(zero) : zero; -} - -static inline int32_t BitAnd(int32_t a, int32_t b) { return a & b; } - -static inline int32_t ShiftRight(int32_t a, int offset) { return a >> offset; } - -static inline int32_t MaskIfLessThan(int32_t a, int32_t b) { return MaskIfNonZero(a < b); } - -static inline int32_t MaskIfGreaterThan(int32_t a, int32_t b) { return MaskIfNonZero(a > b); } - -static inline int32_t Add(int32_t a, int32_t b) { return a + b; } - -static inline int32_t RoundingDivideByPOT(int32_t x, int exponent) -{ - const int32_t mask = (1ll << exponent) - 1; - const int32_t zero = 0; - const int32_t one = 1; - const int32_t remainder = BitAnd(x, mask); - const int32_t threshold = Add(ShiftRight(mask, 1), BitAnd(MaskIfLessThan(x, zero), one)); - return Add(ShiftRight(x, exponent), BitAnd(MaskIfGreaterThan(remainder, threshold), one)); -} -static inline int32_t SaturatingRoundingDoublingHighMul(int32_t a, int32_t b) -{ - bool overflow = a == b && a == std::numeric_limits<int32_t>::min(); - int64_t a_64(a); - int64_t b_64(b); - int64_t ab_64 = a_64 * b_64; - int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30)); - int32_t ab_x2_high32 = static_cast<int32_t>((ab_64 + nudge) / (1ll << 31)); - return overflow ? std::numeric_limits<int32_t>::max() : ab_x2_high32; -} - -static inline int32_t MultiplyByQuantizedMultiplier(int32_t x, int32_t quantized_multiplier, - int shift) -{ - int left_shift = shift > 0 ? shift : 0; - int right_shift = shift > 0 ? 0 : -shift; - return RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(x * (1 << left_shift), quantized_multiplier), right_shift); -} - -static inline int32x4_t SaturatingRoundingDoublingHighMulV(int32x4_t a, int32x4_t b) -{ - return vqrdmulhq_s32(a, b); -} - -static inline int32x4_t RoundingDivideByPOTV(int32x4_t x, int exponent) -{ - const int32x4_t shift_vec = vdupq_n_s32(-exponent); - const int32x4_t fixup = vshrq_n_s32(vandq_s32(x, shift_vec), 31); - const int32x4_t fixed_up_x = vqaddq_s32(x, fixup); - return vrshlq_s32(fixed_up_x, shift_vec); -} - -static inline int32x4_t MultiplyByQuantizedMultiplierV(int32x4_t x, int32_t quantized_multiplier, - int shift) -{ - int left_shift = shift > 0 ? shift : 0; - int right_shift = shift > 0 ? 0 : -shift; - return RoundingDivideByPOTV( - SaturatingRoundingDoublingHighMulV(vrshlq_s32(x, vdupq_n_s32(left_shift)), - vdupq_n_s32(quantized_multiplier)), - right_shift); -} - -} // namespace srcn -} // namespace nnfw - -#endif // __NNFW_SRCN_COMMON_H__ diff --git a/compute/ncnn/src/srcn/conv_sgemm_multithreads.cc b/compute/ncnn/src/srcn/conv_sgemm_multithreads.cc deleted file mode 100644 index 21083f677..000000000 --- a/compute/ncnn/src/srcn/conv_sgemm_multithreads.cc +++ /dev/null @@ -1,483 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifdef _OPENMP -#include <omp.h> -#endif - -#include "ncnn/srcn/conv_type.h" -#include "common.h" -#include "sgemm_kernel.h" -#include "sgemm_pack.h" -#include "conv_sgemm_multithreads.h" - -namespace nnfw -{ -namespace srcn -{ - -void conv_sgemm_multithreads::param_init() -{ -#if __aarch64__ - if (conv_type_ == row_major) - { - mr_ = 8; - nr_ = 12; - } - else if (conv_type_ == col_major) - { -#ifdef BATCH_DILATION_FIX - if (out_mat_.n > 1) - { - - mr_ = 24; - nr_ = 4; - } - else -#endif // BATCH_DILATION_FIX - { - if (m_ > n_) - { - mr_ = 24; - nr_ = 4; - } - else - { - mr_ = 12; - nr_ = 8; - } - } - } -#else // __aarch64__ - if (conv_type_ == row_major) - { - mr_ = 6; - nr_ = 8; - } - else if (conv_type_ == col_major) - { - mr_ = 8; - nr_ = 6; - } -#endif // __aarch64__ - int col = n_; - - if (m_ > n_) - { - shard_type_ = shardByRow; - col = m_; - } - else - { - shard_type_ = shardByCol; - } - - int th_base = divup(col, num_threads_); - - th_base = MIN(MAX(th_base, MIN_COL), MAX_COL); - - int k_div = (nr_ * sizeof_RhsScalar); - int k_sub = (mr_ * nr_ * sizeof_ResScalar); - - const int k_cache = MIN(divup((int)(L1_CACHE_SIZE - k_sub), (int)k_div * 2), MAX_K); - bk_ = MIN(k_cache, k_); - - if (shard_type_ == shardByCol) - { - int m_sub = (bk_ * nr_ * sizeof_RhsScalar); - int m_div = (sizeof_LhsScalar * bk_ * 2 * num_threads_); - if (L3_CACHE_SIZE) - m_div = (sizeof_LhsScalar * bk_ * 2); - int m_cache = divup((L2_CACHE_SIZE - m_sub), m_div); - bm_ = MIN(m_cache, m_); - - bn_ = MIN(th_base, n_); - if (L3_CACHE_SIZE) - { - int n_sub = (bk_ * bm_ * sizeof_RhsScalar); - int n_div = (sizeof_LhsScalar * bk_ * 2 * num_threads_); - int n_cache = divup((L3_CACHE_SIZE - n_sub), n_div); - bn_ = MIN(n_cache, bn_); - } - } - else - { - int n_sub = (bk_ * mr_ * sizeof_LhsScalar); - int n_div = (sizeof_LhsScalar * bk_ * 2 * num_threads_); - if (L3_CACHE_SIZE) - n_div = (sizeof_LhsScalar * bk_ * 2); - int n_cache = divup((L2_CACHE_SIZE - n_sub), n_div); - bn_ = MIN(n_cache, n_); - - bm_ = MIN(th_base, m_); - if (L3_CACHE_SIZE) - { - int m_sub = (bk_ * bn_ * sizeof_RhsScalar); - int m_div = (sizeof_LhsScalar * bk_ * 2 * num_threads_); - int m_cache = divup((L3_CACHE_SIZE - m_sub), m_div); - bm_ = MIN(m_cache, bm_); - } - } - - nm_ = divup(m_, bm_); - nn_ = divup(n_, bn_); - nk_ = divup(k_, bk_); - - rm_ = m_ % bm_; - rn_ = n_ % bn_; - rk_ = k_ % bk_; -} - -conv_sgemm_multithreads::conv_sgemm_multithreads(const convMat_t &in_mat, - const convMat_t &weights_mat, convMat_t &out_mat, - const convParams_t &in_param, int num_threads, - convType_t conv_type) - - : in_mat_(in_mat), weights_mat_(weights_mat), out_mat_(out_mat), in_param_(in_param), - conv_type_(conv_type), num_threads_(num_threads) -{ - m_ = out_mat_.c; -#ifdef NCNN -#ifdef WITH_DPU - np_ = out_mat_.n * alignSize(out_mat_.h * out_mat_.w, 16 / sizeof(float)); - n_ = (np_ + 1) / 2; -#else // WITH_DPU - n_ = out_mat_.n * alignSize(out_mat_.h * out_mat_.w, 16 / sizeof(float)); -#endif // WITH_DPU -#else // NCNN -#ifdef WITH_DPU - np_ = out_mat_.n * out_mat_.w * out_mat_.h; - n_ = (np_ + 1) / 2; -#else // WITH_DPU - n_ = out_mat_.n * out_mat_.w * out_mat_.h; -#endif // WITH_DPU -#endif // NCNN - k_ = in_param_.kernel_h * in_param_.kernel_w * in_mat.c; - - param_init(); - - int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; - int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; - - if (shard_type_ == shardByCol) - { - plhs_buffer_ = new float[lhs_stride * 1 * nm_]; - prhs_buffer_ = new float[rhs_stride * num_threads_]; - } - else - { - plhs_buffer_ = new float[lhs_stride * num_threads_]; - prhs_buffer_ = new float[rhs_stride * 1 * nn_]; - } - - if (plhs_buffer_ == NULL || prhs_buffer_ == NULL) - { - error_ = 1; - } - - if (in_param_.kernel_w != 1 || in_param_.kernel_h != 1 || in_param_.stride_w != 1 || - in_param_.stride_h != 1 || in_param_.padding != 0) - { - need_im2col_ = 1; - } - else - { - need_im2col_ = 0; - } - - omp_set_num_threads(num_threads_); - - error_ = 0; -} - -conv_sgemm_multithreads::~conv_sgemm_multithreads() -{ - if (plhs_buffer_) - delete[] plhs_buffer_; - if (prhs_buffer_) - delete[] prhs_buffer_; -} - -void conv_sgemm_multithreads::run() -{ - if (error_) - return; - - if (shard_type_ == shardByCol && conv_type_ == col_major) - { - compute_colmajor_colshard(); - } - else if (shard_type_ == shardByRow && conv_type_ == col_major) - { - compute_colmajor_rowshard(); - } - else if (shard_type_ == shardByCol && conv_type_ == row_major) - { - compute_rowmajor_colshard(); - } - else if (shard_type_ == shardByRow && conv_type_ == row_major) - { - compute_rowmajor_rowshard(); - } -} - -void conv_sgemm_multithreads::compute_rowmajor_colshard() -{ - int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; - int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; - - for (int l = 0; l < nk_; l++) - { - const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; - -#pragma omp parallel for - for (int i = 0; i < nm_; i++) - { - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &weights_mat_.data[i * bm_ * k_ + l * bk_], - &plhs_buffer_[i * lhs_stride]); - } - -#pragma omp parallel for - for (int j = 0; j < nn_; j++) - { - int thread_num = omp_get_thread_num(); - // float *plhs_ptr = &plhs_buffer_[lhs_stride * thread_num]; - float *prhs_ptr = &prhs_buffer_[rhs_stride * thread_num]; - - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - if (need_im2col_) - { - if (out_mat_.n == 1) - { - _pack_rowmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_, const_cast<convMat_t *>(&in_mat_), - &out_mat_, const_cast<convParams_t *>(&in_param_), prhs_ptr); - } - else - { - _pack_rowmajor_image_rhs_batch(nr_, bn, bk, l * bk_, j * bn_, - const_cast<convMat_t *>(&in_mat_), &out_mat_, - const_cast<convParams_t *>(&in_param_), prhs_ptr); - } - } - else - { -#ifdef WITH_DPU - _pack_rowmajor_notrans_rhs(nr_, bn, bk, np_, &in_mat_.data[n_ + l * bk_ * np_ + j * bn_], - prhs_ptr); -#else - _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &in_mat_.data[l * bk_ * n_ + j * bn_], - prhs_ptr); -#endif - } - - for (int i = 0; i < nm_; i++) - { - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - -#ifdef WITH_DPU - _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, &plhs_buffer_[i * lhs_stride], - prhs_ptr, &out_mat_.data[n_ + i * bm_ * np_ + j * bn_], - l, np_, bk); -#else // WITH_DPU - _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, &plhs_buffer_[i * lhs_stride], - prhs_ptr, &out_mat_.data[i * bm_ * n_ + j * bn_], l, n_, - bk); -#endif // WITH_DPU - } - } - } -} - -void conv_sgemm_multithreads::compute_rowmajor_rowshard() -{ - int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; - int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; - - for (int l = 0; l < nk_; l++) - { - const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; - -#pragma omp parallel for - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - - if (need_im2col_) - { - if (out_mat_.n == 1) - { - _pack_rowmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_, const_cast<convMat_t *>(&in_mat_), - &out_mat_, const_cast<convParams_t *>(&in_param_), - &prhs_buffer_[j * rhs_stride]); - } - else - { - _pack_rowmajor_image_rhs_batch( - nr_, bn, bk, l * bk_, j * bn_, const_cast<convMat_t *>(&in_mat_), &out_mat_, - const_cast<convParams_t *>(&in_param_), &prhs_buffer_[j * rhs_stride]); - } - } - else - { - _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &in_mat_.data[l * bk_ * n_ + j * bn_], - &prhs_buffer_[j * rhs_stride]); - } - } - -#pragma omp parallel for - for (int i = 0; i < nm_; i++) - { - int thread_num = omp_get_thread_num(); - float *plhs_ptr = &plhs_buffer_[lhs_stride * thread_num]; - - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &weights_mat_.data[i * bm_ * k_ + l * bk_], - plhs_ptr); - - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - - _sgemm_rowmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, - &prhs_buffer_[j * rhs_stride], - &out_mat_.data[i * bm_ * n_ + j * bn_], l, n_, bk); - } - } - } -} - -void conv_sgemm_multithreads::compute_colmajor_colshard() -{ - int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; - int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; - - for (int l = 0; l < nk_; l++) - { - const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; - -#pragma omp parallel for - for (int i = 0; i < nm_; i++) - { - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &weights_mat_.data[l * bk_ * m_ + i * bm_], - &plhs_buffer_[i * lhs_stride]); - } - -#pragma omp parallel for - for (int j = 0; j < nn_; j++) - { - int thread_num = omp_get_thread_num(); - float *prhs_ptr = &prhs_buffer_[rhs_stride * thread_num]; - - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - - if (need_im2col_) - { - if (out_mat_.n == 1) - { - _pack_colmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_, const_cast<convMat_t *>(&in_mat_), - &out_mat_, const_cast<convParams_t *>(&in_param_), prhs_ptr); - } - else - { - _pack_colmajor_image_rhs_batch(nr_, bn, bk, l * bk_, j * bn_, - const_cast<convMat_t *>(&in_mat_), &out_mat_, - const_cast<convParams_t *>(&in_param_), prhs_ptr); - } - } - else - { - _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &in_mat_.data[j * bn_ * k_ + l * bk_], - prhs_ptr); - } - - for (int i = 0; i < nm_; i++) - { - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - _sgemm_colmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, &plhs_buffer_[i * lhs_stride], - prhs_ptr, &out_mat_.data[j * bn_ * m_ + i * bm_], l, m_, - bk); - } - } - } -} - -void conv_sgemm_multithreads::compute_colmajor_rowshard() -{ - int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; - int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; - - for (int l = 0; l < nk_; l++) - { - const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; - -#pragma omp parallel for - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - - if (need_im2col_) - { - if (out_mat_.n == 1) - { - _pack_colmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_, const_cast<convMat_t *>(&in_mat_), - &out_mat_, const_cast<convParams_t *>(&in_param_), - &prhs_buffer_[j * rhs_stride]); - } - else - { - _pack_colmajor_image_rhs_batch( - nr_, bn, bk, l * bk_, j * bn_, const_cast<convMat_t *>(&in_mat_), &out_mat_, - const_cast<convParams_t *>(&in_param_), &prhs_buffer_[j * rhs_stride]); - } - } - else - { - _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &in_mat_.data[j * bn_ * k_ + l * bk_], - &prhs_buffer_[j * rhs_stride]); - } - } - -#pragma omp parallel for - for (int i = 0; i < nm_; i++) - { - int thread_num = omp_get_thread_num(); - float *plhs_ptr = &plhs_buffer_[lhs_stride * thread_num]; - - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &weights_mat_.data[l * bk_ * m_ + i * bm_], - plhs_ptr); - - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - - _sgemm_colmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, - &prhs_buffer_[j * rhs_stride], - &out_mat_.data[j * bn_ * m_ + i * bm_], l, m_, bk); - } - } - } -} - -} // namespace srcn -} // namespace nnfw diff --git a/compute/ncnn/src/srcn/conv_sgemm_multithreads.h b/compute/ncnn/src/srcn/conv_sgemm_multithreads.h deleted file mode 100644 index 9c9ce7437..000000000 --- a/compute/ncnn/src/srcn/conv_sgemm_multithreads.h +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __NNFW_SRCN_CONV_SGEMM_MULTITHREADS_H__ -#define __NNFW_SRCN_CONV_SGEMM_MULTITHREADS_H__ - -#include "ncnn/srcn/conv_type.h" -#include "common.h" - -namespace nnfw -{ -namespace srcn -{ - -class conv_sgemm_multithreads -{ -public: - conv_sgemm_multithreads(const convMat_t &in_mat, const convMat_t &weights_mat, convMat_t &out_mat, - const convParams_t &in_param, int num_threads, convType_t conv_type); - ~conv_sgemm_multithreads(); - - void run(); - -private: - void param_init(); - - void compute_rowmajor_colshard(); - void compute_rowmajor_rowshard(); - void compute_colmajor_colshard(); - void compute_colmajor_rowshard(); - - const convMat_t in_mat_; - const convMat_t weights_mat_; - convMat_t out_mat_; - const convParams_t in_param_; - convType_t conv_type_; - int num_threads_; - - int m_; - int n_; -#ifdef WITH_DPU - int np_; -#endif - int k_; - - int bm_; - int bn_; - int bk_; - - int rm_; - int rn_; - int rk_; - - int nm_; - int nn_; - int nk_; - - int mr_; - int nr_; - - int need_im2col_; - shardType_t shard_type_; - - float *prhs_buffer_; - float *plhs_buffer_; - - int error_; -}; - -} // namespace srcn -} // namespace nnfw - -#endif // __NNFW_SRCN_CONV_SGEMM_MULTITHREADS_H__ diff --git a/compute/ncnn/src/srcn/conv_sgemm_singlethread.cc b/compute/ncnn/src/srcn/conv_sgemm_singlethread.cc deleted file mode 100644 index 4cbbf217f..000000000 --- a/compute/ncnn/src/srcn/conv_sgemm_singlethread.cc +++ /dev/null @@ -1,366 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <stdexcept> - -#include "common.h" -#include "sgemm_kernel.h" -#include "sgemm_pack.h" -#include "conv_sgemm_singlethread.h" - -namespace nnfw -{ -namespace srcn -{ - -void conv_sgemm_singlethread::param_init() -{ - if (n_ > 3 * m_) - { - shard_type_ = shardByRow; - } - else - { - shard_type_ = shardByCol; - } - -#if __aarch64__ - if (conv_type_ == row_major) - { - if (shard_type_ == shardByRow) - { - mr_ = 8; - nr_ = 12; - } - else - { - mr_ = 12; - nr_ = 8; - } - } - else if (conv_type_ == col_major) - { -#ifndef BATCH_DILATION_FIX - mr_ = 12; - nr_ = 8; -#else // BATCH_DILATION_FIX - // TODO: batch(dilation) + inw * inh - if (out_mat_.n > 1) - { - mr_ = 24; - nr_ = 4; - } - else - { - mr_ = 12; - nr_ = 8; - } -#endif // BATCH_DILATION_FIX - } -#else // __aarch64__ - if (conv_type_ == row_major) - { - mr_ = 6; - nr_ = 8; - } - else if (conv_type_ == col_major) - { - mr_ = 8; - nr_ = 6; - } -#endif // __aarch64__ - - int k_div = (nr_ * sizeof_RhsScalar); - int k_sub = (mr_ * nr_ * sizeof_ResScalar); - - const int k_cache = MIN(divup((int)(L1_CACHE_SIZE - k_sub), (int)k_div), MAX_K); - bk_ = MIN(k_cache, k_); - - if (shard_type_ == shardByCol) - { - int m_sub = (bk_ * nr_ * sizeof_RhsScalar); - int m_cache = divup((L2_CACHE_SIZE - m_sub), (sizeof_LhsScalar * bk_ * 2)); - bm_ = MIN(m_cache, m_); - - bn_ = MIN(GEN_COL, n_); - if (L3_CACHE_SIZE) - { - int n_sub = (bk_ * bm_ * sizeof_RhsScalar); - int n_cache = divup((L3_CACHE_SIZE - n_sub), (sizeof_LhsScalar * bk_ * 2)); - bn_ = MIN(n_cache, bn_); - } - } - else - { - int n_sub = (bk_ * mr_ * sizeof_RhsScalar); - int n_cache = divup((L2_CACHE_SIZE - n_sub), (sizeof_LhsScalar * bk_ * 2)); - bn_ = MIN(n_cache, n_); - - bm_ = MIN(GEN_COL, m_); - if (L3_CACHE_SIZE) - { - int m_sub = (bk_ * bn_ * sizeof_RhsScalar); - int m_cache = divup((L3_CACHE_SIZE - m_sub), (sizeof_LhsScalar * bk_ * 2)); - bm_ = MIN(m_cache, bm_); - } - } - - nm_ = divup(m_, bm_); - nn_ = divup(n_, bn_); - nk_ = divup(k_, bk_); - - rm_ = m_ % bm_; - rn_ = n_ % bn_; - rk_ = k_ % bk_; -} - -conv_sgemm_singlethread::conv_sgemm_singlethread(const convMat_t &in_mat, - const convMat_t &weights_mat, convMat_t &out_mat, - const convParams_t &in_param, convType_t conv_type) - : in_mat_(in_mat), weights_mat_(weights_mat), out_mat_(out_mat), in_param_(in_param), - conv_type_(conv_type) -{ - m_ = out_mat_.c; -#ifdef NCNN - n_ = out_mat_.n * alignSize(out_mat_.h * out_mat_.w, 16 / sizeof(float)); -#else - n_ = out_mat_.n * out_mat_.w * out_mat_.h; -#endif - k_ = in_param_.kernel_h * in_param_.kernel_w * in_mat.c; - - param_init(); - - if (in_param_.kernel_w != 1 || in_param_.kernel_h != 1 || in_param_.stride_w != 1 || - in_param_.stride_h != 1 || in_param_.padding != 0 || out_mat_.n > 1) - { - need_im2col_ = 1; - } - else - { - need_im2col_ = 0; - } -} - -conv_sgemm_singlethread::~conv_sgemm_singlethread() {} - -void conv_sgemm_singlethread::run() -{ - int mstride = (bm_ + mr_ - 1) / mr_ * mr_; - int nstride = (bn_ + nr_ - 1) / nr_ * nr_; - - float *plhs_ptr = new float[mstride * bk_]; - float *prhs_ptr = new float[nstride * bk_]; - - if (conv_type_ == row_major) - { - if (shard_type_ == shardByCol) - { - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - - for (int l = 0; l < nk_; l++) - { - const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; - - if (need_im2col_) - { - if (out_mat_.n == 1) - { - _pack_rowmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_, - const_cast<convMat_t *>(&in_mat_), &out_mat_, - const_cast<convParams_t *>(&in_param_), prhs_ptr); - } - else - { - _pack_rowmajor_image_rhs_batch(nr_, bn, bk, l * bk_, j * bn_, - const_cast<convMat_t *>(&in_mat_), &out_mat_, - const_cast<convParams_t *>(&in_param_), prhs_ptr); - } - } - else - { - _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &in_mat_.data[l * bk_ * n_ + j * bn_], - prhs_ptr); - } - - for (int i = 0; i < nm_; i++) - { - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &weights_mat_.data[i * bm_ * k_ + l * bk_], - plhs_ptr); - - _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, - &out_mat_.data[i * bm_ * n_ + j * bn_], l, n_, bk); - } - } - } - } - else if (shard_type_ == shardByRow) - { - for (int i = 0; i < nm_; i++) - { - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - for (int l = 0; l < nk_; l++) - { - const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; - - _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &weights_mat_.data[i * bm_ * k_ + l * bk_], - plhs_ptr); - - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - - if (need_im2col_) - { - if (out_mat_.n == 1) - { - _pack_rowmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_, - const_cast<convMat_t *>(&in_mat_), &out_mat_, - const_cast<convParams_t *>(&in_param_), prhs_ptr); - } - else - { - _pack_rowmajor_image_rhs_batch(nr_, bn, bk, l * bk_, j * bn_, - const_cast<convMat_t *>(&in_mat_), &out_mat_, - const_cast<convParams_t *>(&in_param_), prhs_ptr); - } - } - else - { - _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &in_mat_.data[l * bk_ * n_ + j * bn_], - prhs_ptr); - } - - _sgemm_rowmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, - &out_mat_.data[i * bm_ * n_ + j * bn_], l, n_, bk); - } - } - } - } - else - { - throw std::runtime_error{"Error shrad type!"}; - } - } - else if (conv_type_ == col_major) - { - if (shard_type_ == shardByCol) - { - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - for (int l = 0; l < nk_; l++) - { - const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; - - if (need_im2col_) - { - if (out_mat_.n == 1) - { - _pack_colmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_, - const_cast<convMat_t *>(&in_mat_), &out_mat_, - const_cast<convParams_t *>(&in_param_), prhs_ptr); - } - else - { - _pack_colmajor_image_rhs_batch(nr_, bn, bk, l * bk_, j * bn_, - const_cast<convMat_t *>(&in_mat_), &out_mat_, - const_cast<convParams_t *>(&in_param_), prhs_ptr); - } - } - else - { - _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &in_mat_.data[j * bn_ * k_ + l * bk_], - prhs_ptr); - } - - for (int i = 0; i < nm_; i++) - { - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &weights_mat_.data[l * bk_ * m_ + i * bm_], - plhs_ptr); - - _sgemm_colmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, - &out_mat_.data[j * bn_ * m_ + i * bm_], l, m_, bk); - } - } - } - } - else if (shard_type_ == shardByRow) - { - for (int i = 0; i < nm_; i++) - { - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - for (int l = 0; l < nk_; l++) - { - const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; - - _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &weights_mat_.data[l * bk_ * m_ + i * bm_], - plhs_ptr); - - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - - if (need_im2col_) - { - if (out_mat_.n == 1) - { - _pack_colmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_, - const_cast<convMat_t *>(&in_mat_), &out_mat_, - const_cast<convParams_t *>(&in_param_), prhs_ptr); - } - else - { - _pack_colmajor_image_rhs_batch(nr_, bn, bk, l * bk_, j * bn_, - const_cast<convMat_t *>(&in_mat_), &out_mat_, - const_cast<convParams_t *>(&in_param_), prhs_ptr); - } - } - else - { - _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &in_mat_.data[j * bn_ * k_ + l * bk_], - prhs_ptr); - } - - _sgemm_colmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, - &out_mat_.data[j * bn_ * m_ + i * bm_], l, m_, bk); - } - } - } - } - else - { - throw std::runtime_error{"Error shrad type!"}; - } - } - else - { - throw std::runtime_error{"Error conv type!"}; - } - - delete[] plhs_ptr; - delete[] prhs_ptr; -} - -} // namespace srcn -} // namespace nnfw diff --git a/compute/ncnn/src/srcn/conv_sgemm_singlethread.h b/compute/ncnn/src/srcn/conv_sgemm_singlethread.h deleted file mode 100644 index 63f8b6e66..000000000 --- a/compute/ncnn/src/srcn/conv_sgemm_singlethread.h +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __NNFW_SRCN_CONV_SGEMM_SINGLETHREAD_H__ -#define __NNFW_SRCN_CONV_SGEMM_SINGLETHREAD_H__ - -#include "ncnn/srcn/conv_type.h" -#include "common.h" - -namespace nnfw -{ -namespace srcn -{ - -class conv_sgemm_singlethread -{ -public: - conv_sgemm_singlethread(const convMat_t &in_mat, const convMat_t &weights_mat, convMat_t &out_mat, - const convParams_t &in_param, convType_t conv_type); - ~conv_sgemm_singlethread(); - - void run(); - -private: - void param_init(); - - const convMat_t in_mat_; - const convMat_t weights_mat_; - convMat_t out_mat_; - const convParams_t in_param_; - convType_t conv_type_; - - int m_; - int n_; - int k_; - - int bm_; - int bn_; - int bk_; - - int rm_; - int rn_; - int rk_; - - int nm_; - int nn_; - int nk_; - - int mr_; - int nr_; - - int need_im2col_; - - shardType_t shard_type_; -}; - -} // namespace srcn -} // namespace nnfw - -#endif // __NNFW_SRCN_CONV_SGEMM_SINGLETHREAD_H__ diff --git a/compute/ncnn/src/srcn/conv_sparse.cc b/compute/ncnn/src/srcn/conv_sparse.cc deleted file mode 100644 index 10e2a2b93..000000000 --- a/compute/ncnn/src/srcn/conv_sparse.cc +++ /dev/null @@ -1,271 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifdef _OPENMP -#include <omp.h> -#endif - -#include <stdexcept> - -#include "common.h" -#include "sgemm_kernel.h" -#include "sgemm_pack.h" -#include "conv_sparse.h" - -namespace nnfw -{ -namespace srcn -{ - -void conv_sparse::param_init() -{ -#ifdef NCNN - n_ = alignSize(out_mat_.h * out_mat_.w, 16 / sizeof(float)); -#else - n_ = out_mat_.w * out_mat_.h; -#endif - - bch_ = BCH; - nch_ = (out_mat_.c + bch_ - 1) / bch_; - - rch_ = out_mat_.c % bch_; - - bn_ = MIN(n_, L1_CACHE_SIZE / (sizeof(float) * 2)); - bn_ = MIN(bn_, (L2_CACHE_SIZE / 2 - bch_ * sizeof(weight_data_t)) / ((bch_ + 1) * sizeof(float)) / - num_threads_); - nn_ = (n_ + bn_ - 1) / bn_; - rn_ = n_ % bn_; - - if (in_param_.kernel_w != 1 || in_param_.kernel_h != 1 || in_param_.stride_w != 1 || - in_param_.stride_h != 1 || in_param_.padding != 0) - { - need_im2col_ = 1; - } - else - { - need_im2col_ = 0; - } -} - -conv_sparse::conv_sparse(const convMat_t &in_mat, convMat_t &out_mat, const convParams_t &in_param, - const sparse_weight_t *weights, int num_threads, convType_t conv_type) - : in_mat_(in_mat), out_mat_(out_mat), in_param_(in_param), weights_(weights), - num_threads_(num_threads), conv_type_(conv_type) -{ - param_init(); -} - -conv_sparse::~conv_sparse() {} - -void conv_sparse::compute_singlethread() -{ - if (need_im2col_) - { - for (int i = 0; i < nch_; i++) - { - const sparse_weight_t *weight_ptr = weights_ + i; - const int mxk = weight_ptr->mxk; - float prhs_ptr[bn_]; - - for (int j = 0; j < nn_; j++) - { - int k = -1; - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - weight_data_t *lhs_ptr = weight_ptr->wdata; - - for (int l = 0; l < mxk; l++) - { - if (k != lhs_ptr->k) - { - k = lhs_ptr->k; - _sparse_pack_rowmajor_image(bn, k, j * bn_, const_cast<convMat_t *>(&in_mat_), - &out_mat_, const_cast<convParams_t *>(&in_param_), - prhs_ptr); - } - - // Why n_ = 64 x 64 is too much slower on Tizen??? - _sparse_sgemm_kernel(bn, lhs_ptr->data, prhs_ptr, - &out_mat_.data[lhs_ptr->m * n_ + j * bn_]); - - lhs_ptr++; - } - } - } - } - else - { - for (int i = 0; i < nch_; i++) - { - const sparse_weight_t *weight_ptr = weights_ + i; - const int mxk = weight_ptr->mxk; - - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - weight_data_t *lhs_ptr = weight_ptr->wdata; - float *rhs_ptr = in_mat_.data + j * bn_; - - for (int l = 0; l < mxk; l++) - { - // Why n_ = 64 x 64 is too much slower on Tizen??? - _sparse_sgemm_kernel(bn, lhs_ptr->data, rhs_ptr + lhs_ptr->k * n_, - &out_mat_.data[lhs_ptr->m * n_ + j * bn_]); - - lhs_ptr++; - } - } - } - } -} - -void conv_sparse::compute_multithreads() -{ - omp_set_num_threads(num_threads_); - - if (nch_ >= num_threads_ || nch_ >= nn_) - { - if (need_im2col_) - { -#pragma omp parallel for - for (int i = 0; i < nch_; i++) - { - const sparse_weight_t *weight_ptr = weights_ + i; - const int mxk = weight_ptr->mxk; - float prhs_ptr[bn_]; - - for (int j = 0; j < nn_; j++) - { - int k = -1; - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - weight_data_t *lhs_ptr = weight_ptr->wdata; - - for (int l = 0; l < mxk; l++) - { - if (k != lhs_ptr->k) - { - k = lhs_ptr->k; - _sparse_pack_rowmajor_image(bn, k, j * bn_, const_cast<convMat_t *>(&in_mat_), - &out_mat_, const_cast<convParams_t *>(&in_param_), - prhs_ptr); - } - - _sparse_sgemm_kernel(bn, lhs_ptr->data, prhs_ptr, - &out_mat_.data[lhs_ptr->m * n_ + j * bn_]); - - lhs_ptr++; - } - } - } - } - else - { -#pragma omp parallel for - for (int i = 0; i < nch_; i++) - { - const sparse_weight_t *weight_ptr = weights_ + i; - const int mxk = weight_ptr->mxk; - - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - weight_data_t *lhs_ptr = weight_ptr->wdata; - float *rhs_ptr = in_mat_.data + j * bn_; - - for (int l = 0; l < mxk; l++) - { - _sparse_sgemm_kernel(bn, lhs_ptr->data, rhs_ptr + lhs_ptr->k * n_, - &out_mat_.data[lhs_ptr->m * n_ + j * bn_]); - - lhs_ptr++; - } - } - } - } - } - else - { - if (need_im2col_) - { - for (int i = 0; i < nch_; i++) - { - const sparse_weight_t *weight_ptr = weights_ + i; - const int mxk = weight_ptr->mxk; - -#pragma omp parallel for - for (int j = 0; j < nn_; j++) - { - int k = -1; - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - weight_data_t *lhs_ptr = weight_ptr->wdata; - float prhs_ptr[bn]; - - for (int l = 0; l < mxk; l++) - { - if (k != lhs_ptr->k) - { - k = lhs_ptr->k; - _sparse_pack_rowmajor_image(bn, k, j * bn_, const_cast<convMat_t *>(&in_mat_), - &out_mat_, const_cast<convParams_t *>(&in_param_), - prhs_ptr); - } - - _sparse_sgemm_kernel(bn, lhs_ptr->data, prhs_ptr, - &out_mat_.data[lhs_ptr->m * n_ + j * bn_]); - - lhs_ptr++; - } - } - } - } - else - { - for (int i = 0; i < nch_; i++) - { - const sparse_weight_t *weight_ptr = weights_ + i; - const int mxk = weight_ptr->mxk; - -#pragma omp parallel for - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - weight_data_t *lhs_ptr = weight_ptr->wdata; - float *rhs_ptr = in_mat_.data + j * bn_; - - for (int l = 0; l < mxk; l++) - { - _sparse_sgemm_kernel(bn, lhs_ptr->data, rhs_ptr + lhs_ptr->k * n_, - &out_mat_.data[lhs_ptr->m * n_ + j * bn_]); - - lhs_ptr++; - } - } - } - } - } -} - -void conv_sparse::run() -{ - if (num_threads_ == 1) - compute_singlethread(); - else if (num_threads_ > 1) - compute_multithreads(); - else - throw std::runtime_error{"Invalid thread number."}; -} - -} // namespace srcn -} // namespace nnfw diff --git a/compute/ncnn/src/srcn/conv_sparse.h b/compute/ncnn/src/srcn/conv_sparse.h deleted file mode 100644 index 7ac358fd8..000000000 --- a/compute/ncnn/src/srcn/conv_sparse.h +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __NNFW_SRCN_CONV_SPARSE_H__ -#define __NNFW_SRCN_CONV_SPARSE_H__ - -#include "ncnn/srcn/conv_type.h" -#include "common.h" - -namespace nnfw -{ -namespace srcn -{ - -#define BCH 128 - -typedef struct -{ - short m; - short k; - float data; -} weight_data_t; - -typedef struct -{ - int mxk; - weight_data_t *wdata; -} sparse_weight_t; - -class conv_sparse -{ -public: - conv_sparse(const convMat_t &in_mat, convMat_t &out_mat, const convParams_t &in_param, - const sparse_weight_t *weights, int num_threads, convType_t conv_type); - ~conv_sparse(); - - void run(); - -private: - void param_init(); - void compute_singlethread(); - void compute_multithreads(); - - const convMat_t in_mat_; - convMat_t out_mat_; - const convParams_t in_param_; - const sparse_weight_t *weights_; - int num_threads_; - convType_t conv_type_; - - uint32_t n_; - uint32_t bn_; - int rn_; - int nn_; - - int bch_; - int rch_; - int nch_; - - int need_im2col_; -}; - -} // namespace srcn -} // namespace nnfw - -#endif // __NNFW_SRCN_CONV_SPARSE_H__ diff --git a/compute/ncnn/src/srcn/conv_winograd.cc b/compute/ncnn/src/srcn/conv_winograd.cc deleted file mode 100644 index 69649ea2a..000000000 --- a/compute/ncnn/src/srcn/conv_winograd.cc +++ /dev/null @@ -1,341 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common.h" -#include "conv_winograd.h" - -namespace std -{ -template <typename Dtype> static inline Dtype max(Dtype a, Dtype b) -{ - if (a > b) - return a; - else - return b; -} -} - -namespace nnfw -{ -namespace srcn -{ - -void conv_winograd::param_init() -{ - if ((in_param_.kernel_w != in_param_.kernel_h) || (in_param_.stride_w != in_param_.stride_h) || - (in_param_.kernel_w != 3 && in_param_.kernel_w != 5) || (in_param_.stride_w != 1) || - (!winograd_weight_)) - { - error_ = 1; - return; - } - - int M, N; - const int w = in_mat_.w; - const int h = in_mat_.h; - const int outw = out_mat_.w; - const int outh = out_mat_.h; - const int pad_w = in_param_.pad_w; - const int pad_h = in_param_.pad_h; - - if (in_param_.kernel_w == 3) - { - M = winograd_para_3x3s1::M; - N = winograd_para_3x3s1::N; - } - else - { - M = winograd_para_5x5s1::M; - N = winograd_para_5x5s1::N; - } - - tile_h_in_ = tile_w_in_ = M; - tile_h_out_ = tile_h_in_ - N + 1; - tile_w_out_ = tile_w_in_ - N + 1; - ntiles_h_ = (std::max(h + pad_h - tile_h_in_ + 1, outh) + tile_h_out_ - 1) / tile_h_out_; - ntiles_w_ = (std::max(w + pad_w - tile_w_in_ + 1, outw) + tile_w_out_ - 1) / tile_w_out_; - - error_ = 0; -} - -conv_winograd::conv_winograd(const convMat_t &in_mat, convMat_t &out_mat, - const convParams_t &in_param, convType_t conv_type, - const float *winograd_weight, int num_threads, int inc_stride, - int outc_stride, int c_stride) - : in_mat_(in_mat), out_mat_(out_mat), in_param_(in_param), conv_type_(conv_type), - winograd_weight_(winograd_weight), num_threads_(num_threads), inc_stride_(inc_stride), - outc_stride_(outc_stride), c_stride_(c_stride) - -{ - param_init(); -} - -conv_winograd::~conv_winograd() {} - -void conv_winograd::compute_sgemm(sgemmType_t major_type, sgemmTrans_t ltrans, sgemmTrans_t rtrans, - const int m, const int n, const int k, const float *lhs_data, - const float *rhs_data, float *res_data) -{ - class sgemm_singlethread sgemm(major_type, ltrans, rtrans, m, n, k, lhs_data, rhs_data, res_data, - num_threads_); - - sgemm.run(); -} - -void conv_winograd::winograd_input_im2col(float *col_buff) -{ - const int w = in_mat_.w; - const int h = in_mat_.h; - const float *data = in_mat_.data; - const int channels = in_mat_.c; - const int pad_w = in_param_.pad_w; - const int pad_h = in_param_.pad_h; - - if (conv_type_ == row_major) - { -#ifdef NCNN - const int n = alignSize(inc_stride_, 16 / sizeof(float)); -#else // NCNN - const int n = inc_stride_; -#endif // NCNN - for (int c = 0; c < channels; ++c) - { - for (int tile_h = 0; tile_h < ntiles_h_; ++tile_h) - { - for (int tile_w = 0; tile_w < ntiles_w_; ++tile_w) - { - for (int y = 0; y < tile_h_in_; ++y) - { - for (int x = 0; x < tile_w_in_; ++x) - { - int in_y = tile_h * tile_h_out_ + y - pad_h; - int in_x = tile_w * tile_w_out_ + x - pad_w; - - if (in_y < 0 || in_x < 0 || in_y >= h || in_x >= w) - { - col_buff[(((c * ntiles_h_ + tile_h) * ntiles_w_ + tile_w) * tile_h_in_ + y) * - tile_w_in_ + - x] = 0; - } - else - { - col_buff[(((c * ntiles_h_ + tile_h) * ntiles_w_ + tile_w) * tile_h_in_ + y) * - tile_w_in_ + - x] = data[c * n + in_y * w + in_x]; - } - } - } - } - } - } - } - else if (conv_type_ == col_major) - { - for (int tile_h = 0; tile_h < ntiles_h_; ++tile_h) - { - for (int tile_w = 0; tile_w < ntiles_w_; ++tile_w) - { - for (int y = 0; y < tile_h_in_; ++y) - { - for (int x = 0; x < tile_w_in_; ++x) - { - for (int c = 0; c < channels; ++c) - { - int in_y = tile_h * tile_h_out_ + y - pad_h; - int in_x = tile_w * tile_w_out_ + x - pad_w; - - if (in_y < 0 || in_x < 0 || in_y >= h || in_x >= w) - { - col_buff[(((c * ntiles_h_ + tile_h) * ntiles_w_ + tile_w) * tile_h_in_ + y) * - tile_w_in_ + - x] = 0; - } - else - { - col_buff[(((c * ntiles_h_ + tile_h) * ntiles_w_ + tile_w) * tile_h_in_ + y) * - tile_w_in_ + - x] = data[c + (in_y * w + in_x) * channels]; - } - } - } - } - } - } - } -} - -void conv_winograd::winograd_output_col2im(const float *col_buff) -{ - int outh = out_mat_.h; - int outw = out_mat_.w; - float *data = out_mat_.data; - int channels = out_mat_.c; - - if (conv_type_ == row_major) - { -#ifdef NCNN - const int n = alignSize(outc_stride_, 16 / sizeof(float)); -#else // NCNN - const int n = outc_stride_; -#endif // NCNN - for (int c = 0; c < channels; ++c) - { - for (int tile_h = 0; tile_h < ntiles_h_; ++tile_h) - { - for (int tile_w = 0; tile_w < ntiles_w_; ++tile_w) - { - for (int y = 0; y < tile_h_out_; ++y) - { - for (int x = 0; x < tile_w_out_; ++x) - { - int out_y = tile_h * tile_h_out_ + y; - int out_x = tile_w * tile_w_out_ + x; - if (out_y < outh && out_x < outw) - { - data[c * n + out_y * outw + out_x] = - col_buff[(((c * ntiles_h_ + tile_h) * ntiles_w_ + tile_w) * tile_h_out_ + y) * - tile_w_out_ + - x]; - } - } - } - } - } - } - } - else if (conv_type_ == col_major) - { - for (int tile_h = 0; tile_h < ntiles_h_; ++tile_h) - { - for (int tile_w = 0; tile_w < ntiles_w_; ++tile_w) - { - for (int y = 0; y < tile_h_out_; ++y) - { - for (int x = 0; x < tile_w_out_; ++x) - { - for (int c = 0; c < channels; ++c) - { - int out_y = tile_h * tile_h_out_ + y; - int out_x = tile_w * tile_w_out_ + x; - if (out_y < outh && out_x < outw) - { - data[c + (out_y * outw + out_x) * c_stride_] = - col_buff[(((c * ntiles_h_ + tile_h) * ntiles_w_ + tile_w) * tile_h_out_ + y) * - tile_w_out_ + - x]; - } - } - } - } - } - } - } -} - -void conv_winograd::compute_winograd() -{ - // const int w = in_mat_.w; - // const int h = in_mat_.h; - const int inch = in_mat_.c; - // const int outw = out_mat_.w; - // const int outh = out_mat_.h; - const int outch = out_mat_.c; - const int kernel_size = in_param_.kernel_w; - - int M, N; - const double *A; - const double *B; - - if (kernel_size == 3) - { - M = winograd_para_3x3s1::M; - N = winograd_para_3x3s1::N; - B = winograd_para_3x3s1::getB(); - A = winograd_para_3x3s1::getA(); - } - else - { - M = winograd_para_5x5s1::M; - N = winograd_para_5x5s1::N; - B = winograd_para_5x5s1::getB(); - A = winograd_para_5x5s1::getA(); - } - - /*Step 2: transfer image to winograd domain*/ - float *col_buff = - new float[std::max(outch, inch) * ntiles_h_ * ntiles_w_ * tile_h_in_ * tile_w_in_]; - - int temp1_n = inch * ntiles_h_ * ntiles_w_; - float *temp1_ = - new float[tile_h_in_ * tile_w_in_ * std::max(outch, inch) * ntiles_h_ * ntiles_w_]; - - float *winograd_b = new float[M * M * M * M]; - - if ((NULL == col_buff) || (NULL == temp1_) || (NULL == winograd_b)) - { - delete[] col_buff; - delete[] temp1_; - delete[] winograd_b; - return; - } - - winograd_input_im2col(col_buff); - - kronecker_product(winograd_b, B, B, M, M, M, M); - - compute_sgemm(rowMajor, trans, trans, tile_h_in_ * tile_w_in_, temp1_n, tile_h_in_ * tile_w_in_, - winograd_b, col_buff, temp1_); - - delete[] winograd_b; - - /*Step 3: convolution in winograd domain*/ - for (int j = 0; j < tile_h_in_ * tile_w_in_; ++j) - { - compute_sgemm(rowMajor, notrans, notrans, outch, ntiles_h_ * ntiles_w_, inch, - winograd_weight_ + j * c_stride_ * inch, - temp1_ + j * inch * ntiles_h_ * ntiles_w_, - col_buff + j * outch * ntiles_h_ * ntiles_w_); - } - - /*Step 4: transfer back to time domain*/ - float *winograd_a = new float[M * (M - N + 1) * M * (M - N + 1)]; - if (NULL == winograd_a) - { - delete[] col_buff; - delete[] temp1_; - return; - } - kronecker_product(winograd_a, A, A, M, M - N + 1, M, M - N + 1); - compute_sgemm(rowMajor, trans, notrans, outch * ntiles_h_ * ntiles_w_, tile_h_out_ * tile_w_out_, - tile_h_in_ * tile_w_in_, col_buff, winograd_a, temp1_); - delete[] winograd_a; - delete[] col_buff; - - winograd_output_col2im(temp1_); - - delete[] temp1_; -} - -void conv_winograd::run() -{ - if (error_) - return; - - compute_winograd(); -} - -} // namespace srcn -} // namespace nnfw diff --git a/compute/ncnn/src/srcn/conv_winograd.h b/compute/ncnn/src/srcn/conv_winograd.h deleted file mode 100644 index 76c2601f2..000000000 --- a/compute/ncnn/src/srcn/conv_winograd.h +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __NNFW_SRCN_CONV_WINOGRAD_H__ -#define __NNFW_SRCN_CONV_WINOGRAD_H__ - -#include "ncnn/srcn/conv_type.h" -#include "winograd.h" -#include "sgemm_singlethread.h" - -namespace nnfw -{ -namespace srcn -{ - -class conv_winograd -{ -public: - conv_winograd(const convMat_t &in_mat, convMat_t &out_mat, const convParams_t &in_param, - convType_t conv_type, const float *winograd_weight, int num_threads, int inc_stride, - int outc_stride, int c_stride); - ~conv_winograd(); - - void run(); - -private: - void param_init(); - void compute_sgemm(sgemmType_t major_type, sgemmTrans_t ltrans, sgemmTrans_t rtrans, const int m, - const int n, const int k, const float *lhs_data, const float *rhs_data, - float *res_data); - void winograd_input_im2col(float *col_buff); - void winograd_output_col2im(const float *col_buff); - void compute_winograd(); - - const convMat_t in_mat_; - convMat_t out_mat_; - const convParams_t in_param_; - convType_t conv_type_; - const float *winograd_weight_; - const int num_threads_; - - int tile_w_in_; - int tile_h_in_; - int tile_w_out_; - int tile_h_out_; - int ntiles_w_; - int ntiles_h_; - - int inc_stride_; - int outc_stride_; - int c_stride_; - - int error_; -}; - -} // namespace srcn -} // namespace nnfw - -#endif // __NNFW_SRCN_CONV_WINOGRAD_H__ diff --git a/compute/ncnn/src/srcn/conv_winograd_batch.cc b/compute/ncnn/src/srcn/conv_winograd_batch.cc deleted file mode 100644 index cba45c648..000000000 --- a/compute/ncnn/src/srcn/conv_winograd_batch.cc +++ /dev/null @@ -1,304 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common.h" -#include "conv_winograd_batch.h" - -namespace std -{ -template <typename Dtype> static inline Dtype max(Dtype a, Dtype b) -{ - if (a > b) - return a; - else - return b; -} -} - -namespace nnfw -{ -namespace srcn -{ - -void conv_winograd_batch::param_init() -{ - if ((in_param_.kernel_w != in_param_.kernel_h) || (in_param_.stride_w != in_param_.stride_h) || - (in_param_.kernel_w != 3 && in_param_.kernel_w != 5) || (in_param_.stride_w != 1) || - (!winograd_weight_)) - { - error_ = 1; - return; - } - - int M, N; - const int w = in_mat_.w; - const int h = in_mat_.h; - const int outw = out_mat_.w; - const int outh = out_mat_.h; - const int pad_w = in_param_.pad_w; - const int pad_h = in_param_.pad_h; - - if (in_param_.kernel_w == 3) - { - if (w == 4) - { - M = winograd_para_3x3s1_2::M; - N = winograd_para_3x3s1_2::N; - } - else - { - M = winograd_para_3x3s1::M; - N = winograd_para_3x3s1::N; - } - } - else - { - M = winograd_para_5x5s1::M; - N = winograd_para_5x5s1::N; - } - - tile_h_in_ = tile_w_in_ = M; - tile_h_out_ = tile_h_in_ - N + 1; - tile_w_out_ = tile_w_in_ - N + 1; - ntiles_h_ = (std::max(h + pad_h - tile_h_in_ + 1, outh) + tile_h_out_ - 1) / tile_h_out_; - ntiles_w_ = (std::max(w + pad_w - tile_w_in_ + 1, outw) + tile_w_out_ - 1) / tile_w_out_; - - error_ = 0; -} - -conv_winograd_batch::conv_winograd_batch(const convMat_t &in_mat, convMat_t &out_mat, - const convParams_t &in_param, convType_t conv_type, - const float *winograd_weight, int num_threads) - : in_mat_(in_mat), out_mat_(out_mat), in_param_(in_param), conv_type_(conv_type), - winograd_weight_(winograd_weight), num_threads_(num_threads) -{ - param_init(); -} - -conv_winograd_batch::~conv_winograd_batch() {} - -void conv_winograd_batch::compute_sgemm(sgemmType_t major_type, sgemmTrans_t ltrans, - sgemmTrans_t rtrans, const int m, const int n, const int k, - const float *lhs_data, const float *rhs_data, - float *res_data) -{ - class sgemm_singlethread sgemm(major_type, ltrans, rtrans, m, n, k, lhs_data, rhs_data, res_data, - num_threads_); - - sgemm.run(); -} - -void conv_winograd_batch::winograd_input_im2col(float *col_buff) -{ - const int w = in_mat_.w; - const int h = in_mat_.h; - const float *data = in_mat_.data; - const int channels = in_mat_.c; - const int batch = in_mat_.n; - const int pad_w = in_param_.pad_w; - const int pad_h = in_param_.pad_h; - - // TODO: row_major - if (conv_type_ == col_major) - { - for (int n = 0; n < batch; n++) - { - for (int tile_h = 0; tile_h < ntiles_h_; ++tile_h) - { - for (int tile_w = 0; tile_w < ntiles_w_; ++tile_w) - { - for (int y = 0; y < tile_h_in_; ++y) - { - for (int x = 0; x < tile_w_in_; ++x) - { - for (int c = 0; c < channels; ++c) - { - int in_y = tile_h * tile_h_out_ + y - pad_h; - int in_x = tile_w * tile_w_out_ + x - pad_w; - - if (in_y < 0 || in_x < 0 || in_y >= h || in_x >= w) - { - col_buff[((((c * batch + n) * ntiles_h_ + tile_h) * ntiles_w_ + tile_w) * - tile_h_in_ + - y) * - tile_w_in_ + - x] = 0; - } - else - { - col_buff[((((c * batch + n) * ntiles_h_ + tile_h) * ntiles_w_ + tile_w) * - tile_h_in_ + - y) * - tile_w_in_ + - x] = data[((n * h + in_y) * w + in_x) * channels + c]; - } - } - } - } - } - } - } - } -} - -void conv_winograd_batch::winograd_output_col2im(const float *col_buff) -{ - int outh = out_mat_.h; - int outw = out_mat_.w; - float *data = out_mat_.data; - int channels = out_mat_.c; - int batch = out_mat_.n; - - // TODO: row_major - if (conv_type_ == col_major) - { - for (int n = 0; n < batch; n++) - { - for (int tile_h = 0; tile_h < ntiles_h_; ++tile_h) - { - for (int tile_w = 0; tile_w < ntiles_w_; ++tile_w) - { - for (int y = 0; y < tile_h_out_; ++y) - { - for (int x = 0; x < tile_w_out_; ++x) - { - for (int c = 0; c < channels; ++c) - { - int out_y = tile_h * tile_h_out_ + y; - int out_x = tile_w * tile_w_out_ + x; - if (out_y < outh && out_x < outw) - { - data[((n * outh + out_y) * outw + out_x) * channels + c] = - col_buff[((((c * batch + n) * ntiles_h_ + tile_h) * ntiles_w_ + tile_w) * - tile_h_out_ + - y) * - tile_w_out_ + - x]; - } - } - } - } - } - } - } - } -} - -void conv_winograd_batch::compute_winograd() -{ - const int w = in_mat_.w; - // const int h = in_mat_.h; - const int inch = in_mat_.c; - // const int outw = out_mat_.w; - // const int outh = out_mat_.h; - const int outch = out_mat_.c; - const int kernel_size = in_param_.kernel_w; - const int batch = in_mat_.n; - - int M, N; - const double *A; - const double *B; - - if (kernel_size == 3) - { - if (w == 4) - { - M = winograd_para_3x3s1_2::M; - N = winograd_para_3x3s1_2::N; - B = winograd_para_3x3s1_2::getB(); - A = winograd_para_3x3s1_2::getA(); - } - else - { - M = winograd_para_3x3s1::M; - N = winograd_para_3x3s1::N; - B = winograd_para_3x3s1::getB(); - A = winograd_para_3x3s1::getA(); - } - } - else - { - M = winograd_para_5x5s1::M; - N = winograd_para_5x5s1::N; - B = winograd_para_5x5s1::getB(); - A = winograd_para_5x5s1::getA(); - } - - /*Step 2: transfer image to winograd domain*/ - float *col_buff = - new float[std::max(outch, inch) * batch * ntiles_h_ * ntiles_w_ * tile_h_in_ * tile_w_in_]; - - int temp1_n = batch * inch * ntiles_h_ * ntiles_w_; - float *temp1_ = - new float[batch * tile_h_in_ * tile_w_in_ * std::max(outch, inch) * ntiles_h_ * ntiles_w_]; - - float *winograd_b = new float[M * M * M * M]; - - if ((NULL == col_buff) || (NULL == temp1_) || (NULL == winograd_b)) - { - delete[] col_buff; - delete[] temp1_; - delete[] winograd_b; - return; - } - - winograd_input_im2col(col_buff); - - kronecker_product(winograd_b, B, B, M, M, M, M); - - compute_sgemm(rowMajor, trans, trans, tile_h_in_ * tile_w_in_, temp1_n, tile_h_in_ * tile_w_in_, - winograd_b, col_buff, temp1_); - delete[] winograd_b; - - /*Step 3: convolution in winograd domain*/ - for (int j = 0; j < tile_h_in_ * tile_w_in_; ++j) - { - compute_sgemm(rowMajor, notrans, notrans, outch, batch * ntiles_h_ * ntiles_w_, inch, - winograd_weight_ + j * outch * inch, - temp1_ + j * batch * inch * ntiles_h_ * ntiles_w_, - col_buff + j * batch * outch * ntiles_h_ * ntiles_w_); - } - - /*Step 4: transfer back to time domain*/ - float *winograd_a = new float[M * (M - N + 1) * M * (M - N + 1)]; - if (NULL == winograd_a) - { - delete[] col_buff; - delete[] temp1_; - return; - } - - kronecker_product(winograd_a, A, A, M, M - N + 1, M, M - N + 1); - compute_sgemm(rowMajor, trans, notrans, batch * outch * ntiles_h_ * ntiles_w_, - tile_h_out_ * tile_w_out_, tile_h_in_ * tile_w_in_, col_buff, winograd_a, temp1_); - delete[] winograd_a; - delete[] col_buff; - - winograd_output_col2im(temp1_); - - delete[] temp1_; -} - -void conv_winograd_batch::run() -{ - if (error_) - return; - - compute_winograd(); -} - -} // namespace srcn -} // namespace nnfw diff --git a/compute/ncnn/src/srcn/conv_winograd_batch.h b/compute/ncnn/src/srcn/conv_winograd_batch.h deleted file mode 100644 index a022d9c52..000000000 --- a/compute/ncnn/src/srcn/conv_winograd_batch.h +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __NNFW_SRCN_CONV_WINOGRAD_BATCH_H__ -#define __NNFW_SRCN_CONV_WINOGRAD_BATCH_H__ - -#include "ncnn/srcn/conv_type.h" -#include "winograd.h" -#include "sgemm_singlethread.h" - -namespace nnfw -{ -namespace srcn -{ - -class conv_winograd_batch -{ -public: - conv_winograd_batch(const convMat_t &in_mat, convMat_t &out_mat, const convParams_t &in_param, - convType_t conv_type, const float *winograd_weight, int num_threads); - ~conv_winograd_batch(); - - void run(); - -private: - void param_init(); - void compute_sgemm(sgemmType_t major_type, sgemmTrans_t ltrans, sgemmTrans_t rtrans, const int m, - const int n, const int k, const float *lhs_data, const float *rhs_data, - float *res_data); - void winograd_input_im2col(float *col_buff); - void winograd_output_col2im(const float *col_buff); - void compute_winograd(); - - const convMat_t in_mat_; - convMat_t out_mat_; - const convParams_t in_param_; - convType_t conv_type_; - const float *winograd_weight_; - const int num_threads_; - - int tile_w_in_; - int tile_h_in_; - int tile_w_out_; - int tile_h_out_; - int ntiles_w_; - int ntiles_h_; - - int error_; -}; - -} // namespace srcn -} // namespace nnfw - -#endif // __NNFW_SRCN_CONV_WINOGRAD_BATCH_H__ diff --git a/compute/ncnn/src/srcn/deconv_sgemm_multithreads.cc b/compute/ncnn/src/srcn/deconv_sgemm_multithreads.cc deleted file mode 100644 index f3ccf13e5..000000000 --- a/compute/ncnn/src/srcn/deconv_sgemm_multithreads.cc +++ /dev/null @@ -1,387 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifdef _OPENMP -#include <omp.h> -#endif - -#include "common.h" -#include "sgemm_kernel.h" -#include "sgemm_pack.h" -#include "deconv_sgemm_multithreads.h" - -namespace nnfw -{ -namespace srcn -{ - -void deconv_sgemm_multithreads::param_init() -{ -#if __aarch64__ - if (conv_type_ == row_major) - { - mr_ = 8; - nr_ = 12; - } - else if (conv_type_ == col_major) - { - - mr_ = 12; - nr_ = 8; - } -#else // __aarch64__ - if (conv_type_ == row_major) - { - mr_ = 6; - nr_ = 8; - } - else if (conv_type_ == col_major) - { - mr_ = 8; - nr_ = 6; - } -#endif // __aarch64__ - - int col = n_; - - if (m_ > n_) - { - shard_type_ = shardByRow; - col = m_; - } - else - { - shard_type_ = shardByCol; - } - - int th_base = divup(col, num_threads_); - - th_base = MIN(MAX(th_base, MIN_COL), MAX_COL); - - int k_div = (nr_ * sizeof_RhsScalar); - int k_sub = (mr_ * nr_ * sizeof_ResScalar); - - const int k_cache = MIN(divup((int)(L1_CACHE_SIZE - k_sub), (int)k_div * 2), MAX_K); - bk_ = MIN(k_cache, k_); - - if (shard_type_ == shardByCol) - { - int m_sub = (bk_ * nr_ * sizeof_RhsScalar); - int m_div = (sizeof_LhsScalar * bk_ * 2 * num_threads_); - if (L3_CACHE_SIZE) - m_div = (sizeof_LhsScalar * bk_ * 2); - int m_cache = divup((L2_CACHE_SIZE - m_sub), m_div); - bm_ = MIN(m_cache, m_); - - bn_ = MIN(th_base, n_); - if (L3_CACHE_SIZE) - { - int n_sub = (bk_ * bm_ * sizeof_RhsScalar); - int n_div = (sizeof_LhsScalar * bk_ * 2 * num_threads_); - int n_cache = divup((L3_CACHE_SIZE - n_sub), n_div); - bn_ = MIN(n_cache, bn_); - } - } - else - { - int n_sub = (bk_ * mr_ * sizeof_LhsScalar); - int n_div = (sizeof_LhsScalar * bk_ * 2 * num_threads_); - if (L3_CACHE_SIZE) - n_div = (sizeof_LhsScalar * bk_ * 2); - int n_cache = divup((L2_CACHE_SIZE - n_sub), n_div); - bn_ = MIN(n_cache, n_); - - bm_ = MIN(th_base, m_); - if (L3_CACHE_SIZE) - { - int m_sub = (bk_ * bn_ * sizeof_RhsScalar); - int m_div = (sizeof_LhsScalar * bk_ * 2 * num_threads_); - int m_cache = divup((L3_CACHE_SIZE - m_sub), m_div); - bm_ = MIN(m_cache, bm_); - } - } - - nm_ = divup(m_, bm_); - nn_ = divup(n_, bn_); - nk_ = divup(k_, bk_); - - rm_ = m_ % bm_; - rn_ = n_ % bn_; - rk_ = k_ % bk_; -} - -deconv_sgemm_multithreads::deconv_sgemm_multithreads(const convMat_t &in_mat, - const convMat_t &weights_mat, - convMat_t &out_mat, - const convParams_t &in_param, int num_threads, - convType_t conv_type) - - : in_mat_(in_mat), weights_mat_(weights_mat), out_mat_(out_mat), in_param_(in_param), - conv_type_(conv_type), num_threads_(num_threads) -{ - m_ = in_param_.kernel_h * in_param_.kernel_w * out_mat_.c; -#ifdef NCNN - n_ = alignSize(in_mat_.h * in_mat_.w, 16 / sizeof(float)); -#else // NCNN - n_ = in_mat_.w * in_mat_.h; -#endif // NCNN - k_ = in_mat.c; - - param_init(); - - int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; - int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; - - if (shard_type_ == shardByCol) - { - plhs_buffer_ = new float[lhs_stride * 1 * nm_]; - prhs_buffer_ = new float[rhs_stride * num_threads_]; - } - else - { - plhs_buffer_ = new float[lhs_stride * num_threads_]; - prhs_buffer_ = new float[rhs_stride * 1 * nn_]; - } - - pres_buffer_ = new float[bm_ * bn_ * num_threads_]; - - if (plhs_buffer_ == NULL || prhs_buffer_ == NULL || pres_buffer_ == NULL) - { - error_ = 1; - } - - if (in_param_.kernel_w != 1 || in_param_.kernel_h != 1 || in_param_.stride_w != 1 || - in_param_.stride_h != 1 || in_param_.padding != 0) - { - need_col2im_ = 1; - } - else - { - need_col2im_ = 0; - } - - omp_set_num_threads(num_threads_); - - error_ = 0; -} - -deconv_sgemm_multithreads::~deconv_sgemm_multithreads() -{ - if (plhs_buffer_) - delete[] plhs_buffer_; - if (prhs_buffer_) - delete[] prhs_buffer_; - if (pres_buffer_) - delete[] pres_buffer_; -} - -void deconv_sgemm_multithreads::run() -{ - if (error_) - return; - - if (shard_type_ == shardByCol && conv_type_ == col_major) - { - compute_colmajor_colshard(); - } - else if (shard_type_ == shardByRow && conv_type_ == col_major) - { - compute_colmajor_rowshard(); - } - else if (shard_type_ == shardByCol && conv_type_ == row_major) - { - compute_rowmajor_colshard(); - } - else if (shard_type_ == shardByRow && conv_type_ == row_major) - { - compute_rowmajor_rowshard(); - } -} - -void deconv_sgemm_multithreads::compute_rowmajor_colshard() -{ - int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; - int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; - - for (int l = 0; l < nk_; l++) - { - const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; - -#pragma omp parallel for - for (int i = 0; i < nm_; i++) - { - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - _pack_rowmajor_trans_lhs(mr_, bm, bk, m_, &weights_mat_.data[l * bk_ * m_ + i * bm_], - &plhs_buffer_[i * lhs_stride]); - } - -#pragma omp parallel for - for (int j = 0; j < nn_; j++) - { - int thread_num = omp_get_thread_num(); - float *prhs_ptr = &prhs_buffer_[rhs_stride * thread_num]; - float *pres_ptr = &pres_buffer_[bm_ * bn_ * thread_num]; - - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &in_mat_.data[l * bk_ * n_ + j * bn_], prhs_ptr); - - for (int i = 0; i < nm_; i++) - { - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, &plhs_buffer_[i * lhs_stride], - prhs_ptr, pres_ptr, 0, bn, bk); - - if (need_col2im_) - _unpack_rowmajor_image_res(bm, bn, i * bm_, j * bn_, const_cast<convMat_t *>(&in_mat_), - &out_mat_, const_cast<convParams_t *>(&in_param_), pres_ptr); - } - } - } -} - -void deconv_sgemm_multithreads::compute_rowmajor_rowshard() -{ - int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; - int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; - - for (int l = 0; l < nk_; l++) - { - const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; - -#pragma omp parallel for - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &in_mat_.data[l * bk_ * n_ + j * bn_], - &prhs_buffer_[j * rhs_stride]); - } - -#pragma omp parallel for - for (int i = 0; i < nm_; i++) - { - int thread_num = omp_get_thread_num(); - float *plhs_ptr = &plhs_buffer_[lhs_stride * thread_num]; - float *pres_ptr = &pres_buffer_[bm_ * bn_ * thread_num]; - - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - _pack_rowmajor_trans_lhs(mr_, bm, bk, m_, &weights_mat_.data[l * bk_ * m_ + i * bm_], - plhs_ptr); - - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - _sgemm_rowmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, - &prhs_buffer_[j * rhs_stride], pres_ptr, 0, bn, bk); - if (need_col2im_) - _unpack_rowmajor_image_res(bm, bn, i * bm_, j * bn_, const_cast<convMat_t *>(&in_mat_), - &out_mat_, const_cast<convParams_t *>(&in_param_), pres_ptr); - } - } - } -} - -void deconv_sgemm_multithreads::compute_colmajor_colshard() -{ - int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; - int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; - - for (int l = 0; l < nk_; l++) - { - const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; - -#pragma omp parallel for - for (int i = 0; i < nm_; i++) - { - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - _pack_colmajor_trans_lhs(mr_, bm, bk, k_, &weights_mat_.data[i * bm_ * k_ + l * bk_], - &plhs_buffer_[i * lhs_stride]); - } - -#pragma omp parallel for - for (int j = 0; j < nn_; j++) - { - int thread_num = omp_get_thread_num(); - float *prhs_ptr = &prhs_buffer_[rhs_stride * thread_num]; - float *pres_ptr = &pres_buffer_[bm_ * bn_ * thread_num]; - - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &in_mat_.data[j * bn_ * k_ + l * bk_], prhs_ptr); - - for (int i = 0; i < nm_; i++) - { - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - _sgemm_colmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, &plhs_buffer_[i * lhs_stride], - prhs_ptr, pres_ptr, 0, bm, bk); - - // Need to add lock? - if (need_col2im_) - _unpack_colmajor_image_res(bm, bn, i * bm_, j * bn_, const_cast<convMat_t *>(&in_mat_), - &out_mat_, const_cast<convParams_t *>(&in_param_), pres_ptr); - } - } - } -} - -void deconv_sgemm_multithreads::compute_colmajor_rowshard() -{ - int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; - int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; - - for (int l = 0; l < nk_; l++) - { - const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; - -#pragma omp parallel for - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &in_mat_.data[j * bn_ * k_ + l * bk_], - &prhs_buffer_[j * rhs_stride]); - } - -#pragma omp parallel for - for (int i = 0; i < nm_; i++) - { - int thread_num = omp_get_thread_num(); - float *plhs_ptr = &plhs_buffer_[lhs_stride * thread_num]; - float *pres_ptr = &pres_buffer_[bm_ * bn_ * thread_num]; - - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - _pack_colmajor_trans_lhs(mr_, bm, bk, k_, &weights_mat_.data[i * bm_ * k_ + l * bk_], - plhs_ptr); - - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - _sgemm_colmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, - &prhs_buffer_[j * rhs_stride], pres_ptr, 0, bm, bk); - - if (need_col2im_) - _unpack_colmajor_image_res(bm, bn, i * bm_, j * bn_, const_cast<convMat_t *>(&in_mat_), - &out_mat_, const_cast<convParams_t *>(&in_param_), pres_ptr); - } - } - } -} - -} // namespace srcn -} // namespace nnfw diff --git a/compute/ncnn/src/srcn/deconv_sgemm_multithreads.h b/compute/ncnn/src/srcn/deconv_sgemm_multithreads.h deleted file mode 100644 index 762f20380..000000000 --- a/compute/ncnn/src/srcn/deconv_sgemm_multithreads.h +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __NNFW_SRCN_DECONV_SGEMM_MULTITHREADS_H__ -#define __NNFW_SRCN_DECONV_SGEMM_MULTITHREADS_H__ - -#include "ncnn/srcn/conv_type.h" -#include "common.h" - -namespace nnfw -{ -namespace srcn -{ - -class deconv_sgemm_multithreads -{ -public: - deconv_sgemm_multithreads(const convMat_t &in_mat, const convMat_t &weights_mat, - convMat_t &out_mat, const convParams_t &in_param, int num_threads, - convType_t conv_type); - ~deconv_sgemm_multithreads(); - - void run(); - -private: - void param_init(); - - void compute_rowmajor_colshard(); - void compute_rowmajor_rowshard(); - void compute_colmajor_colshard(); - void compute_colmajor_rowshard(); - - const convMat_t in_mat_; - const convMat_t weights_mat_; - convMat_t out_mat_; - const convParams_t in_param_; - convType_t conv_type_; - const int num_threads_; - - int m_; - int n_; - int k_; - - int bm_; - int bn_; - int bk_; - - int rm_; - int rn_; - int rk_; - - int nm_; - int nn_; - int nk_; - - int mr_; - int nr_; - - int need_col2im_; - shardType_t shard_type_; - - float *prhs_buffer_; - float *plhs_buffer_; - float *pres_buffer_; - - int error_; -}; - -} // namespace srcn -} // namespace nnfw - -#endif // __NNFW_SRCN_DECONV_SGEMM_MULTITHREADS_H__ diff --git a/compute/ncnn/src/srcn/depthwise_conv.cc b/compute/ncnn/src/srcn/depthwise_conv.cc deleted file mode 100644 index cd092d5ac..000000000 --- a/compute/ncnn/src/srcn/depthwise_conv.cc +++ /dev/null @@ -1,2684 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifdef _OPENMP -#include <omp.h> -#endif - -#include <arm_neon.h> -#include <stdlib.h> -#include <string.h> - -#include "common.h" -#include "ncnn/srcn/conv_type.h" - -namespace nnfw -{ -namespace srcn -{ - -static void depthwise_conv3x3S1_nopad(const convMat_t &in_mat, convMat_t &out_mat, - const convMat_t &kernel, const convMat_t &bias) -{ -#if !__aarch64__ - int w = in_mat.w; - int h = in_mat.h; - int outw = out_mat.w; - int outh = out_mat.h; - int channels = in_mat.c; - -#pragma omp parallel for - for (int c = 0; c < channels; c++) - { - const float *filter = kernel.data + c * 9; -#ifdef NCNN - float *inbuf = in_mat.data + c * alignSize(w * h, 16 / sizeof(float)); - float *outbuf = out_mat.data + c * alignSize(outw * outh, 16 / sizeof(float)); -#else // NCNN - float *inbuf = in_mat.data + c * w * h; - float *outbuf = out_mat.data + c * outw * outh; -#endif // NCNN - float bias0 = bias.data ? bias.data[c] : 0.0f; - - register float32x4_t weight012 asm("q4") = vld1q_f32(filter); - register float32x4_t weight345 asm("q5") = vld1q_f32(filter + 3); - register float32x4_t weight678 asm("q6") = vld1q_f32(filter + 6); - register float32x4_t qbias0 asm("q7") = vdupq_n_f32(bias0); - - float *in_ptr0 = inbuf + 0 * w; - float *in_ptr1 = inbuf + 1 * w; - float *in_ptr2 = inbuf + 2 * w; - float *in_ptr3 = inbuf + 3 * w; - - float *out_ptr0 = outbuf + 0 * outw; - float *out_ptr1 = outbuf + 1 * outw; - - int i; - for (i = 0; i + 1 < outh; i += 2) - { - int nn = (outw >> 2) - 1; - int remain = outw & 0x03; - - if (nn > 0) - { - __asm __volatile("pld [%[in_ptr0], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" - "vext.32 q2, q0, q1, #1\n" - "vext.32 q3, q0, q1, #2\n" - "add %[in_ptr0], %[in_ptr0], #16\n" - - "1:\n" - "add %[in_ptr0], %[in_ptr0], #16\n" - "vand q14, %q[qbias0], %q[qbias0]\n" - "vmul.f32 q10, q0, %e[weight012][0]\n" - "vmul.f32 q11, q2, %e[weight012][1]\n" - - "pld [%[in_ptr1], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" - "vmla.f32 q14, q3, %f[weight012][0]\n" - "vext.32 q2, q0, q1, #1\n" - "vext.32 q3, q0, q1, #2\n" - "add %[in_ptr1], %[in_ptr1], #16\n" - - "vand q15, %q[qbias0], %q[qbias0]\n" - "vmla.f32 q10, q0, %e[weight345][0]\n" - "vmla.f32 q11, q2, %e[weight345][1]\n" - "vmul.f32 q12, q0, %e[weight012][0]\n" - "vmul.f32 q13, q2, %e[weight012][1]\n" - - "pld [%[in_ptr2], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr2]]\n" - "vmla.f32 q14, q3, %f[weight345][0]\n" - "vmla.f32 q15, q3, %f[weight012][0]\n" - "vext.32 q2, q0, q1, #1\n" - "vext.32 q3, q0, q1, #2\n" - "add %[in_ptr2], %[in_ptr2], #16\n" - - "vmla.f32 q10, q0, %e[weight678][0]\n" - "vmla.f32 q11, q2, %e[weight678][1]\n" - "vmla.f32 q12, q0, %e[weight345][0]\n" - "vmla.f32 q13, q2, %e[weight345][1]\n" - - "pld [%[in_ptr3], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr3]]\n" - "vmla.f32 q14, q3, %f[weight678][0]\n" - "vmla.f32 q15, q3, %f[weight345][0]\n" - "vext.32 q2, q0, q1, #1\n" - "vext.32 q3, q0, q1, #2\n" - "add %[in_ptr3], %[in_ptr3], #16\n" - - "vmla.f32 q12, q0, %e[weight678][0]\n" - "vmla.f32 q13, q2, %e[weight678][1]\n" - - "pld [%[in_ptr0], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" - "vmla.f32 q15, q3, %f[weight678][0]\n" - "vext.32 q2, q0, q1, #1\n" - "vext.32 q3, q0, q1, #2\n" - - "vadd.f32 q14, q14, q10\n" - "vadd.f32 q14, q14, q11\n" - "vadd.f32 q15, q15, q12\n" - "vadd.f32 q15, q15, q13\n" - - "subs %[nn], %[nn], #1\n" - "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" - "vst1.f32 {d30-d31}, [%[out_ptr1]]!\n" - - "bne 1b\n" - : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), - [in_ptr2] "+r"(in_ptr2), [in_ptr3] "+r"(in_ptr3), - - [out_ptr0] "+r"(out_ptr0), [out_ptr1] "+r"(out_ptr1), [nn] "+r"(nn) - : [weight012] "w"(weight012), [weight345] "w"(weight345), - [weight678] "w"(weight678), [qbias0] "w"(qbias0) - : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", - "q15", "cc", "memory"); - } - - for (; remain > 0; remain--) - { - float32x4_t input0 = vld1q_f32(in_ptr0); - float32x4_t input1 = vld1q_f32(in_ptr1); - float32x4_t input2 = vld1q_f32(in_ptr2); - float32x4_t input3 = vld1q_f32(in_ptr3); - - float32x4_t out0 = vmulq_f32(input0, weight012); - out0 = vmlaq_f32(out0, input1, weight345); - out0 = vmlaq_f32(out0, input2, weight678); - - float32x4_t out1 = vmulq_f32(input1, weight012); - out1 = vmlaq_f32(out1, input2, weight345); - out1 = vmlaq_f32(out1, input3, weight678); - - out0 = vsetq_lane_f32(bias0, out0, 3); - out1 = vsetq_lane_f32(bias0, out1, 3); - - float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); - float32x2_t out11 = vadd_f32(vget_low_f32(out1), vget_high_f32(out1)); - - float32x2_t out01 = vpadd_f32(out00, out11); - - *out_ptr0 = vget_lane_f32(out01, 0); - *out_ptr1 = vget_lane_f32(out01, 1); - - in_ptr0++; - in_ptr1++; - in_ptr2++; - in_ptr3++; - out_ptr0++; - out_ptr1++; - } - - in_ptr0 += w + 2; - in_ptr1 += w + 2; - in_ptr2 += w + 2; - in_ptr3 += w + 2; - - out_ptr0 += outw; - out_ptr1 += outw; - } - - for (; i < outh; i++) - { - int nn = outw >> 2; - int remain = outw & 0x03; - - if (nn > 0) - { - __asm __volatile("1:\n" - "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" - "vext.32 q2, q0, q1, #1\n" - "vext.32 q3, q0, q1, #2\n" - "add %[in_ptr0], %[in_ptr0], #16\n" - - "vand q14, %q[qbias0], %q[qbias0]\n" - "vmla.f32 q14, q0, %e[weight012][0]\n" - "vmla.f32 q14, q2, %e[weight012][1]\n" - "vmla.f32 q14, q3, %f[weight012][0]\n" - - "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" - "vext.32 q2, q0, q1, #1\n" - "vext.32 q3, q0, q1, #2\n" - "add %[in_ptr1], %[in_ptr1], #16\n" - - "vmla.f32 q14, q0, %e[weight345][0]\n" - "vmla.f32 q14, q2, %e[weight345][1]\n" - "vmla.f32 q14, q3, %f[weight345][0]\n" - - "vld1.f32 {d0-d2}, [%[in_ptr2]]\n" - "vext.32 q2, q0, q1, #1\n" - "vext.32 q3, q0, q1, #2\n" - "add %[in_ptr2], %[in_ptr2], #16\n" - - "vmla.f32 q14, q0, %e[weight678][0]\n" - "vmla.f32 q14, q2, %e[weight678][1]\n" - "vmla.f32 q14, q3, %f[weight678][0]\n" - - "subs %[nn], %[nn], #1\n" - "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" - - "bne 1b\n" - : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), - [in_ptr2] "+r"(in_ptr2), [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) - : [weight012] "w"(weight012), [weight345] "w"(weight345), - [weight678] "w"(weight678), [qbias0] "w"(qbias0) - : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", - "q15", "cc", "memory"); - } - - for (; remain > 0; remain--) - { - float32x4_t input0 = vld1q_f32(in_ptr0); - float32x4_t input1 = vld1q_f32(in_ptr1); - float32x4_t input2 = vld1q_f32(in_ptr2); - - float32x4_t out0 = vmulq_f32(input0, weight012); - out0 = vmlaq_f32(out0, input1, weight345); - out0 = vmlaq_f32(out0, input2, weight678); - - out0 = vsetq_lane_f32(bias0, out0, 3); - - float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); - - float32x2_t out01 = vpadd_f32(out00, out00); - - *out_ptr0 = vget_lane_f32(out01, 0); - - in_ptr0++; - in_ptr1++; - in_ptr2++; - out_ptr0++; - } - - in_ptr0 += 2; - in_ptr1 += 2; - in_ptr2 += 2; - } - } -#else // __aarch64__ - (void)in_mat; - (void)out_mat; - (void)kernel; - (void)bias; -#endif // !__aarch64__ -} - -static void depthwise_conv3x3S1_padding(const convMat_t &in_mat, convMat_t &out_mat, - const convMat_t &kernel, const convMat_t &bias) -{ -#if !__aarch64__ - int w = in_mat.w; - int h = in_mat.h; - int outw = out_mat.w; - int outh = out_mat.h; - int channels = in_mat.c; - -#pragma omp parallel for - for (int c = 0; c < channels; c++) - { - const float *filter = kernel.data + c * 9; -#ifdef NCNN - float *inbuf = in_mat.data + c * alignSize(w * h, 16 / sizeof(float)); - float *outbuf = out_mat.data + c * alignSize(outw * outh, 16 / sizeof(float)); -#else // NCNN - float *inbuf = in_mat.data + c * w * h; - float *outbuf = out_mat.data + c * outw * outh; -#endif // NCNN - float bias0 = bias.data ? bias.data[c] : 0.0f; - - register float32x4_t weight012 asm("q4") = vld1q_f32(filter); - register float32x4_t weight345 asm("q5") = vld1q_f32(filter + 3); - register float32x4_t weight678 asm("q6") = vld1q_f32(filter + 6); - register float32x4_t qbias0 asm("q7") = vdupq_n_f32(bias0); - - float *in_ptr0 = inbuf + 0 * w; - float *in_ptr1 = inbuf + 1 * w; - float *in_ptr2 = inbuf + 2 * w; - float *in_ptr3 = inbuf + 3 * w; - - float *out_ptr0 = outbuf + 0 * outw; - float *out_ptr1 = outbuf + 1 * outw; - - int i; - for (i = 0; i + 1 < outh; i += 2) - { - int nn = (outw >> 2) - 1; - int remain = (outw & 0x03) + 4; - if (i == 0) - { - if (nn > 0) - { - __asm __volatile("vmov.i32 q8, #0\n" - "pld [%[in_ptr0], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" - "vext.32 q2, q8, q0, #3\n" - "vext.32 q3, q0, q1, #1\n" - "add %[in_ptr0], %[in_ptr0], #12\n" - - "vand q14, %q[qbias0], %q[qbias0]\n" - "vand q15, %q[qbias0], %q[qbias0]\n" - "vmul.f32 q10, q2, %e[weight345][0]\n" - "vmul.f32 q11, q0, %e[weight345][1]\n" - "vmul.f32 q12, q2, %e[weight012][0]\n" - "vmul.f32 q13, q0, %e[weight012][1]\n" - - "pld [%[in_ptr1], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" - "vmla.f32 q14, q3, %f[weight345][0]\n" - "vmla.f32 q15, q3, %f[weight012][0]\n" - "vext.32 q2, q8, q0, #3\n" - "vext.32 q3, q0, q1, #1\n" - "add %[in_ptr1], %[in_ptr1], #12\n" - - "vmla.f32 q10, q2, %e[weight678][0]\n" - "vmla.f32 q11, q0, %e[weight678][1]\n" - "vmla.f32 q12, q2, %e[weight345][0]\n" - "vmla.f32 q13, q0, %e[weight345][1]\n" - - "pld [%[in_ptr2], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr2]]\n" - "vmla.f32 q14, q3, %f[weight678][0]\n" - "vmla.f32 q15, q3, %f[weight345][0]\n" - "vext.32 q2, q8, q0, #3\n" - "vext.32 q3, q0, q1, #1\n" - "add %[in_ptr2], %[in_ptr2], #12\n" - - "vmla.f32 q12, q2, %e[weight678][0]\n" - "vmla.f32 q13, q0, %e[weight678][1]\n" - "vmla.f32 q15, q3, %f[weight678][0]\n" - - "vadd.f32 q14, q14, q10\n" - "vadd.f32 q14, q14, q11\n" - "vadd.f32 q15, q15, q12\n" - "vadd.f32 q15, q15, q13\n" - - "subs %[nn], %[nn], #1\n" - "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" - "vst1.f32 {d30-d31}, [%[out_ptr1]]!\n" - "beq 2f\n" - - "pld [%[in_ptr0], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" - "vext.32 q2, q0, q1, #1\n" - "vext.32 q3, q0, q1, #2\n" - - "1:\n" - "add %[in_ptr0], %[in_ptr0], #16\n" - "vand q14, %q[qbias0], %q[qbias0]\n" - "vand q15, %q[qbias0], %q[qbias0]\n" - "vmul.f32 q10, q0, %e[weight345][0]\n" - "vmul.f32 q11, q2, %e[weight345][1]\n" - "vmul.f32 q12, q0, %e[weight012][0]\n" - "vmul.f32 q13, q2, %e[weight012][1]\n" - - "pld [%[in_ptr1], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" - "vmla.f32 q14, q3, %f[weight345][0]\n" - "vmla.f32 q15, q3, %f[weight012][0]\n" - "vext.32 q2, q0, q1, #1\n" - "vext.32 q3, q0, q1, #2\n" - "add %[in_ptr1], %[in_ptr1], #16\n" - - "vmla.f32 q10, q0, %e[weight678][0]\n" - "vmla.f32 q11, q2, %e[weight678][1]\n" - "vmla.f32 q12, q0, %e[weight345][0]\n" - "vmla.f32 q13, q2, %e[weight345][1]\n" - - "pld [%[in_ptr2], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr2]]\n" - "vmla.f32 q14, q3, %f[weight678][0]\n" - "vmla.f32 q15, q3, %f[weight345][0]\n" - "vext.32 q2, q0, q1, #1\n" - "vext.32 q3, q0, q1, #2\n" - "add %[in_ptr2], %[in_ptr2], #16\n" - - "vmla.f32 q12, q0, %e[weight678][0]\n" - "vmla.f32 q13, q2, %e[weight678][1]\n" - - "pld [%[in_ptr0], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" - "vmla.f32 q15, q3, %f[weight678][0]\n" - "vext.32 q2, q0, q1, #1\n" - "vext.32 q3, q0, q1, #2\n" - - "vadd.f32 q14, q14, q10\n" - "vadd.f32 q14, q14, q11\n" - "vadd.f32 q15, q15, q12\n" - "vadd.f32 q15, q15, q13\n" - - "subs %[nn], %[nn], #1\n" - "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" - "vst1.f32 {d30-d31}, [%[out_ptr1]]!\n" - "bne 1b\n" - "2:\n" - : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), - [in_ptr2] "+r"(in_ptr2), [out_ptr0] "+r"(out_ptr0), - [out_ptr1] "+r"(out_ptr1), [nn] "+r"(nn) - : [weight012] "w"(weight012), [weight345] "w"(weight345), - [weight678] "w"(weight678), [qbias0] "w"(qbias0) - : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", - "q15", "cc", "memory"); - } - - for (; remain > 0; remain--) - { - // TODO: when nn == 0, pad_left comes here. - float32x4_t input0 = vld1q_f32(in_ptr0); - float32x4_t input1 = vld1q_f32(in_ptr1); - float32x4_t input2 = vld1q_f32(in_ptr2); - - if (remain == 1) - { - input0 = vsetq_lane_f32(0.0f, input0, 2); - input1 = vsetq_lane_f32(0.0f, input1, 2); - input2 = vsetq_lane_f32(0.0f, input2, 2); - } - - float32x4_t out0 = vmulq_f32(input0, weight345); - out0 = vmlaq_f32(out0, input1, weight678); - - float32x4_t out1 = vmulq_f32(input0, weight012); - out1 = vmlaq_f32(out1, input1, weight345); - out1 = vmlaq_f32(out1, input2, weight678); - - out0 = vsetq_lane_f32(bias0, out0, 3); - out1 = vsetq_lane_f32(bias0, out1, 3); - - float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); - float32x2_t out11 = vadd_f32(vget_low_f32(out1), vget_high_f32(out1)); - - float32x2_t out01 = vpadd_f32(out00, out11); - - *out_ptr0 = vget_lane_f32(out01, 0); - *out_ptr1 = vget_lane_f32(out01, 1); - - in_ptr0++; - in_ptr1++; - in_ptr2++; - out_ptr0++; - out_ptr1++; - } - - in_ptr0 += 1; - in_ptr1 += 1; - in_ptr2 += 1; - in_ptr3 += w; - } - else if (i == outh - 2) - { - if (nn > 0) - { - __asm __volatile("vmov.i32 q8, #0\n" - "pld [%[in_ptr0], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" - "vext.32 q2, q8, q0, #3\n" - "vext.32 q3, q0, q1, #1\n" - "add %[in_ptr0], %[in_ptr0], #12\n" - - "vand q14, %q[qbias0], %q[qbias0]\n" - "vmul.f32 q10, q2, %e[weight012][0]\n" - "vmul.f32 q11, q0, %e[weight012][1]\n" - - "pld [%[in_ptr1], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" - "vmla.f32 q14, q3, %f[weight012][0]\n" - "vext.32 q2, q8, q0, #3\n" - "vext.32 q3, q0, q1, #1\n" - "add %[in_ptr1], %[in_ptr1], #12\n" - - "vand q15, %q[qbias0], %q[qbias0]\n" - "vmla.f32 q10, q2, %e[weight345][0]\n" - "vmla.f32 q11, q0, %e[weight345][1]\n" - "vmul.f32 q12, q2, %e[weight012][0]\n" - "vmul.f32 q13, q0, %e[weight012][1]\n" - - "pld [%[in_ptr2], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr2]]\n" - "vmla.f32 q14, q3, %f[weight345][0]\n" - "vmla.f32 q15, q3, %f[weight012][0]\n" - "vext.32 q2, q8, q0, #3\n" - "vext.32 q3, q0, q1, #1\n" - "add %[in_ptr2], %[in_ptr2], #12\n" - - "vmla.f32 q10, q2, %e[weight678][0]\n" - "vmla.f32 q11, q0, %e[weight678][1]\n" - "vmla.f32 q12, q2, %e[weight345][0]\n" - "vmla.f32 q13, q0, %e[weight345][1]\n" - - "vmla.f32 q14, q3, %f[weight678][0]\n" - "vmla.f32 q15, q3, %f[weight345][0]\n" - - "vadd.f32 q14, q14, q10\n" - "vadd.f32 q14, q14, q11\n" - "vadd.f32 q15, q15, q12\n" - "vadd.f32 q15, q15, q13\n" - - "subs %[nn], %[nn], #1\n" - "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" - "vst1.f32 {d30-d31}, [%[out_ptr1]]!\n" - "beq 2f\n" - - "pld [%[in_ptr0], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" - "vext.32 q2, q0, q1, #1\n" - "vext.32 q3, q0, q1, #2\n" - - "1:\n" - "add %[in_ptr0], %[in_ptr0], #16\n" - "vand q14, %q[qbias0], %q[qbias0]\n" - "vmul.f32 q10, q0, %e[weight012][0]\n" - "vmul.f32 q11, q2, %e[weight012][1]\n" - - "pld [%[in_ptr1], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" - "vmla.f32 q14, q3, %f[weight012][0]\n" - "vext.32 q2, q0, q1, #1\n" - "vext.32 q3, q0, q1, #2\n" - "add %[in_ptr1], %[in_ptr1], #16\n" - - "vand q15, %q[qbias0], %q[qbias0]\n" - "vmla.f32 q10, q0, %e[weight345][0]\n" - "vmla.f32 q11, q2, %e[weight345][1]\n" - "vmul.f32 q12, q0, %e[weight012][0]\n" - "vmul.f32 q13, q2, %e[weight012][1]\n" - - "pld [%[in_ptr2], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr2]]\n" - "vmla.f32 q14, q3, %f[weight345][0]\n" - "vmla.f32 q15, q3, %f[weight012][0]\n" - "vext.32 q2, q0, q1, #1\n" - "vext.32 q3, q0, q1, #2\n" - "add %[in_ptr2], %[in_ptr2], #16\n" - - "vmla.f32 q10, q0, %e[weight678][0]\n" - "vmla.f32 q11, q2, %e[weight678][1]\n" - "vmla.f32 q12, q0, %e[weight345][0]\n" - "vmla.f32 q13, q2, %e[weight345][1]\n" - - "pld [%[in_ptr0], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" - "vmla.f32 q14, q3, %f[weight678][0]\n" - "vmla.f32 q15, q3, %f[weight345][0]\n" - "vext.32 q2, q0, q1, #1\n" - "vext.32 q3, q0, q1, #2\n" - - "vadd.f32 q14, q14, q10\n" - "vadd.f32 q14, q14, q11\n" - "vadd.f32 q15, q15, q12\n" - "vadd.f32 q15, q15, q13\n" - - "subs %[nn], %[nn], #1\n" - "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" - "vst1.f32 {d30-d31}, [%[out_ptr1]]!\n" - "bne 1b\n" - "2:\n" - : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), - [in_ptr2] "+r"(in_ptr2), [out_ptr0] "+r"(out_ptr0), - [out_ptr1] "+r"(out_ptr1), [nn] "+r"(nn) - : [weight012] "w"(weight012), [weight345] "w"(weight345), - [weight678] "w"(weight678), [qbias0] "w"(qbias0) - : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", - "q15", "cc", "memory"); - } - for (; remain > 0; remain--) - { - // TODO: when nn == 0, pad_left comes here. - float32x4_t input0 = vld1q_f32(in_ptr0); - float32x4_t input1 = vld1q_f32(in_ptr1); - float32x4_t input2 = vld1q_f32(in_ptr2); - - if (remain == 1) - { - input0 = vsetq_lane_f32(0.0f, input0, 2); - input1 = vsetq_lane_f32(0.0f, input1, 2); - input2 = vsetq_lane_f32(0.0f, input2, 2); - } - - float32x4_t out0 = vmulq_f32(input0, weight012); - out0 = vmlaq_f32(out0, input1, weight345); - out0 = vmlaq_f32(out0, input2, weight678); - - float32x4_t out1 = vmulq_f32(input1, weight012); - out1 = vmlaq_f32(out1, input2, weight345); - - out0 = vsetq_lane_f32(bias0, out0, 3); - out1 = vsetq_lane_f32(bias0, out1, 3); - - float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); - float32x2_t out11 = vadd_f32(vget_low_f32(out1), vget_high_f32(out1)); - - float32x2_t out01 = vpadd_f32(out00, out11); - - *out_ptr0 = vget_lane_f32(out01, 0); - *out_ptr1 = vget_lane_f32(out01, 1); - - in_ptr0++; - in_ptr1++; - in_ptr2++; - out_ptr0++; - out_ptr1++; - } - } - else - { - if (nn > 0) - { - __asm __volatile("vmov.i32 q8, #0\n" - "pld [%[in_ptr0], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" - "vext.32 q2, q8, q0, #3\n" - "vext.32 q3, q0, q1, #1\n" - "add %[in_ptr0], %[in_ptr0], #12\n" - - "vand q14, %q[qbias0], %q[qbias0]\n" - "vmul.f32 q10, q2, %e[weight012][0]\n" - "vmul.f32 q11, q0, %e[weight012][1]\n" - - "pld [%[in_ptr1], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" - "vmla.f32 q14, q3, %f[weight012][0]\n" - "vext.32 q2, q8, q0, #3\n" - "vext.32 q3, q0, q1, #1\n" - "add %[in_ptr1], %[in_ptr1], #12\n" - - "vand q15, %q[qbias0], %q[qbias0]\n" - "vmla.f32 q10, q2, %e[weight345][0]\n" - "vmla.f32 q11, q0, %e[weight345][1]\n" - "vmul.f32 q12, q2, %e[weight012][0]\n" - "vmul.f32 q13, q0, %e[weight012][1]\n" - - "pld [%[in_ptr2], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr2]]\n" - "vmla.f32 q14, q3, %f[weight345][0]\n" - "vmla.f32 q15, q3, %f[weight012][0]\n" - "vext.32 q2, q8, q0, #3\n" - "vext.32 q3, q0, q1, #1\n" - "add %[in_ptr2], %[in_ptr2], #12\n" - - "vmla.f32 q10, q2, %e[weight678][0]\n" - "vmla.f32 q11, q0, %e[weight678][1]\n" - "vmla.f32 q12, q2, %e[weight345][0]\n" - "vmla.f32 q13, q0, %e[weight345][1]\n" - - "pld [%[in_ptr3], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr3]]\n" - "vmla.f32 q14, q3, %f[weight678][0]\n" - "vmla.f32 q15, q3, %f[weight345][0]\n" - "vext.32 q2, q8, q0, #3\n" - "vext.32 q3, q0, q1, #1\n" - "add %[in_ptr3], %[in_ptr3], #12\n" - - "vmla.f32 q15, q2, %e[weight678][0]\n" - "vmla.f32 q15, q0, %e[weight678][1]\n" - "vmla.f32 q15, q3, %f[weight678][0]\n" - - "vadd.f32 q14, q14, q10\n" - "vadd.f32 q14, q14, q11\n" - "vadd.f32 q15, q15, q12\n" - "vadd.f32 q15, q15, q13\n" - - "subs %[nn], %[nn], #1\n" - "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" - "vst1.f32 {d30-d31}, [%[out_ptr1]]!\n" - "beq 2f\n" - - "pld [%[in_ptr0], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" - "vext.32 q2, q0, q1, #1\n" - "vext.32 q3, q0, q1, #2\n" - - "1:\n" - "add %[in_ptr0], %[in_ptr0], #16\n" - "vand q14, %q[qbias0], %q[qbias0]\n" - "vmul.f32 q10, q0, %e[weight012][0]\n" - "vmul.f32 q11, q2, %e[weight012][1]\n" - - "pld [%[in_ptr1], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" - "vmla.f32 q14, q3, %f[weight012][0]\n" - "vext.32 q2, q0, q1, #1\n" - "vext.32 q3, q0, q1, #2\n" - "add %[in_ptr1], %[in_ptr1], #16\n" - - "vand q15, %q[qbias0], %q[qbias0]\n" - "vmla.f32 q10, q0, %e[weight345][0]\n" - "vmla.f32 q11, q2, %e[weight345][1]\n" - "vmul.f32 q12, q0, %e[weight012][0]\n" - "vmul.f32 q13, q2, %e[weight012][1]\n" - - "pld [%[in_ptr2], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr2]]\n" - "vmla.f32 q14, q3, %f[weight345][0]\n" - "vmla.f32 q15, q3, %f[weight012][0]\n" - "vext.32 q2, q0, q1, #1\n" - "vext.32 q3, q0, q1, #2\n" - "add %[in_ptr2], %[in_ptr2], #16\n" - - "vmla.f32 q10, q0, %e[weight678][0]\n" - "vmla.f32 q11, q2, %e[weight678][1]\n" - "vmla.f32 q12, q0, %e[weight345][0]\n" - "vmla.f32 q13, q2, %e[weight345][1]\n" - - "pld [%[in_ptr3], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr3]]\n" - "vmla.f32 q14, q3, %f[weight678][0]\n" - "vmla.f32 q15, q3, %f[weight345][0]\n" - "vext.32 q2, q0, q1, #1\n" - "vext.32 q3, q0, q1, #2\n" - "add %[in_ptr3], %[in_ptr3], #16\n" - - "vmla.f32 q15, q0, %e[weight678][0]\n" - "vmla.f32 q15, q2, %e[weight678][1]\n" - - "pld [%[in_ptr0], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" - "vmla.f32 q15, q3, %f[weight678][0]\n" - "vext.32 q2, q0, q1, #1\n" - "vext.32 q3, q0, q1, #2\n" - - "vadd.f32 q14, q14, q10\n" - "vadd.f32 q15, q15, q12\n" - "vadd.f32 q14, q14, q11\n" - "vadd.f32 q15, q15, q13\n" - - "subs %[nn], %[nn], #1\n" - "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" - "vst1.f32 {d30-d31}, [%[out_ptr1]]!\n" - "bne 1b\n" - "2:\n" - : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), - [in_ptr2] "+r"(in_ptr2), [in_ptr3] "+r"(in_ptr3), - - [out_ptr0] "+r"(out_ptr0), [out_ptr1] "+r"(out_ptr1), [nn] "+r"(nn) - : [weight012] "w"(weight012), [weight345] "w"(weight345), - [weight678] "w"(weight678), [qbias0] "w"(qbias0) - : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", - "q15", "cc", "memory"); - } - for (; remain > 0; remain--) - { - // TODO: when nn == 0, pad_left comes here. - float32x4_t input0 = vld1q_f32(in_ptr0); - float32x4_t input1 = vld1q_f32(in_ptr1); - float32x4_t input2 = vld1q_f32(in_ptr2); - float32x4_t input3 = vld1q_f32(in_ptr3); - - if (remain == 1) - { - input0 = vsetq_lane_f32(0.0f, input0, 2); - input1 = vsetq_lane_f32(0.0f, input1, 2); - input2 = vsetq_lane_f32(0.0f, input2, 2); - input3 = vsetq_lane_f32(0.0f, input3, 2); - } - - float32x4_t out0 = vmulq_f32(input0, weight012); - out0 = vmlaq_f32(out0, input1, weight345); - out0 = vmlaq_f32(out0, input2, weight678); - - float32x4_t out1 = vmulq_f32(input1, weight012); - out1 = vmlaq_f32(out1, input2, weight345); - out1 = vmlaq_f32(out1, input3, weight678); - - out0 = vsetq_lane_f32(bias0, out0, 3); - out1 = vsetq_lane_f32(bias0, out1, 3); - - float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); - float32x2_t out11 = vadd_f32(vget_low_f32(out1), vget_high_f32(out1)); - - float32x2_t out01 = vpadd_f32(out00, out11); - - *out_ptr0 = vget_lane_f32(out01, 0); - *out_ptr1 = vget_lane_f32(out01, 1); - - in_ptr0++; - in_ptr1++; - in_ptr2++; - in_ptr3++; - out_ptr0++; - out_ptr1++; - } - in_ptr0 += w + 1; - in_ptr1 += w + 1; - in_ptr2 += w + 1; - in_ptr3 += w + 1; - } - - out_ptr0 += outw; - out_ptr1 += outw; - } - - for (; i < outh; i++) - { - // TODO:if i == 0, pad_top comes here. - int nn = (outw >> 2) - 1; - int remain = (outw & 0x03) + 4; - - if (nn > 0) - { - __asm __volatile("vmov.i32 q8, #0\n" - "pld [%[in_ptr0], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" - "vext.32 q2, q8, q0, #3\n" - "vext.32 q3, q0, q1, #1\n" - "add %[in_ptr0], %[in_ptr0], #12\n" - - "vand q14, %q[qbias0], %q[qbias0]\n" - "vmul.f32 q10, q2, %e[weight012][0]\n" - "vmul.f32 q11, q0, %e[weight012][1]\n" - - "pld [%[in_ptr1], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" - "vmla.f32 q14, q3, %f[weight012][0]\n" - "vext.32 q2, q8, q0, #3\n" - "vext.32 q3, q0, q1, #1\n" - "add %[in_ptr1], %[in_ptr1], #12\n" - - "vmla.f32 q10, q2, %e[weight345][0]\n" - "vmla.f32 q11, q0, %e[weight345][1]\n" - "vmla.f32 q14, q3, %f[weight345][0]\n" - - "vadd.f32 q14, q14, q10\n" - "vadd.f32 q14, q14, q11\n" - - "subs %[nn], %[nn], #1\n" - "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" - "beq 2f\n" - - "pld [%[in_ptr0], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" - "vext.32 q2, q0, q1, #1\n" - "vext.32 q3, q0, q1, #2\n" - - "1:\n" - "add %[in_ptr0], %[in_ptr0], #16\n" - "vand q14, %q[qbias0], %q[qbias0]\n" - "vmul.f32 q10, q0, %e[weight012][0]\n" - "vmul.f32 q11, q2, %e[weight012][1]\n" - - "pld [%[in_ptr1], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" - "vmla.f32 q14, q3, %f[weight012][0]\n" - "vext.32 q2, q0, q1, #1\n" - "vext.32 q3, q0, q1, #2\n" - "add %[in_ptr1], %[in_ptr1], #16\n" - - "vmla.f32 q10, q0, %e[weight345][0]\n" - "vmla.f32 q11, q2, %e[weight345][1]\n" - - "pld [%[in_ptr0], #192]\n" - "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" - "vmla.f32 q14, q3, %f[weight678][0]\n" - "vext.32 q2, q0, q1, #1\n" - "vext.32 q3, q0, q1, #2\n" - - "vadd.f32 q14, q14, q10\n" - "vadd.f32 q14, q14, q11\n" - - "subs %[nn], %[nn], #1\n" - "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" - "bne 1b\n" - "2:\n" - : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), - [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) - : [weight012] "w"(weight012), [weight345] "w"(weight345), - [weight678] "w"(weight678), [qbias0] "w"(qbias0) - : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", - "q15", "cc", "memory"); - } - for (; remain > 0; remain--) - { - // TODO: when nn == 0, pad_left comes here. - float32x4_t input0 = vld1q_f32(in_ptr0); - float32x4_t input1 = vld1q_f32(in_ptr1); - - if (remain == 1) - { - input0 = vsetq_lane_f32(0.0f, input0, 2); - input1 = vsetq_lane_f32(0.0f, input1, 2); - } - - float32x4_t out0 = vmulq_f32(input0, weight012); - out0 = vmlaq_f32(out0, input1, weight345); - - out0 = vsetq_lane_f32(bias0, out0, 3); - - float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); - - float32x2_t out01 = vpadd_f32(out00, out00); - - *out_ptr0 = vget_lane_f32(out01, 0); - - in_ptr0++; - in_ptr1++; - out_ptr0++; - out_ptr1++; - } - } - } -#else // __aarch64__ - (void)in_mat; - (void)out_mat; - (void)kernel; - (void)bias; -#endif // __aarch64__ -} - -static void depthwise_conv3x3S2_nopad(const convMat_t &in_mat, convMat_t &out_mat, - const convMat_t &kernel, const convMat_t &bias) -{ -#if !__aarch64__ - int w = in_mat.w; - int h = in_mat.h; - int outw = out_mat.w; - int outh = out_mat.h; - int channels = in_mat.c; - - const int tailstep = w - 2 * outw + w; - -#pragma omp parallel for - for (int c = 0; c < channels; c++) - { - const float *filter = kernel.data + c * 9; -#ifdef NCNN - float *inbuf = in_mat.data + c * alignSize(w * h, 16 / sizeof(float)); - float *outbuf = out_mat.data + c * alignSize(outw * outh, 16 / sizeof(float)); -#else // NCNN - float *inbuf = in_mat.data + c * w * h; - float *outbuf = out_mat.data + c * outw * outh; -#endif // NCNN - float bias0 = bias.data ? bias.data[c] : 0.0f; - - register float32x4_t weight012 asm("q4") = vld1q_f32(filter); - register float32x4_t weight345 asm("q5") = vld1q_f32(filter + 3); - register float32x4_t weight678 asm("q6") = vld1q_f32(filter + 6); - register float32x4_t qbias0 asm("q7") = vdupq_n_f32(bias0); - - float *in_ptr0 = inbuf + 0 * w; - float *in_ptr1 = inbuf + 1 * w; - float *in_ptr2 = inbuf + 2 * w; - - float *out_ptr0 = outbuf + 0 * outw; - - int i; - for (i = 0; i < outh; i++) - { - int nn = outw >> 2; - int remain = outw & 0x03; - - if (nn > 0) - { - __asm __volatile("pld [%[in_ptr0], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" - "vext.32 q3, q0, q2, #1\n" - - "1:\n" - "vand q14, %q[qbias0], %q[qbias0]\n" - "vmul.f32 q10, q0, %e[weight012][0]\n" - "vmul.f32 q11, q1, %e[weight012][1]\n" - - "pld [%[in_ptr1], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" - "vmla.f32 q14, q3, %f[weight012][0]\n" - "vext.32 q3, q0, q2, #1\n" - - "vmla.f32 q10, q0, %e[weight345][0]\n" - "vmla.f32 q11, q1, %e[weight345][1]\n" - - "pld [%[in_ptr2], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr2]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr2]]\n" - "vmla.f32 q14, q3, %f[weight345][0]\n" - "vext.32 q3, q0, q2, #1\n" - - "vmla.f32 q10, q0, %e[weight678][0]\n" - "vmla.f32 q11, q1, %e[weight678][1]\n" - - "pld [%[in_ptr0], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" - "vmla.f32 q14, q3, %f[weight678][0]\n" - "vext.32 q3, q0, q2, #1\n" - - "vadd.f32 q14, q14, q10\n" - "vadd.f32 q14, q14, q11\n" - - "subs %[nn], %[nn], #1\n" - "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" - "bne 1b\n" - "sub %[in_ptr0], %[in_ptr0], #32\n" - : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), - [in_ptr2] "+r"(in_ptr2), [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) - : [weight012] "w"(weight012), [weight345] "w"(weight345), - [weight678] "w"(weight678), [qbias0] "w"(qbias0) - : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", - "q15", "cc", "memory"); - } - - for (; remain > 0; remain--) - { - float32x4_t input0 = vld1q_f32(in_ptr0); - float32x4_t input1 = vld1q_f32(in_ptr1); - float32x4_t input2 = vld1q_f32(in_ptr2); - - float32x4_t out0 = vmulq_f32(input0, weight012); - out0 = vmlaq_f32(out0, input1, weight345); - out0 = vmlaq_f32(out0, input2, weight678); - - out0 = vsetq_lane_f32(bias0, out0, 3); - - float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); - - float32x2_t out01 = vpadd_f32(out00, out00); - - *out_ptr0 = vget_lane_f32(out01, 0); - - in_ptr0 += 2; - in_ptr1 += 2; - in_ptr2 += 2; - out_ptr0++; - } - - in_ptr0 += tailstep; - in_ptr1 += tailstep; - in_ptr2 += tailstep; - } - } - -#else // __aarch64__ - (void)in_mat; - (void)out_mat; - (void)kernel; - (void)bias; -#endif // __aarch64__ -} - -static void depthwise_conv3x3S2_padding00(const convMat_t &in_mat, convMat_t &out_mat, - const convMat_t &kernel, const convMat_t &bias) -{ -#if !__aarch64__ - int w = in_mat.w; - int h = in_mat.h; - int outw = out_mat.w; - int outh = out_mat.h; - int channels = in_mat.c; - -#pragma omp parallel for - for (int c = 0; c < channels; c++) - { - const float *filter = kernel.data + c * 9; -#ifdef NCNN - float *inbuf = in_mat.data + c * alignSize(w * h, 16 / sizeof(float)); - float *outbuf = out_mat.data + c * alignSize(outw * outh, 16 / sizeof(float)); -#else // NCNN - float *inbuf = in_mat.data + c * w * h; - float *outbuf = out_mat.data + c * outw * outh; -#endif // NCNN - float bias0 = bias.data ? bias.data[c] : 0.0f; - - register float32x4_t weight012 asm("q4") = vld1q_f32(filter); - register float32x4_t weight345 asm("q5") = vld1q_f32(filter + 3); - register float32x4_t weight678 asm("q6") = vld1q_f32(filter + 6); - register float32x4_t qbias0 asm("q7") = vdupq_n_f32(bias0); - - float *in_ptr0 = inbuf + 0 * w; - float *in_ptr1 = inbuf + 1 * w; - float *in_ptr2 = inbuf + 2 * w; - - float *out_ptr0 = outbuf + 0 * outw; - - int i; - for (i = 0; i < outh; i++) - { - int nn = (outw >> 2) - 1; - int remain = (outw & 0x03) + 4; - - if (i == outh - 1) - { - if (nn > 0) - { - __asm __volatile("pld [%[in_ptr0], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" - "vext.32 q3, q0, q2, #1\n" - - "1:\n" - "vand q14, %q[qbias0], %q[qbias0]\n" - "vmul.f32 q10, q0, %e[weight012][0]\n" - "vmul.f32 q11, q1, %e[weight012][1]\n" - - "pld [%[in_ptr1], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" - "vmla.f32 q14, q3, %f[weight012][0]\n" - "vext.32 q3, q0, q2, #1\n" - - "vmla.f32 q10, q0, %e[weight345][0]\n" - "vmla.f32 q11, q1, %e[weight345][1]\n" - - "pld [%[in_ptr0], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" - "vmla.f32 q14, q3, %f[weight345][0]\n" - "vext.32 q3, q0, q2, #1\n" - - "vadd.f32 q14, q14, q10\n" - "vadd.f32 q14, q14, q11\n" - - "subs %[nn], %[nn], #1\n" - "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" - "bne 1b\n" - "sub %[in_ptr0], %[in_ptr0], #32\n" - : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), - [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) - : [weight012] "w"(weight012), [weight345] "w"(weight345), - [weight678] "w"(weight678), [qbias0] "w"(qbias0) - : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", - "q15", "cc", "memory"); - } - for (; remain > 0; remain--) - { - float32x4_t input0 = vld1q_f32(in_ptr0); - float32x4_t input1 = vld1q_f32(in_ptr1); - - if (remain == 1) - { - input0 = vsetq_lane_f32(0.0f, input0, 2); - input1 = vsetq_lane_f32(0.0f, input1, 2); - } - - float32x4_t out0 = vmulq_f32(input0, weight012); - out0 = vmlaq_f32(out0, input1, weight345); - - out0 = vsetq_lane_f32(bias0, out0, 3); - - float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); - - float32x2_t out01 = vpadd_f32(out00, out00); - - *out_ptr0 = vget_lane_f32(out01, 0); - - in_ptr0 += 2; - in_ptr1 += 2; - out_ptr0++; - } - } - else - { - if (nn > 0) - { - __asm __volatile("pld [%[in_ptr0], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" - "vext.32 q3, q0, q2, #1\n" - - "1:\n" - "vand q14, %q[qbias0], %q[qbias0]\n" - "vmul.f32 q10, q0, %e[weight012][0]\n" - "vmul.f32 q11, q1, %e[weight012][1]\n" - - "pld [%[in_ptr1], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" - "vmla.f32 q14, q3, %f[weight012][0]\n" - "vext.32 q3, q0, q2, #1\n" - - "vmla.f32 q10, q0, %e[weight345][0]\n" - "vmla.f32 q11, q1, %e[weight345][1]\n" - - "pld [%[in_ptr2], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr2]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr2]]\n" - "vmla.f32 q14, q3, %f[weight345][0]\n" - "vext.32 q3, q0, q2, #1\n" - - "vmla.f32 q10, q0, %e[weight678][0]\n" - "vmla.f32 q11, q1, %e[weight678][1]\n" - - "pld [%[in_ptr0], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" - "vmla.f32 q14, q3, %f[weight678][0]\n" - "vext.32 q3, q0, q2, #1\n" - - "vadd.f32 q14, q14, q10\n" - "vadd.f32 q14, q14, q11\n" - - "subs %[nn], %[nn], #1\n" - "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" - "bne 1b\n" - "sub %[in_ptr0], %[in_ptr0], #32\n" - : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), - [in_ptr2] "+r"(in_ptr2), [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) - : [weight012] "w"(weight012), [weight345] "w"(weight345), - [weight678] "w"(weight678), [qbias0] "w"(qbias0) - : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", - "q15", "cc", "memory"); - } - for (; remain > 0; remain--) - { - float32x4_t input0 = vld1q_f32(in_ptr0); - float32x4_t input1 = vld1q_f32(in_ptr1); - float32x4_t input2 = vld1q_f32(in_ptr2); - - if (remain == 1) - { - input0 = vsetq_lane_f32(0.0f, input0, 2); - input1 = vsetq_lane_f32(0.0f, input1, 2); - input2 = vsetq_lane_f32(0.0f, input2, 2); - } - - float32x4_t out0 = vmulq_f32(input0, weight012); - out0 = vmlaq_f32(out0, input1, weight345); - out0 = vmlaq_f32(out0, input2, weight678); - - out0 = vsetq_lane_f32(bias0, out0, 3); - - float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); - - float32x2_t out01 = vpadd_f32(out00, out00); - - *out_ptr0 = vget_lane_f32(out01, 0); - - in_ptr0 += 2; - in_ptr1 += 2; - in_ptr2 += 2; - out_ptr0++; - } - - in_ptr0 += w; - in_ptr1 += w; - in_ptr2 += w; - } - } - } -#else // __aarch64__ - (void)in_mat; - (void)out_mat; - (void)kernel; - (void)bias; -#endif // !__aarch64__ -} - -static void depthwise_conv3x3S2_padding01(const convMat_t &in_mat, convMat_t &out_mat, - const convMat_t &kernel, const convMat_t &bias) -{ -#if !__aarch64__ - int w = in_mat.w; - int h = in_mat.h; - int outw = out_mat.w; - int outh = out_mat.h; - int channels = in_mat.c; - -#pragma omp parallel for - for (int c = 0; c < channels; c++) - { - const float *filter = kernel.data + c * 9; -#ifdef NCNN - float *inbuf = in_mat.data + c * alignSize(w * h, 16 / sizeof(float)); - float *outbuf = out_mat.data + c * alignSize(outw * outh, 16 / sizeof(float)); -#else // NCNN - float *inbuf = in_mat.data + c * w * h; - float *outbuf = out_mat.data + c * outw * outh; -#endif // NCNN - float bias0 = bias.data ? bias.data[c] : 0.0f; - - register float32x4_t weight012 asm("q4") = vld1q_f32(filter); - register float32x4_t weight345 asm("q5") = vld1q_f32(filter + 3); - register float32x4_t weight678 asm("q6") = vld1q_f32(filter + 6); - register float32x4_t qbias0 asm("q7") = vdupq_n_f32(bias0); - - float *in_ptr0 = inbuf + 0 * w; - float *in_ptr1 = inbuf + 1 * w; - float *in_ptr2 = inbuf + 2 * w; - - float *out_ptr0 = outbuf + 0 * outw; - - int i; - for (i = 0; i < outh; i++) - { - int nn = (outw >> 2) - 1; - int remain = (outw & 0x03) + 4; - - if (i == outh - 1) - { - if (nn > 0) - { - __asm __volatile("vmov.i32 q2, #0\n" - "pld [%[in_ptr0], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr0]]\n" - "vext.32 q3, q2, q0, #3\n" - "add %[in_ptr0], %[in_ptr0], #28\n" - - "vand q14, %q[qbias0], %q[qbias0]\n" - "vmul.f32 q10, q3, %e[weight012][0]\n" - "vmul.f32 q11, q0, %e[weight012][1]\n" - - "pld [%[in_ptr1], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr1]]\n" - "vmla.f32 q14, q1, %f[weight012][0]\n" - "vext.32 q3, q2, q0, #3\n" - "add %[in_ptr1], %[in_ptr1], #28\n" - - "vmla.f32 q10, q3, %e[weight345][0]\n" - "vmla.f32 q11, q0, %e[weight345][1]\n" - "vmla.f32 q14, q1, %f[weight345][0]\n" - - "vadd.f32 q14, q14, q10\n" - "vadd.f32 q14, q14, q11\n" - - "subs %[nn], %[nn], #1\n" - "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" - "beq 2f\n" - - "pld [%[in_ptr0], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" - "vext.32 q3, q0, q2, #1\n" - - "1:\n" - "vand q14, %q[qbias0], %q[qbias0]\n" - "vmul.f32 q10, q0, %e[weight012][0]\n" - "vmul.f32 q11, q1, %e[weight012][1]\n" - - "pld [%[in_ptr1], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" - "vmla.f32 q14, q3, %f[weight012][0]\n" - "vext.32 q3, q0, q2, #1\n" - - "vmla.f32 q10, q0, %e[weight345][0]\n" - "vmla.f32 q11, q1, %e[weight345][1]\n" - - "pld [%[in_ptr0], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" - "vmla.f32 q14, q3, %f[weight345][0]\n" - "vext.32 q3, q0, q2, #1\n" - - "vadd.f32 q14, q14, q10\n" - "vadd.f32 q14, q14, q11\n" - - "subs %[nn], %[nn], #1\n" - "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" - "bne 1b\n" - "sub %[in_ptr0], %[in_ptr0], #32\n" - - "2:\n" - : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), - [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) - : [weight012] "w"(weight012), [weight345] "w"(weight345), - [weight678] "w"(weight678), [qbias0] "w"(qbias0) - : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", - "q15", "cc", "memory"); - } - for (; remain > 0; remain--) - { - // TODO: if nn == 0, pad_left comes here. - float32x4_t input0 = vld1q_f32(in_ptr0); - float32x4_t input1 = vld1q_f32(in_ptr1); - - if (remain == 1) - { - input0 = vsetq_lane_f32(0.0f, input0, 2); - input1 = vsetq_lane_f32(0.0f, input1, 2); - } - - float32x4_t out0 = vmulq_f32(input0, weight012); - out0 = vmlaq_f32(out0, input1, weight345); - - out0 = vsetq_lane_f32(bias0, out0, 3); - - float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); - - float32x2_t out01 = vpadd_f32(out00, out00); - - *out_ptr0 = vget_lane_f32(out01, 0); - - in_ptr0 += 2; - in_ptr1 += 2; - out_ptr0++; - } - } - else - { - if (nn > 0) - { - __asm __volatile("vmov.i32 q2, #0\n" - "pld [%[in_ptr0], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr0]]\n" - "vext.32 q3, q2, q0, #3\n" - "add %[in_ptr0], %[in_ptr0], #28\n" - - "vand q14, %q[qbias0], %q[qbias0]\n" - "vmul.f32 q10, q3, %e[weight012][0]\n" - "vmul.f32 q11, q0, %e[weight012][1]\n" - - "pld [%[in_ptr1], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr1]]\n" - "vmla.f32 q14, q1, %f[weight012][0]\n" - "vext.32 q3, q2, q0, #3\n" - "add %[in_ptr1], %[in_ptr1], #28\n" - - "vmla.f32 q10, q3, %e[weight345][0]\n" - "vmla.f32 q11, q0, %e[weight345][1]\n" - - "pld [%[in_ptr2], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr2]]\n" - "vmla.f32 q14, q1, %f[weight345][0]\n" - "vext.32 q3, q2, q0, #3\n" - "add %[in_ptr2], %[in_ptr2], #28\n" - - "vmla.f32 q10, q3, %e[weight678][0]\n" - "vmla.f32 q11, q0, %e[weight678][1]\n" - "vmla.f32 q14, q1, %f[weight678][0]\n" - - "vadd.f32 q14, q14, q10\n" - "vadd.f32 q14, q14, q11\n" - - "subs %[nn], %[nn], #1\n" - "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" - "beq 2f\n" - - "pld [%[in_ptr0], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" - "vext.32 q3, q0, q2, #1\n" - - "1:\n" - "vand q14, %q[qbias0], %q[qbias0]\n" - "vmul.f32 q10, q0, %e[weight012][0]\n" - "vmul.f32 q11, q1, %e[weight012][1]\n" - - "pld [%[in_ptr1], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" - "vmla.f32 q14, q3, %f[weight012][0]\n" - "vext.32 q3, q0, q2, #1\n" - - "vmla.f32 q10, q0, %e[weight345][0]\n" - "vmla.f32 q11, q1, %e[weight345][1]\n" - - "pld [%[in_ptr2], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr2]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr2]]\n" - "vmla.f32 q14, q3, %f[weight345][0]\n" - "vext.32 q3, q0, q2, #1\n" - - "vmla.f32 q10, q0, %e[weight678][0]\n" - "vmla.f32 q11, q1, %e[weight678][1]\n" - - "pld [%[in_ptr0], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" - "vmla.f32 q14, q3, %f[weight678][0]\n" - "vext.32 q3, q0, q2, #1\n" - - "vadd.f32 q14, q14, q10\n" - "vadd.f32 q14, q14, q11\n" - - "subs %[nn], %[nn], #1\n" - "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" - "bne 1b\n" - "sub %[in_ptr0], %[in_ptr0], #32\n" - "2:\n" - : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), - [in_ptr2] "+r"(in_ptr2), [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) - : [weight012] "w"(weight012), [weight345] "w"(weight345), - [weight678] "w"(weight678), [qbias0] "w"(qbias0) - : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", - "q15", "cc", "memory"); - } - for (; remain > 0; remain--) - { - // TODO: if nn == 0, pad_left comes here. - float32x4_t input0 = vld1q_f32(in_ptr0); - float32x4_t input1 = vld1q_f32(in_ptr1); - float32x4_t input2 = vld1q_f32(in_ptr2); - - if (remain == 1) - { - input0 = vsetq_lane_f32(0.0f, input0, 2); - input1 = vsetq_lane_f32(0.0f, input1, 2); - input2 = vsetq_lane_f32(0.0f, input2, 2); - } - - float32x4_t out0 = vmulq_f32(input0, weight012); - out0 = vmlaq_f32(out0, input1, weight345); - out0 = vmlaq_f32(out0, input2, weight678); - - out0 = vsetq_lane_f32(bias0, out0, 3); - - float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); - - float32x2_t out01 = vpadd_f32(out00, out00); - - *out_ptr0 = vget_lane_f32(out01, 0); - - in_ptr0 += 2; - in_ptr1 += 2; - in_ptr2 += 2; - out_ptr0++; - } - - in_ptr0 += w; - in_ptr1 += w; - in_ptr2 += w; - } - } - } - -#else // __aarch64__ - (void)in_mat; - (void)out_mat; - (void)kernel; - (void)bias; -#endif // __aarch64__ -} - -static void depthwise_conv3x3S2_padding10(const convMat_t &in_mat, convMat_t &out_mat, - const convMat_t &kernel, const convMat_t &bias) -{ -#if !__aarch64__ - int w = in_mat.w; - int h = in_mat.h; - int outw = out_mat.w; - int outh = out_mat.h; - int channels = in_mat.c; - -#pragma omp parallel for - for (int c = 0; c < channels; c++) - { - const float *filter = kernel.data + c * 9; -#ifdef NCNN - float *inbuf = in_mat.data + c * alignSize(w * h, 16 / sizeof(float)); - float *outbuf = out_mat.data + c * alignSize(outw * outh, 16 / sizeof(float)); -#else // NCNN - float *inbuf = in_mat.data + c * w * h; - float *outbuf = out_mat.data + c * outw * outh; -#endif // NCNN - float bias0 = bias.data ? bias.data[c] : 0.0f; - - register float32x4_t weight012 asm("q4") = vld1q_f32(filter); - register float32x4_t weight345 asm("q5") = vld1q_f32(filter + 3); - register float32x4_t weight678 asm("q6") = vld1q_f32(filter + 6); - register float32x4_t qbias0 asm("q7") = vdupq_n_f32(bias0); - - float *in_ptr0 = inbuf + 0 * w; - float *in_ptr1 = inbuf + 1 * w; - float *in_ptr2 = inbuf + 2 * w; - - float *out_ptr0 = outbuf + 0 * outw; - - int i; - for (i = 0; i < outh; i++) - { - int nn = (outw >> 2) - 1; - int remain = (outw & 0x03) + 4; - - // TODO: i == 0 && i == outh -1 - if (i == 0) - { - if (nn > 0) - { - __asm __volatile("pld [%[in_ptr0], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" - "vext.32 q3, q0, q2, #1\n" - - "1:\n" - "vand q14, %q[qbias0], %q[qbias0]\n" - "vmul.f32 q10, q0, %e[weight345][0]\n" - "vmul.f32 q11, q1, %e[weight345][1]\n" - - "pld [%[in_ptr1], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" - "vmla.f32 q14, q3, %f[weight345][0]\n" - "vext.32 q3, q0, q2, #1\n" - - "vmla.f32 q10, q0, %e[weight678][0]\n" - "vmla.f32 q11, q1, %e[weight678][1]\n" - - "pld [%[in_ptr0], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" - "vmla.f32 q14, q3, %f[weight678][0]\n" - "vext.32 q3, q0, q2, #1\n" - - "vadd.f32 q14, q14, q10\n" - "vadd.f32 q14, q14, q11\n" - - "subs %[nn], %[nn], #1\n" - "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" - "bne 1b\n" - "sub %[in_ptr0], %[in_ptr0], #32\n" - : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), - [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) - : [weight012] "w"(weight012), [weight345] "w"(weight345), - [weight678] "w"(weight678), [qbias0] "w"(qbias0) - : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", - "q15", "cc", "memory"); - } - for (; remain > 0; remain--) - { - float32x4_t input0 = vld1q_f32(in_ptr0); - float32x4_t input1 = vld1q_f32(in_ptr1); - - if (remain == 1) - { - input0 = vsetq_lane_f32(0.0f, input0, 2); - input1 = vsetq_lane_f32(0.0f, input1, 2); - } - - float32x4_t out0 = vmulq_f32(input0, weight345); - out0 = vmlaq_f32(out0, input1, weight678); - - out0 = vsetq_lane_f32(bias0, out0, 3); - - float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); - - float32x2_t out01 = vpadd_f32(out00, out00); - - *out_ptr0 = vget_lane_f32(out01, 0); - - in_ptr0 += 2; - in_ptr1 += 2; - out_ptr0++; - } - - in_ptr2 += w; - } - else if (i == outh - 1) - { - if (nn > 0) - { - __asm __volatile("pld [%[in_ptr0], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" - "vext.32 q3, q0, q2, #1\n" - - "1:\n" - "vand q14, %q[qbias0], %q[qbias0]\n" - "vmul.f32 q10, q0, %e[weight012][0]\n" - "vmul.f32 q11, q1, %e[weight012][1]\n" - - "pld [%[in_ptr1], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" - "vmla.f32 q14, q3, %f[weight012][0]\n" - "vext.32 q3, q0, q2, #1\n" - - "vmla.f32 q10, q0, %e[weight345][0]\n" - "vmla.f32 q11, q1, %e[weight345][1]\n" - - "pld [%[in_ptr0], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" - "vmla.f32 q14, q3, %f[weight345][0]\n" - "vext.32 q3, q0, q2, #1\n" - - "vadd.f32 q14, q14, q10\n" - "vadd.f32 q14, q14, q11\n" - - "subs %[nn], %[nn], #1\n" - "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" - "bne 1b\n" - "sub %[in_ptr0], %[in_ptr0], #32\n" - : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), - [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) - : [weight012] "w"(weight012), [weight345] "w"(weight345), - [weight678] "w"(weight678), [qbias0] "w"(qbias0) - : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", - "q15", "cc", "memory"); - } - for (; remain > 0; remain--) - { - float32x4_t input0 = vld1q_f32(in_ptr0); - float32x4_t input1 = vld1q_f32(in_ptr1); - - if (remain == 1) - { - input0 = vsetq_lane_f32(0.0f, input0, 2); - input1 = vsetq_lane_f32(0.0f, input1, 2); - } - - float32x4_t out0 = vmulq_f32(input0, weight012); - out0 = vmlaq_f32(out0, input1, weight345); - - out0 = vsetq_lane_f32(bias0, out0, 3); - - float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); - - float32x2_t out01 = vpadd_f32(out00, out00); - - *out_ptr0 = vget_lane_f32(out01, 0); - - in_ptr0 += 2; - in_ptr1 += 2; - out_ptr0++; - } - } - else - { - if (nn > 0) - { - __asm __volatile("pld [%[in_ptr0], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" - "vext.32 q3, q0, q2, #1\n" - - "1:\n" - "vand q14, %q[qbias0], %q[qbias0]\n" - "vmul.f32 q10, q0, %e[weight012][0]\n" - "vmul.f32 q11, q1, %e[weight012][1]\n" - - "pld [%[in_ptr1], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" - "vmla.f32 q14, q3, %f[weight012][0]\n" - "vext.32 q3, q0, q2, #1\n" - - "vmla.f32 q10, q0, %e[weight345][0]\n" - "vmla.f32 q11, q1, %e[weight345][1]\n" - - "pld [%[in_ptr2], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr2]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr2]]\n" - "vmla.f32 q14, q3, %f[weight345][0]\n" - "vext.32 q3, q0, q2, #1\n" - - "vmla.f32 q10, q0, %e[weight678][0]\n" - "vmla.f32 q11, q1, %e[weight678][1]\n" - - "pld [%[in_ptr0], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" - "vmla.f32 q14, q3, %f[weight678][0]\n" - "vext.32 q3, q0, q2, #1\n" - - "vadd.f32 q14, q14, q10\n" - "vadd.f32 q14, q14, q11\n" - - "subs %[nn], %[nn], #1\n" - "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" - "bne 1b\n" - "sub %[in_ptr0], %[in_ptr0], #32\n" - : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), - [in_ptr2] "+r"(in_ptr2), [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) - : [weight012] "w"(weight012), [weight345] "w"(weight345), - [weight678] "w"(weight678), [qbias0] "w"(qbias0) - : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", - "q15", "cc", "memory"); - } - for (; remain > 0; remain--) - { - float32x4_t input0 = vld1q_f32(in_ptr0); - float32x4_t input1 = vld1q_f32(in_ptr1); - float32x4_t input2 = vld1q_f32(in_ptr2); - - if (remain == 1) - { - input0 = vsetq_lane_f32(0.0f, input0, 2); - input1 = vsetq_lane_f32(0.0f, input1, 2); - input2 = vsetq_lane_f32(0.0f, input2, 2); - } - - float32x4_t out0 = vmulq_f32(input0, weight012); - out0 = vmlaq_f32(out0, input1, weight345); - out0 = vmlaq_f32(out0, input2, weight678); - - out0 = vsetq_lane_f32(bias0, out0, 3); - - float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); - - float32x2_t out01 = vpadd_f32(out00, out00); - - *out_ptr0 = vget_lane_f32(out01, 0); - - in_ptr0 += 2; - in_ptr1 += 2; - in_ptr2 += 2; - out_ptr0++; - } - - in_ptr0 += w; - in_ptr1 += w; - in_ptr2 += w; - } - } - } - -#else // __aarch64__ - (void)in_mat; - (void)out_mat; - (void)kernel; - (void)bias; -#endif // __aarch64__ -} - -static void depthwise_conv3x3S2_padding11(const convMat_t &in_mat, convMat_t &out_mat, - const convMat_t &kernel, const convMat_t &bias) -{ -#if !__aarch64__ - int w = in_mat.w; - int h = in_mat.h; - int outw = out_mat.w; - int outh = out_mat.h; - int channels = in_mat.c; - -#pragma omp parallel for - for (int c = 0; c < channels; c++) - { - const float *filter = kernel.data + c * 9; -#ifdef NCNN - float *inbuf = in_mat.data + c * alignSize(w * h, 16 / sizeof(float)); - float *outbuf = out_mat.data + c * alignSize(outw * outh, 16 / sizeof(float)); -#else // NCNN - float *inbuf = in_mat.data + c * w * h; - float *outbuf = out_mat.data + c * outw * outh; -#endif // NCNN - float bias0 = bias.data ? bias.data[c] : 0.0f; - - register float32x4_t weight012 asm("q4") = vld1q_f32(filter); - register float32x4_t weight345 asm("q5") = vld1q_f32(filter + 3); - register float32x4_t weight678 asm("q6") = vld1q_f32(filter + 6); - register float32x4_t qbias0 asm("q7") = vdupq_n_f32(bias0); - - float *in_ptr0 = inbuf + 0 * w; - float *in_ptr1 = inbuf + 1 * w; - float *in_ptr2 = inbuf + 2 * w; - - float *out_ptr0 = outbuf + 0 * outw; - - int i; - for (i = 0; i < outh; i++) - { - int nn = (outw >> 2) - 1; - int remain = (outw & 0x03) + 4; - - // TODO: i == 0 && i == outh - 1 - if (i == 0) - { - if (nn > 0) - { - __asm __volatile("vmov.i32 q2, #0\n" - "pld [%[in_ptr0], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr0]]\n" - "vext.32 q3, q2, q0, #3\n" - "add %[in_ptr0], %[in_ptr0], #28\n" - - "vand q14, %q[qbias0], %q[qbias0]\n" - "vmul.f32 q10, q3, %e[weight345][0]\n" - "vmul.f32 q11, q0, %e[weight345][1]\n" - - "pld [%[in_ptr1], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr1]]\n" - "vmla.f32 q14, q1, %f[weight345][0]\n" - "vext.32 q3, q2, q0, #3\n" - "add %[in_ptr1], %[in_ptr1], #28\n" - - "vmla.f32 q10, q3, %e[weight678][0]\n" - "vmla.f32 q11, q0, %e[weight678][1]\n" - "vmla.f32 q14, q1, %f[weight678][0]\n" - - "vadd.f32 q14, q14, q10\n" - "vadd.f32 q14, q14, q11\n" - - "subs %[nn], %[nn], #1\n" - "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" - "beq 2f\n" - - "pld [%[in_ptr0], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" - "vext.32 q3, q0, q2, #1\n" - - "1:\n" - "vand q14, %q[qbias0], %q[qbias0]\n" - "vmul.f32 q10, q0, %e[weight345][0]\n" - "vmul.f32 q11, q1, %e[weight345][1]\n" - - "pld [%[in_ptr1], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" - "vmla.f32 q14, q3, %f[weight345][0]\n" - "vext.32 q3, q0, q2, #1\n" - - "vmla.f32 q10, q0, %e[weight678][0]\n" - "vmla.f32 q11, q1, %e[weight678][1]\n" - - "pld [%[in_ptr0], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" - "vmla.f32 q14, q3, %f[weight678][0]\n" - "vext.32 q3, q0, q2, #1\n" - - "vadd.f32 q14, q14, q10\n" - "vadd.f32 q14, q14, q11\n" - - "subs %[nn], %[nn], #1\n" - "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" - "bne 1b\n" - "sub %[in_ptr0], %[in_ptr0], #32\n" - "2:\n" - : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), - [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) - : [weight012] "w"(weight012), [weight345] "w"(weight345), - [weight678] "w"(weight678), [qbias0] "w"(qbias0) - : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", - "q15", "cc", "memory"); - } - for (; remain > 0; remain--) - { - // TODO: if nn == 0, pad_left comes here. - float32x4_t input0 = vld1q_f32(in_ptr0); - float32x4_t input1 = vld1q_f32(in_ptr1); - - if (remain == 1) - { - input0 = vsetq_lane_f32(0.0f, input0, 2); - input1 = vsetq_lane_f32(0.0f, input1, 2); - } - - float32x4_t out0 = vmulq_f32(input0, weight345); - out0 = vmlaq_f32(out0, input1, weight678); - - out0 = vsetq_lane_f32(bias0, out0, 3); - - float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); - - float32x2_t out01 = vpadd_f32(out00, out00); - - *out_ptr0 = vget_lane_f32(out01, 0); - - in_ptr0 += 2; - in_ptr1 += 2; - out_ptr0++; - } - - in_ptr2 += w; - } - else if (i == outh - 1) - { - if (nn > 0) - { - __asm __volatile("vmov.i32 q2, #0\n" - "pld [%[in_ptr0], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr0]]\n" - "vext.32 q3, q2, q0, #3\n" - "add %[in_ptr0], %[in_ptr0], #28\n" - - "vand q14, %q[qbias0], %q[qbias0]\n" - "vmul.f32 q10, q3, %e[weight012][0]\n" - "vmul.f32 q11, q0, %e[weight012][1]\n" - - "pld [%[in_ptr1], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr1]]\n" - "vmla.f32 q14, q1, %f[weight012][0]\n" - "vext.32 q3, q2, q0, #3\n" - "add %[in_ptr1], %[in_ptr1], #28\n" - - "vmla.f32 q10, q3, %e[weight345][0]\n" - "vmla.f32 q11, q0, %e[weight345][1]\n" - "vmla.f32 q14, q1, %f[weight345][0]\n" - - "vadd.f32 q14, q14, q10\n" - "vadd.f32 q14, q14, q11\n" - - "subs %[nn], %[nn], #1\n" - "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" - "beq 2f\n" - - "pld [%[in_ptr0], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" - "vext.32 q3, q0, q2, #1\n" - - "1:\n" - "vand q14, %q[qbias0], %q[qbias0]\n" - "vmul.f32 q10, q0, %e[weight012][0]\n" - "vmul.f32 q11, q1, %e[weight012][1]\n" - - "pld [%[in_ptr1], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" - "vmla.f32 q14, q3, %f[weight012][0]\n" - "vext.32 q3, q0, q2, #1\n" - - "vmla.f32 q10, q0, %e[weight345][0]\n" - "vmla.f32 q11, q1, %e[weight345][1]\n" - - "pld [%[in_ptr0], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" - "vmla.f32 q14, q3, %f[weight345][0]\n" - "vext.32 q3, q0, q2, #1\n" - - "vadd.f32 q14, q14, q10\n" - "vadd.f32 q14, q14, q11\n" - - "subs %[nn], %[nn], #1\n" - "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" - "bne 1b\n" - "sub %[in_ptr0], %[in_ptr0], #32\n" - - "2:\n" - : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), - [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) - : [weight012] "w"(weight012), [weight345] "w"(weight345), - [weight678] "w"(weight678), [qbias0] "w"(qbias0) - : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", - "q15", "cc", "memory"); - } - for (; remain > 0; remain--) - { - // TODO: if nn == 0, pad_left comes here. - float32x4_t input0 = vld1q_f32(in_ptr0); - float32x4_t input1 = vld1q_f32(in_ptr1); - - if (remain == 1) - { - input0 = vsetq_lane_f32(0.0f, input0, 2); - input1 = vsetq_lane_f32(0.0f, input1, 2); - } - - float32x4_t out0 = vmulq_f32(input0, weight012); - out0 = vmlaq_f32(out0, input1, weight345); - - out0 = vsetq_lane_f32(bias0, out0, 3); - - float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); - - float32x2_t out01 = vpadd_f32(out00, out00); - - *out_ptr0 = vget_lane_f32(out01, 0); - - in_ptr0 += 2; - in_ptr1 += 2; - out_ptr0++; - } - } - else - { - if (nn > 0) - { - __asm __volatile("vmov.i32 q2, #0\n" - "pld [%[in_ptr0], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr0]]\n" - "vext.32 q3, q2, q0, #3\n" - "add %[in_ptr0], %[in_ptr0], #28\n" - - "vand q14, %q[qbias0], %q[qbias0]\n" - "vmul.f32 q10, q3, %e[weight012][0]\n" - "vmul.f32 q11, q0, %e[weight012][1]\n" - - "pld [%[in_ptr1], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr1]]\n" - "vmla.f32 q14, q1, %f[weight012][0]\n" - "vext.32 q3, q2, q0, #3\n" - "add %[in_ptr1], %[in_ptr1], #28\n" - - "vmla.f32 q10, q3, %e[weight345][0]\n" - "vmla.f32 q11, q0, %e[weight345][1]\n" - - "pld [%[in_ptr2], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr2]]\n" - "vmla.f32 q14, q1, %f[weight345][0]\n" - "vext.32 q3, q2, q0, #3\n" - "add %[in_ptr2], %[in_ptr2], #28\n" - - "vmla.f32 q10, q3, %e[weight678][0]\n" - "vmla.f32 q11, q0, %e[weight678][1]\n" - "vmla.f32 q14, q1, %f[weight678][0]\n" - - "vadd.f32 q14, q14, q10\n" - "vadd.f32 q14, q14, q11\n" - - "subs %[nn], %[nn], #1\n" - "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" - "beq 2f\n" - - "pld [%[in_ptr0], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" - "vext.32 q3, q0, q2, #1\n" - - "1:\n" - "vand q14, %q[qbias0], %q[qbias0]\n" - "vmul.f32 q10, q0, %e[weight012][0]\n" - "vmul.f32 q11, q1, %e[weight012][1]\n" - - "pld [%[in_ptr1], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" - "vmla.f32 q14, q3, %f[weight012][0]\n" - "vext.32 q3, q0, q2, #1\n" - - "vmla.f32 q10, q0, %e[weight345][0]\n" - "vmla.f32 q11, q1, %e[weight345][1]\n" - - "pld [%[in_ptr2], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr2]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr2]]\n" - "vmla.f32 q14, q3, %f[weight345][0]\n" - "vext.32 q3, q0, q2, #1\n" - - "vmla.f32 q10, q0, %e[weight678][0]\n" - "vmla.f32 q11, q1, %e[weight678][1]\n" - - "pld [%[in_ptr0], #256]\n" - "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" - "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" - "vmla.f32 q14, q3, %f[weight678][0]\n" - "vext.32 q3, q0, q2, #1\n" - - "vadd.f32 q14, q14, q10\n" - "vadd.f32 q14, q14, q11\n" - - "subs %[nn], %[nn], #1\n" - "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" - "bne 1b\n" - "sub %[in_ptr0], %[in_ptr0], #32\n" - "2:\n" - : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), - [in_ptr2] "+r"(in_ptr2), [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) - : [weight012] "w"(weight012), [weight345] "w"(weight345), - [weight678] "w"(weight678), [qbias0] "w"(qbias0) - : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", - "q15", "cc", "memory"); - } - for (; remain > 0; remain--) - { - // TODO: if nn == 0, pad_left comes here. - float32x4_t input0 = vld1q_f32(in_ptr0); - float32x4_t input1 = vld1q_f32(in_ptr1); - float32x4_t input2 = vld1q_f32(in_ptr2); - - if (remain == 1) - { - input0 = vsetq_lane_f32(0.0f, input0, 2); - input1 = vsetq_lane_f32(0.0f, input1, 2); - input2 = vsetq_lane_f32(0.0f, input2, 2); - } - - float32x4_t out0 = vmulq_f32(input0, weight012); - out0 = vmlaq_f32(out0, input1, weight345); - out0 = vmlaq_f32(out0, input2, weight678); - - out0 = vsetq_lane_f32(bias0, out0, 3); - - float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); - - float32x2_t out01 = vpadd_f32(out00, out00); - - *out_ptr0 = vget_lane_f32(out01, 0); - - in_ptr0 += 2; - in_ptr1 += 2; - in_ptr2 += 2; - out_ptr0++; - } - - in_ptr0 += w; - in_ptr1 += w; - in_ptr2 += w; - } - } - } -#else // __aarch64__ - (void)in_mat; - (void)out_mat; - (void)kernel; - (void)bias; -#endif // __aarch64__ -} - -static void depthwise_conv_colmajor(const convMat_t &in_mat, convMat_t &out_mat, - const convMat_t &kernel, const convParams_t &in_param) -{ -#if __aarch64__ - const int w = in_mat.w; - const int h = in_mat.h; - const int outw = out_mat.w; - const int outh = out_mat.h; - const int channels = out_mat.c; - const int stridew = in_param.stride_w; - const int strideh = in_param.stride_h; - const int padding = in_param.padding; - const int padw = in_param.pad_w; - const int padh = in_param.pad_h; - -#pragma omp parallel for - for (int oh = 0; oh < outh; oh++) - { - const float *input_data0 = in_mat.data + (oh * strideh - padh) * w * channels; - - memset(out_mat.data + oh * outw * channels, 0x00, outw * channels * sizeof(float)); - - for (int kh = 0; kh < in_param.kernel_h; kh++) - { - for (int kw = 0; kw < in_param.kernel_w; kw++) - { - const float *kernel_data = kernel.data + (kh * in_param.kernel_w + kw) * channels; - const float *input_data1 = input_data0 + (kh * w + kw) * channels; - - if (padding && ((oh * strideh + kh < padh) || (oh * strideh + kh >= padh + h))) - { - continue; - } - - int ow = 0; - for (; ow + 3 < outw; /*ow += 4*/) - { - if (((ow + 3) * stridew + kw < padw) || (ow * stridew + kw >= padw + w)) - { - ow += 4; - continue; - } - else if ((ow + 3) * stridew + kw >= padw + w) - { - break; - } - else if (ow * stridew + kw < padw) - { - int delta = (padw - kw) / stridew - ow; - delta += (padw - kw) % stridew ? 1 : 0; - ow += delta; - continue; - } - - int nn = channels >> 2; - int remain = channels & 0x03; - - const float *input_r0 = input_data1 + (ow * stridew - padw) * channels; - - const float *input_r1 = input_r0 + stridew * channels; - const float *input_r2 = input_r1 + stridew * channels; - const float *input_r3 = input_r2 + stridew * channels; - const float *weights_data = kernel_data; - float *output_r0 = out_mat.data + (oh * outw + ow) * channels; - float *output_r1 = output_r0 + channels; - float *output_r2 = output_r1 + channels; - float *output_r3 = output_r2 + channels; - - if (nn > 0) - { - int _n = (nn + 1) >> 1; - int oddn = nn & 1; - - asm volatile("subs %[_n], %[_n], #1\n" - "ld1 {v4.4s}, [%[weights_data]], #16\n" - "ld1 {v5.4s}, [%[input_r0]], #16\n" - "ld1 {v6.4s}, [%[input_r1]], #16\n" - "ld1 {v7.4s}, [%[input_r2]], #16\n" - "ld1 {v8.4s}, [%[input_r3]], #16\n" - "beq 1f\n" - - "0:\n" - "ld1 {v24.4s, v25.4s}, [%[output_r0]]\n" - "ld1 {v26.4s, v27.4s}, [%[output_r1]]\n" - "ld1 {v28.4s, v29.4s}, [%[output_r2]]\n" - "ld1 {v30.4s, v31.4s}, [%[output_r3]]\n" - - "ld1 {v9.4s}, [%[weights_data]], #16\n" - "ld1 {v10.4s}, [%[input_r0]], #16\n" - "ld1 {v11.4s}, [%[input_r1]], #16\n" - "ld1 {v12.4s}, [%[input_r2]], #16\n" - "ld1 {v13.4s}, [%[input_r3]], #16\n" - - "fmla v24.4s, v4.4s, v5.4s\n" - "fmla v26.4s, v4.4s, v6.4s\n" - - "fmla v28.4s, v4.4s, v7.4s\n" - "fmla v30.4s, v4.4s, v8.4s\n" - - "ld1 {v4.4s}, [%[weights_data]], #16\n" - "ld1 {v5.4s}, [%[input_r0]], #16\n" - "ld1 {v6.4s}, [%[input_r1]], #16\n" - "ld1 {v7.4s}, [%[input_r2]], #16\n" - "ld1 {v8.4s}, [%[input_r3]], #16\n" - - "fmla v25.4s, v9.4s, v10.4s\n" - "fmla v27.4s, v9.4s, v11.4s\n" - - "fmla v29.4s, v9.4s, v12.4s\n" - "fmla v31.4s, v9.4s, v13.4s\n" - - "st1 {v24.4s, v25.4s}, [%[output_r0]], #32\n" - "st1 {v26.4s, v27.4s}, [%[output_r1]], #32\n" - "st1 {v28.4s, v29.4s}, [%[output_r2]], #32\n" - "st1 {v30.4s, v31.4s}, [%[output_r3]], #32\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "ld1 {v24.4s}, [%[output_r0]]\n" - "ld1 {v26.4s}, [%[output_r1]]\n" - "ld1 {v28.4s}, [%[output_r2]]\n" - "ld1 {v30.4s}, [%[output_r3]]\n" - "cmp %[oddn], #1\n" - - "fmla v24.4s, v4.4s, v5.4s\n" - "fmla v26.4s, v4.4s, v6.4s\n" - - "fmla v28.4s, v4.4s, v7.4s\n" - "fmla v30.4s, v4.4s, v8.4s\n" - - "st1 {v24.4s}, [%[output_r0]], #16\n" - "st1 {v26.4s}, [%[output_r1]], #16\n" - "st1 {v28.4s}, [%[output_r2]], #16\n" - "st1 {v30.4s}, [%[output_r3]], #16\n" - - "beq 2f\n" - "ld1 {v25.4s}, [%[output_r0]]\n" - "ld1 {v27.4s}, [%[output_r1]]\n" - "ld1 {v29.4s}, [%[output_r2]]\n" - "ld1 {v31.4s}, [%[output_r3]]\n" - - "ld1 {v9.4s}, [%[weights_data]], #16\n" - "ld1 {v10.4s}, [%[input_r0]], #16\n" - "ld1 {v11.4s}, [%[input_r1]], #16\n" - "ld1 {v12.4s}, [%[input_r2]], #16\n" - "ld1 {v13.4s}, [%[input_r3]], #16\n" - - "fmla v25.4s, v9.4s, v10.4s\n" - "fmla v27.4s, v9.4s, v11.4s\n" - - "fmla v29.4s, v9.4s, v12.4s\n" - "fmla v31.4s, v9.4s, v13.4s\n" - - "st1 {v25.4s}, [%[output_r0]], #16\n" - "st1 {v27.4s}, [%[output_r1]], #16\n" - "st1 {v29.4s}, [%[output_r2]], #16\n" - "st1 {v31.4s}, [%[output_r3]], #16\n" - "2:\n" - : [weights_data] "+r"(weights_data), [input_r0] "+r"(input_r0), - [input_r1] "+r"(input_r1), [input_r2] "+r"(input_r2), - [input_r3] "+r"(input_r3), [output_r0] "+r"(output_r0), - [output_r1] "+r"(output_r1), [output_r2] "+r"(output_r2), - [output_r3] "+r"(output_r3), [_n] "+r"(_n) - : [oddn] "r"(oddn) - : "cc", "memory", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); - } - if (remain >= 2) - { - asm volatile( - "ld1 {v24.2s}, [%[output_r0]]\n" - "ld1 {v26.2s}, [%[output_r1]]\n" - "ld1 {v28.2s}, [%[output_r2]]\n" - "ld1 {v30.2s}, [%[output_r3]]\n" - "ld1 {v4.2s}, [%[weights_data]], #8\n" - "ld1 {v5.2s}, [%[input_r0]], #8\n" - - "ld1 {v6.2s}, [%[input_r1]], #8\n" - "ld1 {v7.2s}, [%[input_r2]], #8\n" - "ld1 {v8.2s}, [%[input_r3]], #8\n" - - "fmla v24.2s, v4.2s, v5.2s\n" - "fmla v26.2s, v4.2s, v6.2s\n" - - "fmla v28.2s, v4.2s, v7.2s\n" - "fmla v30.2s, v4.2s, v8.2s\n" - - "st1 {v24.2s}, [%[output_r0]], #8\n" - "st1 {v26.2s}, [%[output_r1]], #8\n" - "st1 {v28.2s}, [%[output_r2]], #8\n" - "st1 {v30.2s}, [%[output_r3]], #8\n" - : [weights_data] "+r"(weights_data), [input_r0] "+r"(input_r0), - [input_r1] "+r"(input_r1), [input_r2] "+r"(input_r2), [input_r3] "+r"(input_r3), - [output_r0] "+r"(output_r0), [output_r1] "+r"(output_r1), - [output_r2] "+r"(output_r2), [output_r3] "+r"(output_r3) - : - : "cc", "memory", "v4", "v5", "v6", "v7", "v8", "v24", "v26", "v28", "v30"); - remain -= 2; - } - - if (remain > 0) - { - *output_r0++ += (*weights_data) * (*input_r0++); - *output_r1++ += (*weights_data++) * (*input_r1++); - *output_r2++ += (*weights_data) * (*input_r2++); - *output_r3++ += (*weights_data++) * (*input_r3++); - } - ow += 4; - } - - for (; ow + 1 < outw; /*ow += 2*/) - { - if (padding) - { - if (((ow + 1) * stridew + kw < padw) || (ow * stridew + kw >= padw + w)) - { - ow += 2; - continue; - } - else if ((ow + 1) * stridew + kw >= padw + w) - { - break; - } - else if (ow * stridew + kw < padw) - { - ow++; - continue; - } - } - - int nn = channels >> 2; - int remain = channels & 0x03; - - const float *input_r0 = input_data1 + (ow * stridew - padw) * channels; - - const float *input_r1 = input_r0 + stridew * channels; - const float *weights_data = kernel_data; - float *output_r0 = out_mat.data + (oh * outw + ow) * channels; - float *output_r1 = output_r0 + channels; - - if (nn > 0) - { - int _n = (nn + 1) >> 1; - int oddn = nn & 1; - - asm volatile("subs %[_n], %[_n], #1\n" - "ld1 {v4.4s}, [%[weights_data]], #16\n" - "ld1 {v5.4s}, [%[input_r0]], #16\n" - "ld1 {v6.4s}, [%[input_r1]], #16\n" - "beq 1f\n" - - "0:\n" - "ld1 {v24.4s, v25.4s}, [%[output_r0]]\n" - "ld1 {v26.4s, v27.4s}, [%[output_r1]]\n" - - "ld1 {v9.4s}, [%[weights_data]], #16\n" - "ld1 {v10.4s}, [%[input_r0]], #16\n" - "ld1 {v11.4s}, [%[input_r1]], #16\n" - - "fmla v24.4s, v4.4s, v5.4s\n" - "fmla v26.4s, v4.4s, v6.4s\n" - - "ld1 {v4.4s}, [%[weights_data]], #16\n" - "ld1 {v5.4s}, [%[input_r0]], #16\n" - "ld1 {v6.4s}, [%[input_r1]], #16\n" - - "fmla v25.4s, v9.4s, v10.4s\n" - "fmla v27.4s, v9.4s, v11.4s\n" - - "st1 {v24.4s, v25.4s}, [%[output_r0]], #32\n" - "st1 {v26.4s, v27.4s}, [%[output_r1]], #32\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "ld1 {v24.4s}, [%[output_r0]]\n" - "ld1 {v26.4s}, [%[output_r1]]\n" - "cmp %[oddn], #1\n" - - "fmla v24.4s, v4.4s, v5.4s\n" - "fmla v26.4s, v4.4s, v6.4s\n" - - "st1 {v24.4s}, [%[output_r0]], #16\n" - "st1 {v26.4s}, [%[output_r1]], #16\n" - - "beq 2f\n" - "ld1 {v25.4s}, [%[output_r0]]\n" - "ld1 {v27.4s}, [%[output_r1]]\n" - - "ld1 {v9.4s}, [%[weights_data]], #16\n" - "ld1 {v10.4s}, [%[input_r0]], #16\n" - "ld1 {v11.4s}, [%[input_r1]], #16\n" - - "fmla v25.4s, v9.4s, v10.4s\n" - "fmla v27.4s, v9.4s, v11.4s\n" - - "st1 {v25.4s}, [%[output_r0]], #16\n" - "st1 {v27.4s}, [%[output_r1]], #16\n" - "2:\n" - : [weights_data] "+r"(weights_data), [input_r0] "+r"(input_r0), - [input_r1] "+r"(input_r1), [output_r0] "+r"(output_r0), - [output_r1] "+r"(output_r1), [_n] "+r"(_n) - : [oddn] "r"(oddn) - : "cc", "memory", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); - } - if (remain >= 2) - { - asm volatile("ld1 {v24.2s}, [%[output_r0]]\n" - "ld1 {v26.2s}, [%[output_r1]]\n" - "ld1 {v4.2s}, [%[weights_data]], #8\n" - "ld1 {v5.2s}, [%[input_r0]], #8\n" - - "ld1 {v6.2s}, [%[input_r1]], #8\n" - - "fmla v24.2s, v4.2s, v5.2s\n" - "fmla v26.2s, v4.2s, v6.2s\n" - - "st1 {v24.2s}, [%[output_r0]], #8\n" - "st1 {v26.2s}, [%[output_r1]], #8\n" - : [weights_data] "+r"(weights_data), [input_r0] "+r"(input_r0), - [input_r1] "+r"(input_r1), [output_r0] "+r"(output_r0), - [output_r1] "+r"(output_r1) - : - : "cc", "memory", "v4", "v5", "v6", "v7", "v8", "v24", "v26", "v28", - "v30"); - remain -= 2; - } - - if (remain > 0) - { - *output_r0++ += (*weights_data) * (*input_r0++); - *output_r1++ += (*weights_data++) * (*input_r1++); - } - ow += 2; - } - - for (; ow < outw; ow++) - { - const float *input_data = input_data1 + (ow * stridew - padw) * channels; - - if (padding && ((ow * stridew + kw < padw) || (ow * strideh + kw >= padw + w))) - { - continue; - } - - int nn = channels >> 2; - int remain = channels & 0x03; - - const float *weights_data = kernel_data; - float *output_data = out_mat.data + (oh * outw + ow) * channels; - - if (nn > 0) - { - int _n = (nn + 1) >> 1; - int oddn = nn & 1; - - asm volatile("subs %[_n], %[_n], #1\n" - "ld1 {v4.4s}, [%[weights_data]], #16\n" - "ld1 {v5.4s}, [%[input_data]], #16\n" - "beq 1f\n" - - "0:\n" - "ld1 {v30.4s, v31.4s}, [%[output_data]]\n" - "ld1 {v6.4s}, [%[weights_data]], #16\n" - "ld1 {v7.4s}, [%[input_data]], #16\n" - "fmla v30.4s, v4.4s, v5.4s\n" - - "ld1 {v4.4s}, [%[weights_data]], #16\n" - "ld1 {v5.4s}, [%[input_data]], #16\n" - "fmla v31.4s, v6.4s, v7.4s\n" - - "st1 {v30.4s, v31.4s}, [%[output_data]], #32\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "ld1 {v30.4s}, [%[output_data]]\n" - "cmp %[oddn], #1\n" - "fmla v30.4s, v4.4s, v5.4s\n" - "st1 {v30.4s}, [%[output_data]], #16\n" - "beq 2f\n" - "ld1 {v31.4s}, [%[output_data]]\n" - "ld1 {v6.4s}, [%[weights_data]], #16\n" - "ld1 {v7.4s}, [%[input_data]], #16\n" - "fmla v31.4s, v6.4s, v7.4s\n" - - "st1 {v31.4s}, [%[output_data]], #16\n" - "2:\n" - : [weights_data] "+r"(weights_data), [input_data] "+r"(input_data), - [output_data] "+r"(output_data), [_n] "+r"(_n) - : [oddn] "r"(oddn) - : "cc", "memory", "v4", "v5", "v30", "v31"); - } - if (remain >= 2) - { - asm volatile("ld1 {v30.2s}, [%[output_data]]\n" - "ld1 {v4.2s}, [%[weights_data]], #8\n" - "ld1 {v5.2s}, [%[input_data]], #8\n" - - "fmla v30.2s, v4.2s, v5.2s\n" - - "st1 {v30.2s}, [%[output_data]], #8\n" - : [weights_data] "+r"(weights_data), [input_data] "+r"(input_data), - [output_data] "+r"(output_data) - : - : "cc", "memory", "v4", "v5", "v30"); - remain -= 2; - } - - if (remain > 0) - { - *output_data++ += (*weights_data++) * (*input_data++); - } - } - } - } - } -#else // __aarch64__ - (void)in_mat; - (void)out_mat; - (void)kernel; - (void)in_param; -#endif // __aarch64__ -} - -void srcn_depthwise_conv(const convMat_t &in_mat, const convMat_t &weights_mat, convMat_t &out_mat, - const convMat_t &bias, const convParams_t &in_param, int num_threads, - convType_t conv_type) -{ - omp_set_num_threads(num_threads); - - if (conv_type == col_major) - { - depthwise_conv_colmajor(in_mat, out_mat, weights_mat, in_param); - return; - } - - else if (conv_type == row_major) - { - if (in_param.kernel_w == 3 && in_param.kernel_h == 3 && in_param.dilation_w == 1 && - in_param.dilation_h == 1) - { - if (in_param.stride_w == 1 && in_param.stride_h == 1) - { - if (in_param.padding == 0) - depthwise_conv3x3S1_nopad(in_mat, out_mat, weights_mat, bias); - else - depthwise_conv3x3S1_padding(in_mat, out_mat, weights_mat, bias); - } - else if (in_param.stride_w == 2 && in_param.stride_h == 2) - { - if (in_param.padding == 0) - depthwise_conv3x3S2_nopad(in_mat, out_mat, weights_mat, bias); - else - { - if (in_param.pad_w == 0 && in_param.pad_h == 0) - depthwise_conv3x3S2_padding00(in_mat, out_mat, weights_mat, bias); - else if (in_param.pad_w == 0 && in_param.pad_h == 1) - depthwise_conv3x3S2_padding10(in_mat, out_mat, weights_mat, bias); - else if (in_param.pad_w == 1 && in_param.pad_h == 0) - depthwise_conv3x3S2_padding01(in_mat, out_mat, weights_mat, bias); - else if (in_param.pad_w == 1 && in_param.pad_h == 1) - depthwise_conv3x3S2_padding11(in_mat, out_mat, weights_mat, bias); - } - } - } - } -} - -} // namespace srcn -} // namespace nnfw diff --git a/compute/ncnn/src/srcn/direct_conv_colmajor.cc b/compute/ncnn/src/srcn/direct_conv_colmajor.cc deleted file mode 100644 index 300235222..000000000 --- a/compute/ncnn/src/srcn/direct_conv_colmajor.cc +++ /dev/null @@ -1,5872 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifdef _OPENMP -#include <omp.h> -#endif - -#include <stdlib.h> -#include <arm_neon.h> -#include "ncnn/srcn/conv_type.h" - -namespace nnfw -{ -namespace srcn -{ - -#if __aarch64__ -static void direct_conv_l(const convMat_t &bottom_blob, convMat_t &top_blob, - const convMat_t &_kernel, const int _stride, const int padding, - const int pad_top, const int pad_left) -{ - const int w = bottom_blob.w; - const int h = bottom_blob.h; - const int inch = bottom_blob.c; - const int outw = top_blob.w; - const int outh = top_blob.h; - const int outch = top_blob.c; - const int kernel_w = _kernel.w; - const int kernel_h = _kernel.h; - - for (int m = 0; m < kernel_w * kernel_h; m++) - { - const float *_kernel0 = _kernel.data + m * inch * outch; - const float *img0 = - bottom_blob.data + (m / kernel_w - pad_top) * w * inch + (m % kernel_w - pad_left) * inch; - -#ifdef _OPENMP -#pragma omp parallel for -#endif // _OPENMP - for (int p = 0; p < outh; p++) - { - float *out0 = top_blob.data + p * outw * outch; - - // clear output - if (m == 0) - { - for (int j = 0; j < outw * outch; j++) - { - *(out0 + j) = 0.f; - } - } - - if (padding) - { - if (((p * _stride + m / kernel_w) < pad_top) || (p * _stride + m / kernel_w >= pad_top + h)) - { - continue; - } - } - - const float *img1 = img0 + p * w * inch * _stride; - - int q = 0; - for (; q + 3 < outw; /*q += 4*/) - { - if (padding) - { - if (((q + 3) * _stride + m % kernel_w < pad_left) || - (q * _stride + m % kernel_w) >= pad_left + w) - { - out0 += outch * 4; - img1 += inch * _stride * 4; - q += 4; - continue; - } - else if ((q + 3) * _stride + m % kernel_w >= pad_left + w) - { - break; - } - else if (q * _stride + m % kernel_w < pad_left) - { - int delta = (pad_left - m % kernel_w) / _stride - q; - delta += (pad_left - m % kernel_w) % _stride ? 1 : 0; - out0 += outch * delta; - img1 += inch * _stride * delta; - q += delta; - continue; - } - } - - const float *_x0 = img1; - const float *_x1 = img1 + inch * _stride; - const float *_x2 = img1 + inch * _stride * 2; - const float *_x3 = img1 + inch * _stride * 3; - const float *kernel0 = _kernel0; - - int i = 0; - for (; i + 3 < inch; i += 4) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x4_t rx0 asm("v4") = vld1q_f32(_x0); - register float32x4_t rx1 asm("v5") = vld1q_f32(_x1); - register float32x4_t rx2 asm("v16") = vld1q_f32(_x2); - register float32x4_t rx3 asm("v17") = vld1q_f32(_x3); - - float *outptr0 = out0; - float *outptr1 = out0 + outch; - float *outptr2 = out0 + outch * 2; - float *outptr3 = out0 + outch * 3; - - int stride = outch << 2; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile("cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v9.4s}, [x0]\n" - - "beq 1f\n" - - "0:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v11.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v12.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v13.4s}, [x0]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v30.4s, v6.4s, %[rx2].s[0]\n" - "fmla v31.4s, v6.4s, %[rx3].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v15.4s, v7.4s, %[rx1].s[1]\n" - "fmla v30.4s, v7.4s, %[rx2].s[1]\n" - "fmla v31.4s, v7.4s, %[rx3].s[1]\n" - "fmla v14.4s, v8.4s, %[rx0].s[2]\n" - "fmla v15.4s, v8.4s, %[rx1].s[2]\n" - "fmla v30.4s, v8.4s, %[rx2].s[2]\n" - "fmla v31.4s, v8.4s, %[rx3].s[2]\n" - "fmla v14.4s, v9.4s, %[rx0].s[3]\n" - "fmla v15.4s, v9.4s, %[rx1].s[3]\n" - "fmla v30.4s, v9.4s, %[rx2].s[3]\n" - "fmla v31.4s, v9.4s, %[rx3].s[3]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v9.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v15.4s, v10.4s, %[rx1].s[0]\n" - "fmla v30.4s, v10.4s, %[rx2].s[0]\n" - "fmla v31.4s, v10.4s, %[rx3].s[0]\n" - "fmla v14.4s, v11.4s, %[rx0].s[1]\n" - "fmla v15.4s, v11.4s, %[rx1].s[1]\n" - "fmla v30.4s, v11.4s, %[rx2].s[1]\n" - "fmla v31.4s, v11.4s, %[rx3].s[1]\n" - "fmla v14.4s, v12.4s, %[rx0].s[2]\n" - "fmla v15.4s, v12.4s, %[rx1].s[2]\n" - "fmla v30.4s, v12.4s, %[rx2].s[2]\n" - "fmla v31.4s, v12.4s, %[rx3].s[2]\n" - "fmla v14.4s, v13.4s, %[rx0].s[3]\n" - "fmla v15.4s, v13.4s, %[rx1].s[3]\n" - "fmla v30.4s, v13.4s, %[rx2].s[3]\n" - "fmla v31.4s, v13.4s, %[rx3].s[3]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v30.4s, v6.4s, %[rx2].s[0]\n" - "fmla v31.4s, v6.4s, %[rx3].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v15.4s, v7.4s, %[rx1].s[1]\n" - "fmla v30.4s, v7.4s, %[rx2].s[1]\n" - "fmla v31.4s, v7.4s, %[rx3].s[1]\n" - "fmla v14.4s, v8.4s, %[rx0].s[2]\n" - "fmla v15.4s, v8.4s, %[rx1].s[2]\n" - "fmla v30.4s, v8.4s, %[rx2].s[2]\n" - "fmla v31.4s, v8.4s, %[rx3].s[2]\n" - "fmla v14.4s, v9.4s, %[rx0].s[3]\n" - "fmla v15.4s, v9.4s, %[rx1].s[3]\n" - "fmla v30.4s, v9.4s, %[rx2].s[3]\n" - "fmla v31.4s, v9.4s, %[rx3].s[3]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v11.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v12.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v13.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v15.4s, v10.4s, %[rx1].s[0]\n" - "fmla v30.4s, v10.4s, %[rx2].s[0]\n" - "fmla v31.4s, v10.4s, %[rx3].s[0]\n" - "fmla v14.4s, v11.4s, %[rx0].s[1]\n" - "fmla v15.4s, v11.4s, %[rx1].s[1]\n" - "fmla v30.4s, v11.4s, %[rx2].s[1]\n" - "fmla v31.4s, v11.4s, %[rx3].s[1]\n" - "fmla v14.4s, v12.4s, %[rx0].s[2]\n" - "fmla v15.4s, v12.4s, %[rx1].s[2]\n" - "fmla v30.4s, v12.4s, %[rx2].s[2]\n" - "fmla v31.4s, v12.4s, %[rx3].s[2]\n" - "fmla v14.4s, v13.4s, %[rx0].s[3]\n" - "fmla v15.4s, v13.4s, %[rx1].s[3]\n" - "fmla v30.4s, v13.4s, %[rx2].s[3]\n" - "fmla v31.4s, v13.4s, %[rx3].s[3]\n" - - "cmp %[oddn], #1\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - - "bne 3f\n" - - "2:\n" - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v9.4s}, [x0]\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v30.4s, v6.4s, %[rx2].s[0]\n" - "fmla v31.4s, v6.4s, %[rx3].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v15.4s, v7.4s, %[rx1].s[1]\n" - "fmla v30.4s, v7.4s, %[rx2].s[1]\n" - "fmla v31.4s, v7.4s, %[rx3].s[1]\n" - "fmla v14.4s, v8.4s, %[rx0].s[2]\n" - "fmla v15.4s, v8.4s, %[rx1].s[2]\n" - "fmla v30.4s, v8.4s, %[rx2].s[2]\n" - "fmla v31.4s, v8.4s, %[rx3].s[2]\n" - "fmla v14.4s, v9.4s, %[rx0].s[3]\n" - "fmla v15.4s, v9.4s, %[rx1].s[3]\n" - "fmla v30.4s, v9.4s, %[rx2].s[3]\n" - "fmla v31.4s, v9.4s, %[rx3].s[3]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), - [outptr1] "+r"(outptr1), [_n] "+r"(_n), [outptr2] "+r"(outptr2), - [outptr3] "+r"(outptr3) - : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn), - [rx2] "w"(rx2), [rx3] "w"(rx3) - : "cc", "memory", "x0", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", - "v14", "v15", "v30", "v31"); - } - - if (remain >= 2) - { - asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" - "ld1 {v15.2s}, [%[outptr1]]\n" - "ld1 {v30.2s}, [%[outptr2]]\n" - "ld1 {v31.2s}, [%[outptr3]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "ld1 {v6.2s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.2s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v8.2s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v9.2s}, [x0]\n" - - "fmla v14.2s, v6.2s, %[rx0].s[0]\n" - "fmla v15.2s, v6.2s, %[rx1].s[0]\n" - "fmla v30.2s, v6.2s, %[rx2].s[0]\n" - "fmla v31.2s, v6.2s, %[rx3].s[0]\n" - "fmla v14.2s, v7.2s, %[rx0].s[1]\n" - "fmla v15.2s, v7.2s, %[rx1].s[1]\n" - "fmla v30.2s, v7.2s, %[rx2].s[1]\n" - "fmla v31.2s, v7.2s, %[rx3].s[1]\n" - "fmla v14.2s, v8.2s, %[rx0].s[2]\n" - "fmla v15.2s, v8.2s, %[rx1].s[2]\n" - "fmla v30.2s, v8.2s, %[rx2].s[2]\n" - "fmla v31.2s, v8.2s, %[rx3].s[2]\n" - "fmla v14.2s, v9.2s, %[rx0].s[3]\n" - "fmla v15.2s, v9.2s, %[rx1].s[3]\n" - "fmla v30.2s, v9.2s, %[rx2].s[3]\n" - "fmla v31.2s, v9.2s, %[rx3].s[3]\n" - - "st1 {v14.2s}, [%[outptr0]], #8\n" - "st1 {v15.2s}, [%[outptr1]], #8\n" - "st1 {v30.2s}, [%[outptr2]], #8\n" - "st1 {v31.2s}, [%[outptr3]], #8\n" - - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), - [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3) - : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), - - [rx2] "w"(rx2), [rx3] "w"(rx3) - : "cc", "memory", "x0", "v6", "v7", "v8", "v9", "v14", "v15", "v30", - "v31"); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)) + - (*(kernel0 + outch * 2)) * (*(_x0 + 2)) + - (*(kernel0 + outch * 3)) * (*(_x0 + 3)); - - *outptr1 += (*kernel0) * (*_x1) + (*(kernel0 + outch)) * (*(_x1 + 1)) + - (*(kernel0 + outch * 2)) * (*(_x1 + 2)) + - (*(kernel0 + outch * 3)) * (*(_x1 + 3)); - - *outptr2 += (*kernel0) * (*_x2) + (*(kernel0 + outch)) * (*(_x2 + 1)) + - (*(kernel0 + outch * 2)) * (*(_x2 + 2)) + - (*(kernel0 + outch * 3)) * (*(_x2 + 3)); - - *outptr3 += (*kernel0) * (*_x3) + (*(kernel0 + outch)) * (*(_x3 + 1)) + - (*(kernel0 + outch * 2)) * (*(_x3 + 2)) + - (*(kernel0 + outch * 3)) * (*(_x3 + 3)); - - kernel0++; - outptr0++; - outptr1++; - outptr2++; - outptr3++; - } - - kernel0 += outch * 3; - _x0 += 4; - _x1 += 4; - _x2 += 4; - _x3 += 4; - } - - for (; i + 1 < inch; i += 2) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x2_t rx0 asm("v4") = vld1_f32(_x0); - register float32x2_t rx1 asm("v5") = vld1_f32(_x1); - register float32x2_t rx2 asm("v16") = vld1_f32(_x2); - register float32x2_t rx3 asm("v17") = vld1_f32(_x3); - - float *outptr0 = out0; - float *outptr1 = out0 + outch; - float *outptr2 = out0 + outch * 2; - float *outptr3 = out0 + outch * 3; - - int stride = outch << 2; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile( - "cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - - "beq 1f\n" - - "0:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v11.4s}, [x0]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v30.4s, v6.4s, %[rx2].s[0]\n" - "fmla v31.4s, v6.4s, %[rx3].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v15.4s, v7.4s, %[rx1].s[1]\n" - "fmla v30.4s, v7.4s, %[rx2].s[1]\n" - "fmla v31.4s, v7.4s, %[rx3].s[1]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v15.4s, v10.4s, %[rx1].s[0]\n" - "fmla v30.4s, v10.4s, %[rx2].s[0]\n" - "fmla v31.4s, v10.4s, %[rx3].s[0]\n" - "fmla v14.4s, v11.4s, %[rx0].s[1]\n" - "fmla v15.4s, v11.4s, %[rx1].s[1]\n" - "fmla v30.4s, v11.4s, %[rx2].s[1]\n" - "fmla v31.4s, v11.4s, %[rx3].s[1]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v30.4s, v6.4s, %[rx2].s[0]\n" - "fmla v31.4s, v6.4s, %[rx3].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v15.4s, v7.4s, %[rx1].s[1]\n" - "fmla v30.4s, v7.4s, %[rx2].s[1]\n" - "fmla v31.4s, v7.4s, %[rx3].s[1]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v11.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v15.4s, v10.4s, %[rx1].s[0]\n" - "fmla v30.4s, v10.4s, %[rx2].s[0]\n" - "fmla v31.4s, v10.4s, %[rx3].s[0]\n" - "fmla v14.4s, v11.4s, %[rx0].s[1]\n" - "fmla v15.4s, v11.4s, %[rx1].s[1]\n" - "fmla v30.4s, v11.4s, %[rx2].s[1]\n" - "fmla v31.4s, v11.4s, %[rx3].s[1]\n" - - "cmp %[oddn], #1\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - - "bne 3f\n" - - "2:\n" - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v30.4s, v6.4s, %[rx2].s[0]\n" - "fmla v31.4s, v6.4s, %[rx3].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v15.4s, v7.4s, %[rx1].s[1]\n" - "fmla v30.4s, v7.4s, %[rx2].s[1]\n" - "fmla v31.4s, v7.4s, %[rx3].s[1]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), - [_n] "+r"(_n), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3) - : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn), - [rx2] "w"(rx2), [rx3] "w"(rx3) - : "cc", "memory", "x0", "v6", "v7", "v10", "v11", "v14", "v15", "v30", "v31"); - } - - if (remain >= 2) - { - asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" - "ld1 {v15.2s}, [%[outptr1]]\n" - "ld1 {v30.2s}, [%[outptr2]]\n" - "ld1 {v31.2s}, [%[outptr3]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "ld1 {v6.2s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.2s}, [x0]\n" - - "fmla v14.2s, v6.2s, %[rx0].s[0]\n" - "fmla v15.2s, v6.2s, %[rx1].s[0]\n" - "fmla v30.2s, v6.2s, %[rx2].s[0]\n" - "fmla v31.2s, v6.2s, %[rx3].s[0]\n" - "fmla v14.2s, v7.2s, %[rx0].s[1]\n" - "fmla v15.2s, v7.2s, %[rx1].s[1]\n" - "fmla v30.2s, v7.2s, %[rx2].s[1]\n" - "fmla v31.2s, v7.2s, %[rx3].s[1]\n" - - "st1 {v14.2s}, [%[outptr0]], #8\n" - "st1 {v15.2s}, [%[outptr1]], #8\n" - "st1 {v30.2s}, [%[outptr2]], #8\n" - "st1 {v31.2s}, [%[outptr3]], #8\n" - - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), - [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3) - : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), - - [rx2] "w"(rx2), [rx3] "w"(rx3) - : "cc", "memory", "x0", "v6", "v7", "v14", "v15", "v30", "v31"); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)); - *outptr1 += (*kernel0) * (*_x1) + (*(kernel0 + outch)) * (*(_x1 + 1)); - *outptr2 += (*kernel0) * (*_x2) + (*(kernel0 + outch)) * (*(_x2 + 1)); - *outptr3 += (*kernel0) * (*_x3) + (*(kernel0 + outch)) * (*(_x3 + 1)); - - kernel0++; - outptr0++; - outptr1++; - outptr2++; - outptr3++; - } - - kernel0 += outch; - _x0 += 2; - _x1 += 2; - _x2 += 2; - _x3 += 2; - } - - for (; i < inch; i++) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x2_t rx0 asm("v4") = vld1_dup_f32(_x0); - register float32x2_t rx1 asm("v5") = vld1_dup_f32(_x1); - register float32x2_t rx2 asm("v16") = vld1_dup_f32(_x2); - register float32x2_t rx3 asm("v17") = vld1_dup_f32(_x3); - - float *outptr0 = out0; - float *outptr1 = out0 + outch; - float *outptr2 = out0 + outch * 2; - float *outptr3 = out0 + outch * 3; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile( - "cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "beq 1f\n" - - "0:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v30.4s, v6.4s, %[rx2].s[0]\n" - "fmla v31.4s, v6.4s, %[rx3].s[0]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v15.4s, v10.4s, %[rx1].s[0]\n" - "fmla v30.4s, v10.4s, %[rx2].s[0]\n" - "fmla v31.4s, v10.4s, %[rx3].s[0]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v30.4s, v6.4s, %[rx2].s[0]\n" - "fmla v31.4s, v6.4s, %[rx3].s[0]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v15.4s, v10.4s, %[rx1].s[0]\n" - "fmla v30.4s, v10.4s, %[rx2].s[0]\n" - "fmla v31.4s, v10.4s, %[rx3].s[0]\n" - - "cmp %[oddn], #1\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - - "bne 3f\n" - - "2:\n" - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v30.4s, v6.4s, %[rx2].s[0]\n" - "fmla v31.4s, v6.4s, %[rx3].s[0]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), - [_n] "+r"(_n), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3) - : [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn), [rx2] "w"(rx2), [rx3] "w"(rx3) - : "cc", "memory", "x0", "v6", "v10", "v14", "v15", "v30", "v31"); - } - - if (remain >= 2) - { - asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" - "ld1 {v15.2s}, [%[outptr1]]\n" - "ld1 {v30.2s}, [%[outptr2]]\n" - "ld1 {v31.2s}, [%[outptr3]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "ld1 {v6.2s}, [x0]\n" - - "fmla v14.2s, v6.2s, %[rx0].s[0]\n" - "fmla v15.2s, v6.2s, %[rx1].s[0]\n" - "fmla v30.2s, v6.2s, %[rx2].s[0]\n" - "fmla v31.2s, v6.2s, %[rx3].s[0]\n" - - "st1 {v14.2s}, [%[outptr0]], #8\n" - "st1 {v15.2s}, [%[outptr1]], #8\n" - "st1 {v30.2s}, [%[outptr2]], #8\n" - "st1 {v31.2s}, [%[outptr3]], #8\n" - - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), - [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3) - : [rx0] "w"(rx0), [rx1] "w"(rx1), - - [rx2] "w"(rx2), [rx3] "w"(rx3) - : "cc", "memory", "x0", "v6", "v14", "v15", "v30", "v31"); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0); - *outptr1 += (*kernel0) * (*_x1); - *outptr2 += (*kernel0) * (*_x2); - *outptr3 += (*kernel0) * (*_x3); - - kernel0++; - outptr0++; - outptr1++; - outptr2++; - outptr3++; - } - - _x0 += 1; - _x1 += 1; - _x2 += 1; - _x3 += 1; - } - - img1 += inch * 4 * _stride; - out0 += outch * 4; - q += 4; - } - - for (; q + 1 < outw; /*q += 2*/) - { - if (padding) - { - if (((q + 1) * _stride + m % kernel_w < pad_left) || - (q * _stride + m % kernel_w) >= pad_left + w) - { - out0 += outch * 2; - img1 += inch * _stride * 2; - q += 2; - continue; - } - else if ((q + 1) * _stride + m % kernel_w >= pad_left + w) - { - break; - } - else if (q * _stride + m % kernel_w < pad_left) - { - out0 += outch; - img1 += inch * _stride; - q++; - continue; - } - } - - const float *_x0 = img1; - const float *_x1 = img1 + inch * _stride; - const float *kernel0 = _kernel0; - - int i = 0; - for (; i + 3 < inch; i += 4) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x4_t rx0 asm("v4") = vld1q_f32(_x0); - register float32x4_t rx1 asm("v5") = vld1q_f32(_x1); - - float *outptr0 = out0; - float *outptr1 = out0 + outch; - - int stride = outch << 2; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile("cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v9.4s}, [x0]\n" - - "beq 1f\n" - - "0:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v11.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v12.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v13.4s}, [x0]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v15.4s, v7.4s, %[rx1].s[1]\n" - "fmla v14.4s, v8.4s, %[rx0].s[2]\n" - "fmla v15.4s, v8.4s, %[rx1].s[2]\n" - "fmla v14.4s, v9.4s, %[rx0].s[3]\n" - "fmla v15.4s, v9.4s, %[rx1].s[3]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v9.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v15.4s, v10.4s, %[rx1].s[0]\n" - "fmla v14.4s, v11.4s, %[rx0].s[1]\n" - "fmla v15.4s, v11.4s, %[rx1].s[1]\n" - "fmla v14.4s, v12.4s, %[rx0].s[2]\n" - "fmla v15.4s, v12.4s, %[rx1].s[2]\n" - "fmla v14.4s, v13.4s, %[rx0].s[3]\n" - "fmla v15.4s, v13.4s, %[rx1].s[3]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v15.4s, v7.4s, %[rx1].s[1]\n" - "fmla v14.4s, v8.4s, %[rx0].s[2]\n" - "fmla v15.4s, v8.4s, %[rx1].s[2]\n" - "fmla v14.4s, v9.4s, %[rx0].s[3]\n" - "fmla v15.4s, v9.4s, %[rx1].s[3]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v11.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v12.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v13.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v15.4s, v10.4s, %[rx1].s[0]\n" - "fmla v14.4s, v11.4s, %[rx0].s[1]\n" - "fmla v15.4s, v11.4s, %[rx1].s[1]\n" - "fmla v14.4s, v12.4s, %[rx0].s[2]\n" - "fmla v15.4s, v12.4s, %[rx1].s[2]\n" - "fmla v14.4s, v13.4s, %[rx0].s[3]\n" - "fmla v15.4s, v13.4s, %[rx1].s[3]\n" - - "cmp %[oddn], #1\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - - "bne 3f\n" - - "2:\n" - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v9.4s}, [x0]\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v15.4s, v7.4s, %[rx1].s[1]\n" - "fmla v14.4s, v8.4s, %[rx0].s[2]\n" - "fmla v15.4s, v8.4s, %[rx1].s[2]\n" - "fmla v14.4s, v9.4s, %[rx0].s[3]\n" - "fmla v15.4s, v9.4s, %[rx1].s[3]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), - [outptr1] "+r"(outptr1), [_n] "+r"(_n) - : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn) - : "cc", "memory", "x0", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", - "v14", "v15"); - } - - if (remain >= 2) - { - asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" - "ld1 {v15.2s}, [%[outptr1]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "ld1 {v6.2s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.2s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v8.2s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v9.2s}, [x0]\n" - - "fmla v14.2s, v6.2s, %[rx0].s[0]\n" - "fmla v15.2s, v6.2s, %[rx1].s[0]\n" - "fmla v14.2s, v7.2s, %[rx0].s[1]\n" - "fmla v15.2s, v7.2s, %[rx1].s[1]\n" - "fmla v14.2s, v8.2s, %[rx0].s[2]\n" - "fmla v15.2s, v8.2s, %[rx1].s[2]\n" - "fmla v14.2s, v9.2s, %[rx0].s[3]\n" - "fmla v15.2s, v9.2s, %[rx1].s[3]\n" - - "st1 {v14.2s}, [%[outptr0]], #8\n" - "st1 {v15.2s}, [%[outptr1]], #8\n" - - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1) - : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1) - : "cc", "memory", "x0", "v6", "v7", "v8", "v9", "v14", "v15"); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)) + - (*(kernel0 + outch * 2)) * (*(_x0 + 2)) + - (*(kernel0 + outch * 3)) * (*(_x0 + 3)); - - *outptr1 += (*kernel0) * (*_x1) + (*(kernel0 + outch)) * (*(_x1 + 1)) + - (*(kernel0 + outch * 2)) * (*(_x1 + 2)) + - (*(kernel0 + outch * 3)) * (*(_x1 + 3)); - - kernel0++; - outptr0++; - outptr1++; - } - - kernel0 += outch * 3; - _x0 += 4; - _x1 += 4; - } - - for (; i + 1 < inch; i += 2) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x2_t rx0 asm("v4") = vld1_f32(_x0); - register float32x2_t rx1 asm("v5") = vld1_f32(_x1); - - float *outptr0 = out0; - float *outptr1 = out0 + outch; - - int stride = outch << 2; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile("cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - - "beq 1f\n" - - "0:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v11.4s}, [x0]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v15.4s, v7.4s, %[rx1].s[1]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v15.4s, v10.4s, %[rx1].s[0]\n" - "fmla v14.4s, v11.4s, %[rx0].s[1]\n" - "fmla v15.4s, v11.4s, %[rx1].s[1]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v15.4s, v7.4s, %[rx1].s[1]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v11.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v15.4s, v10.4s, %[rx1].s[0]\n" - "fmla v14.4s, v11.4s, %[rx0].s[1]\n" - "fmla v15.4s, v11.4s, %[rx1].s[1]\n" - - "cmp %[oddn], #1\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - - "bne 3f\n" - - "2:\n" - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v15.4s, v7.4s, %[rx1].s[1]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), - [outptr1] "+r"(outptr1), [_n] "+r"(_n) - : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn) - : "cc", "memory", "x0", "v6", "v7", "v10", "v11", "v14", "v15"); - } - - if (remain >= 2) - { - asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" - "ld1 {v15.2s}, [%[outptr1]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "ld1 {v6.2s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.2s}, [x0]\n" - - "fmla v14.2s, v6.2s, %[rx0].s[0]\n" - "fmla v15.2s, v6.2s, %[rx1].s[0]\n" - "fmla v14.2s, v7.2s, %[rx0].s[1]\n" - "fmla v15.2s, v7.2s, %[rx1].s[1]\n" - - "st1 {v14.2s}, [%[outptr0]], #8\n" - "st1 {v15.2s}, [%[outptr1]], #8\n" - - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1) - : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1) - : "cc", "memory", "x0", "v6", "v7", "v14", "v15"); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)); - *outptr1 += (*kernel0) * (*_x1) + (*(kernel0 + outch)) * (*(_x1 + 1)); - - kernel0++; - outptr0++; - outptr1++; - } - - kernel0 += outch; - _x0 += 2; - _x1 += 2; - } - - for (; i < inch; i++) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x2_t rx0 asm("v4") = vld1_dup_f32(_x0); - register float32x2_t rx1 asm("v5") = vld1_dup_f32(_x1); - - float *outptr0 = out0; - float *outptr1 = out0 + outch; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile("cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "beq 1f\n" - - "0:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v15.4s, v10.4s, %[rx1].s[0]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v15.4s, v10.4s, %[rx1].s[0]\n" - - "cmp %[oddn], #1\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - - "bne 3f\n" - - "2:\n" - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), - [outptr1] "+r"(outptr1), [_n] "+r"(_n) - : [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn) - : "cc", "memory", "x0", "v6", "v10", "v14", "v15"); - } - - if (remain >= 2) - { - asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" - "ld1 {v15.2s}, [%[outptr1]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "ld1 {v6.2s}, [x0]\n" - - "fmla v14.2s, v6.2s, %[rx0].s[0]\n" - "fmla v15.2s, v6.2s, %[rx1].s[0]\n" - - "st1 {v14.2s}, [%[outptr0]], #8\n" - "st1 {v15.2s}, [%[outptr1]], #8\n" - - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1) - : [rx0] "w"(rx0), [rx1] "w"(rx1) - : "cc", "memory", "x0", "v6", "v14", "v15"); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0); - *outptr1 += (*kernel0) * (*_x1); - - kernel0++; - outptr0++; - outptr1++; - } - - _x0 += 1; - _x1 += 1; - } - - img1 += inch * 2 * _stride; - out0 += outch * 2; - q += 2; - } - - for (; q < outw; q++) - { - if (padding) - { - if ((q * _stride + m % kernel_w < pad_left) || - (q * _stride + m % kernel_w >= pad_left + w)) - { - img1 += inch * _stride; - out0 += outch; - continue; - } - } - - const float *_x0 = img1; - const float *kernel0 = _kernel0; - - int i = 0; - for (; i + 3 < inch; i += 4) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x4_t rx0 asm("v4") = vld1q_f32(_x0); - - float *outptr0 = out0; - - int stride = outch << 2; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile("cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v9.4s}, [x0]\n" - - "beq 1f\n" - - "0:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v11.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v12.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v13.4s}, [x0]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v14.4s, v8.4s, %[rx0].s[2]\n" - "fmla v14.4s, v9.4s, %[rx0].s[3]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v9.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v14.4s, v11.4s, %[rx0].s[1]\n" - "fmla v14.4s, v12.4s, %[rx0].s[2]\n" - "fmla v14.4s, v13.4s, %[rx0].s[3]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v14.4s, v8.4s, %[rx0].s[2]\n" - "fmla v14.4s, v9.4s, %[rx0].s[3]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v11.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v12.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v13.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v14.4s, v11.4s, %[rx0].s[1]\n" - "fmla v14.4s, v12.4s, %[rx0].s[2]\n" - "fmla v14.4s, v13.4s, %[rx0].s[3]\n" - - "cmp %[oddn], #1\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - - "bne 3f\n" - - "2:\n" - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v9.4s}, [x0]\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v14.4s, v8.4s, %[rx0].s[2]\n" - "fmla v14.4s, v9.4s, %[rx0].s[3]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [_n] "+r"(_n) - : [stride] "r"(stride), [rx0] "w"(rx0), [oddn] "r"(oddn) - : "cc", "memory", "x0", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", - "v14"); - } - - if (remain >= 2) - { - asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "ld1 {v6.2s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.2s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v8.2s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v9.2s}, [x0]\n" - - "fmla v14.2s, v6.2s, %[rx0].s[0]\n" - "fmla v14.2s, v7.2s, %[rx0].s[1]\n" - "fmla v14.2s, v8.2s, %[rx0].s[2]\n" - "fmla v14.2s, v9.2s, %[rx0].s[3]\n" - - "st1 {v14.2s}, [%[outptr0]], #8\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0) - : [stride] "r"(stride), [rx0] "w"(rx0) - : "cc", "memory", "x0", "v6", "v7", "v8", "v9", "v14"); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)) + - (*(kernel0 + outch * 2)) * (*(_x0 + 2)) + - (*(kernel0 + outch * 3)) * (*(_x0 + 3)); - - kernel0++; - outptr0++; - } - - kernel0 += outch * 3; - _x0 += 4; - } - - for (; i + 1 < inch; i += 2) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x2_t rx0 asm("v4") = vld1_f32(_x0); - - float *outptr0 = out0; - int stride = outch << 2; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile("cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - - "beq 1f\n" - - "0:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v11.4s}, [x0]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v14.4s, v11.4s, %[rx0].s[1]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v11.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v14.4s, v11.4s, %[rx0].s[1]\n" - - "cmp %[oddn], #1\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - - "bne 3f\n" - - "2:\n" - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [_n] "+r"(_n) - : [stride] "r"(stride), [rx0] "w"(rx0), [oddn] "r"(oddn) - : "cc", "memory", "x0", "v6", "v7", "v10", "v11", "v14"); - } - - if (remain >= 2) - { - asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "ld1 {v6.2s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.2s}, [x0]\n" - - "fmla v14.2s, v6.2s, %[rx0].s[0]\n" - "fmla v14.2s, v7.2s, %[rx0].s[1]\n" - - "st1 {v14.2s}, [%[outptr0]], #8\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0) - : [stride] "r"(stride), [rx0] "w"(rx0) - : "cc", "memory", "x0", "v6", "v7", "v14"); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)); - - kernel0++; - outptr0++; - } - - kernel0 += outch; - _x0 += 2; - } - - for (; i < inch; i++) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x2_t rx0 asm("v4") = vld1_dup_f32(_x0); - - float *outptr0 = out0; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile("cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "beq 1f\n" - - "0:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - - "cmp %[oddn], #1\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - - "bne 3f\n" - - "2:\n" - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [_n] "+r"(_n) - : [rx0] "w"(rx0), [oddn] "r"(oddn) - : "cc", "memory", "x0", "v6", "v10", "v14"); - } - - if (remain >= 2) - { - asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "ld1 {v6.2s}, [x0]\n" - - "fmla v14.2s, v6.2s, %[rx0].s[0]\n" - - "st1 {v14.2s}, [%[outptr0]], #8\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0) - : [rx0] "w"(rx0) - : "cc", "memory", "x0", "v6", "v14"); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0); - - kernel0++; - outptr0++; - } - - _x0 += 1; - } - - img1 += inch * _stride; - out0 += outch; - } - } - } -} - -static void direct_conv_s(const convMat_t &bottom_blob, convMat_t &top_blob, - const convMat_t &_kernel, const int _stride, const int padding, - const int pad_top, const int pad_left) -{ - const int w = bottom_blob.w; - const int h = bottom_blob.h; - const int inch = bottom_blob.c; - const int outw = top_blob.w; - const int outh = top_blob.h; - const int outch = top_blob.c; - const int kernel_w = _kernel.w; - const int kernel_h = _kernel.h; - -#ifdef _OPENMP -#pragma omp parallel for -#endif - for (int p = 0; p < outh; p++) - { - const float *img0 = bottom_blob.data + (p * _stride - pad_top) * w * inch; - float *out = top_blob.data + p * outw * outch; - - // clear output - for (int j = 0; j < outw * outch; j++) - { - *(out + j) = 0.f; - } - - for (int m = 0; m < kernel_w * kernel_h; m++) - { - if (padding) - { - if (((p * _stride + m / kernel_w) < pad_top) || (p * _stride + m / kernel_w >= pad_top + h)) - { - continue; - } - } - - float *out0 = out; - const float *_kernel0 = _kernel.data + m * inch * outch; - const float *img1 = img0 + (m / kernel_w) * w * inch + (m % kernel_w - pad_left) * inch; - - int q = 0; - for (; q + 3 < outw; /*q += 4*/) - { - if (padding) - { - if (((q + 3) * _stride + m % kernel_w < pad_left) || - (q * _stride + m % kernel_w) >= pad_left + w) - { - out0 += outch * 4; - img1 += inch * _stride * 4; - q += 4; - continue; - } - else if ((q + 3) * _stride + m % kernel_w >= pad_left + w) - { - break; - } - else if (q * _stride + m % kernel_w < pad_left) - { - int delta = (pad_left - m % kernel_w) / _stride - q; - delta += (pad_left - m % kernel_w) % _stride ? 1 : 0; - out0 += outch * delta; - img1 += inch * _stride * delta; - q += delta; - continue; - } - } - - const float *_x0 = img1; - const float *_x1 = img1 + inch * _stride; - const float *_x2 = img1 + inch * _stride * 2; - const float *_x3 = img1 + inch * _stride * 3; - const float *kernel0 = _kernel0; - - int i = 0; - for (; i + 3 < inch; i += 4) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x4_t rx0 asm("v4") = vld1q_f32(_x0); - register float32x4_t rx1 asm("v5") = vld1q_f32(_x1); - register float32x4_t rx2 asm("v16") = vld1q_f32(_x2); - register float32x4_t rx3 asm("v17") = vld1q_f32(_x3); - - float *outptr0 = out0; - float *outptr1 = out0 + outch; - float *outptr2 = out0 + outch * 2; - float *outptr3 = out0 + outch * 3; - - int stride = outch << 2; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile("cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v9.4s}, [x0]\n" - - "beq 1f\n" - - "0:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v11.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v12.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v13.4s}, [x0]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v30.4s, v6.4s, %[rx2].s[0]\n" - "fmla v31.4s, v6.4s, %[rx3].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v15.4s, v7.4s, %[rx1].s[1]\n" - "fmla v30.4s, v7.4s, %[rx2].s[1]\n" - "fmla v31.4s, v7.4s, %[rx3].s[1]\n" - "fmla v14.4s, v8.4s, %[rx0].s[2]\n" - "fmla v15.4s, v8.4s, %[rx1].s[2]\n" - "fmla v30.4s, v8.4s, %[rx2].s[2]\n" - "fmla v31.4s, v8.4s, %[rx3].s[2]\n" - "fmla v14.4s, v9.4s, %[rx0].s[3]\n" - "fmla v15.4s, v9.4s, %[rx1].s[3]\n" - "fmla v30.4s, v9.4s, %[rx2].s[3]\n" - "fmla v31.4s, v9.4s, %[rx3].s[3]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v9.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v15.4s, v10.4s, %[rx1].s[0]\n" - "fmla v30.4s, v10.4s, %[rx2].s[0]\n" - "fmla v31.4s, v10.4s, %[rx3].s[0]\n" - "fmla v14.4s, v11.4s, %[rx0].s[1]\n" - "fmla v15.4s, v11.4s, %[rx1].s[1]\n" - "fmla v30.4s, v11.4s, %[rx2].s[1]\n" - "fmla v31.4s, v11.4s, %[rx3].s[1]\n" - "fmla v14.4s, v12.4s, %[rx0].s[2]\n" - "fmla v15.4s, v12.4s, %[rx1].s[2]\n" - "fmla v30.4s, v12.4s, %[rx2].s[2]\n" - "fmla v31.4s, v12.4s, %[rx3].s[2]\n" - "fmla v14.4s, v13.4s, %[rx0].s[3]\n" - "fmla v15.4s, v13.4s, %[rx1].s[3]\n" - "fmla v30.4s, v13.4s, %[rx2].s[3]\n" - "fmla v31.4s, v13.4s, %[rx3].s[3]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v30.4s, v6.4s, %[rx2].s[0]\n" - "fmla v31.4s, v6.4s, %[rx3].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v15.4s, v7.4s, %[rx1].s[1]\n" - "fmla v30.4s, v7.4s, %[rx2].s[1]\n" - "fmla v31.4s, v7.4s, %[rx3].s[1]\n" - "fmla v14.4s, v8.4s, %[rx0].s[2]\n" - "fmla v15.4s, v8.4s, %[rx1].s[2]\n" - "fmla v30.4s, v8.4s, %[rx2].s[2]\n" - "fmla v31.4s, v8.4s, %[rx3].s[2]\n" - "fmla v14.4s, v9.4s, %[rx0].s[3]\n" - "fmla v15.4s, v9.4s, %[rx1].s[3]\n" - "fmla v30.4s, v9.4s, %[rx2].s[3]\n" - "fmla v31.4s, v9.4s, %[rx3].s[3]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v11.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v12.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v13.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v15.4s, v10.4s, %[rx1].s[0]\n" - "fmla v30.4s, v10.4s, %[rx2].s[0]\n" - "fmla v31.4s, v10.4s, %[rx3].s[0]\n" - "fmla v14.4s, v11.4s, %[rx0].s[1]\n" - "fmla v15.4s, v11.4s, %[rx1].s[1]\n" - "fmla v30.4s, v11.4s, %[rx2].s[1]\n" - "fmla v31.4s, v11.4s, %[rx3].s[1]\n" - "fmla v14.4s, v12.4s, %[rx0].s[2]\n" - "fmla v15.4s, v12.4s, %[rx1].s[2]\n" - "fmla v30.4s, v12.4s, %[rx2].s[2]\n" - "fmla v31.4s, v12.4s, %[rx3].s[2]\n" - "fmla v14.4s, v13.4s, %[rx0].s[3]\n" - "fmla v15.4s, v13.4s, %[rx1].s[3]\n" - "fmla v30.4s, v13.4s, %[rx2].s[3]\n" - "fmla v31.4s, v13.4s, %[rx3].s[3]\n" - - "cmp %[oddn], #1\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - - "bne 3f\n" - - "2:\n" - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v9.4s}, [x0]\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v30.4s, v6.4s, %[rx2].s[0]\n" - "fmla v31.4s, v6.4s, %[rx3].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v15.4s, v7.4s, %[rx1].s[1]\n" - "fmla v30.4s, v7.4s, %[rx2].s[1]\n" - "fmla v31.4s, v7.4s, %[rx3].s[1]\n" - "fmla v14.4s, v8.4s, %[rx0].s[2]\n" - "fmla v15.4s, v8.4s, %[rx1].s[2]\n" - "fmla v30.4s, v8.4s, %[rx2].s[2]\n" - "fmla v31.4s, v8.4s, %[rx3].s[2]\n" - "fmla v14.4s, v9.4s, %[rx0].s[3]\n" - "fmla v15.4s, v9.4s, %[rx1].s[3]\n" - "fmla v30.4s, v9.4s, %[rx2].s[3]\n" - "fmla v31.4s, v9.4s, %[rx3].s[3]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), - [outptr1] "+r"(outptr1), [_n] "+r"(_n), [outptr2] "+r"(outptr2), - [outptr3] "+r"(outptr3) - : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn), - [rx2] "w"(rx2), [rx3] "w"(rx3) - : "cc", "memory", "x0", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", - "v14", "v15", "v30", "v31"); - } - - if (remain >= 2) - { - asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" - "ld1 {v15.2s}, [%[outptr1]]\n" - "ld1 {v30.2s}, [%[outptr2]]\n" - "ld1 {v31.2s}, [%[outptr3]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "ld1 {v6.2s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.2s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v8.2s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v9.2s}, [x0]\n" - - "fmla v14.2s, v6.2s, %[rx0].s[0]\n" - "fmla v15.2s, v6.2s, %[rx1].s[0]\n" - "fmla v30.2s, v6.2s, %[rx2].s[0]\n" - "fmla v31.2s, v6.2s, %[rx3].s[0]\n" - "fmla v14.2s, v7.2s, %[rx0].s[1]\n" - "fmla v15.2s, v7.2s, %[rx1].s[1]\n" - "fmla v30.2s, v7.2s, %[rx2].s[1]\n" - "fmla v31.2s, v7.2s, %[rx3].s[1]\n" - "fmla v14.2s, v8.2s, %[rx0].s[2]\n" - "fmla v15.2s, v8.2s, %[rx1].s[2]\n" - "fmla v30.2s, v8.2s, %[rx2].s[2]\n" - "fmla v31.2s, v8.2s, %[rx3].s[2]\n" - "fmla v14.2s, v9.2s, %[rx0].s[3]\n" - "fmla v15.2s, v9.2s, %[rx1].s[3]\n" - "fmla v30.2s, v9.2s, %[rx2].s[3]\n" - "fmla v31.2s, v9.2s, %[rx3].s[3]\n" - - "st1 {v14.2s}, [%[outptr0]], #8\n" - "st1 {v15.2s}, [%[outptr1]], #8\n" - "st1 {v30.2s}, [%[outptr2]], #8\n" - "st1 {v31.2s}, [%[outptr3]], #8\n" - - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), - [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3) - : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), - - [rx2] "w"(rx2), [rx3] "w"(rx3) - : "cc", "memory", "x0", "v6", "v7", "v8", "v9", "v14", "v15", "v30", - "v31"); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)) + - (*(kernel0 + outch * 2)) * (*(_x0 + 2)) + - (*(kernel0 + outch * 3)) * (*(_x0 + 3)); - - *outptr1 += (*kernel0) * (*_x1) + (*(kernel0 + outch)) * (*(_x1 + 1)) + - (*(kernel0 + outch * 2)) * (*(_x1 + 2)) + - (*(kernel0 + outch * 3)) * (*(_x1 + 3)); - - *outptr2 += (*kernel0) * (*_x2) + (*(kernel0 + outch)) * (*(_x2 + 1)) + - (*(kernel0 + outch * 2)) * (*(_x2 + 2)) + - (*(kernel0 + outch * 3)) * (*(_x2 + 3)); - - *outptr3 += (*kernel0) * (*_x3) + (*(kernel0 + outch)) * (*(_x3 + 1)) + - (*(kernel0 + outch * 2)) * (*(_x3 + 2)) + - (*(kernel0 + outch * 3)) * (*(_x3 + 3)); - - kernel0++; - outptr0++; - outptr1++; - outptr2++; - outptr3++; - } - - kernel0 += outch * 3; - _x0 += 4; - _x1 += 4; - _x2 += 4; - _x3 += 4; - } - - for (; i + 1 < inch; i += 2) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x2_t rx0 asm("v4") = vld1_f32(_x0); - register float32x2_t rx1 asm("v5") = vld1_f32(_x1); - register float32x2_t rx2 asm("v16") = vld1_f32(_x2); - register float32x2_t rx3 asm("v17") = vld1_f32(_x3); - - float *outptr0 = out0; - float *outptr1 = out0 + outch; - float *outptr2 = out0 + outch * 2; - float *outptr3 = out0 + outch * 3; - - int stride = outch << 2; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile( - "cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - - "beq 1f\n" - - "0:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v11.4s}, [x0]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v30.4s, v6.4s, %[rx2].s[0]\n" - "fmla v31.4s, v6.4s, %[rx3].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v15.4s, v7.4s, %[rx1].s[1]\n" - "fmla v30.4s, v7.4s, %[rx2].s[1]\n" - "fmla v31.4s, v7.4s, %[rx3].s[1]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v15.4s, v10.4s, %[rx1].s[0]\n" - "fmla v30.4s, v10.4s, %[rx2].s[0]\n" - "fmla v31.4s, v10.4s, %[rx3].s[0]\n" - "fmla v14.4s, v11.4s, %[rx0].s[1]\n" - "fmla v15.4s, v11.4s, %[rx1].s[1]\n" - "fmla v30.4s, v11.4s, %[rx2].s[1]\n" - "fmla v31.4s, v11.4s, %[rx3].s[1]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v30.4s, v6.4s, %[rx2].s[0]\n" - "fmla v31.4s, v6.4s, %[rx3].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v15.4s, v7.4s, %[rx1].s[1]\n" - "fmla v30.4s, v7.4s, %[rx2].s[1]\n" - "fmla v31.4s, v7.4s, %[rx3].s[1]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v11.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v15.4s, v10.4s, %[rx1].s[0]\n" - "fmla v30.4s, v10.4s, %[rx2].s[0]\n" - "fmla v31.4s, v10.4s, %[rx3].s[0]\n" - "fmla v14.4s, v11.4s, %[rx0].s[1]\n" - "fmla v15.4s, v11.4s, %[rx1].s[1]\n" - "fmla v30.4s, v11.4s, %[rx2].s[1]\n" - "fmla v31.4s, v11.4s, %[rx3].s[1]\n" - - "cmp %[oddn], #1\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - - "bne 3f\n" - - "2:\n" - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v30.4s, v6.4s, %[rx2].s[0]\n" - "fmla v31.4s, v6.4s, %[rx3].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v15.4s, v7.4s, %[rx1].s[1]\n" - "fmla v30.4s, v7.4s, %[rx2].s[1]\n" - "fmla v31.4s, v7.4s, %[rx3].s[1]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), - [_n] "+r"(_n), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3) - : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn), - [rx2] "w"(rx2), [rx3] "w"(rx3) - : "cc", "memory", "x0", "v6", "v7", "v10", "v11", "v14", "v15", "v30", "v31"); - } - - if (remain >= 2) - { - asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" - "ld1 {v15.2s}, [%[outptr1]]\n" - "ld1 {v30.2s}, [%[outptr2]]\n" - "ld1 {v31.2s}, [%[outptr3]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "ld1 {v6.2s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.2s}, [x0]\n" - - "fmla v14.2s, v6.2s, %[rx0].s[0]\n" - "fmla v15.2s, v6.2s, %[rx1].s[0]\n" - "fmla v30.2s, v6.2s, %[rx2].s[0]\n" - "fmla v31.2s, v6.2s, %[rx3].s[0]\n" - "fmla v14.2s, v7.2s, %[rx0].s[1]\n" - "fmla v15.2s, v7.2s, %[rx1].s[1]\n" - "fmla v30.2s, v7.2s, %[rx2].s[1]\n" - "fmla v31.2s, v7.2s, %[rx3].s[1]\n" - - "st1 {v14.2s}, [%[outptr0]], #8\n" - "st1 {v15.2s}, [%[outptr1]], #8\n" - "st1 {v30.2s}, [%[outptr2]], #8\n" - "st1 {v31.2s}, [%[outptr3]], #8\n" - - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), - [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3) - : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), - - [rx2] "w"(rx2), [rx3] "w"(rx3) - : "cc", "memory", "x0", "v6", "v7", "v14", "v15", "v30", "v31"); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)); - *outptr1 += (*kernel0) * (*_x1) + (*(kernel0 + outch)) * (*(_x1 + 1)); - *outptr2 += (*kernel0) * (*_x2) + (*(kernel0 + outch)) * (*(_x2 + 1)); - *outptr3 += (*kernel0) * (*_x3) + (*(kernel0 + outch)) * (*(_x3 + 1)); - - kernel0++; - outptr0++; - outptr1++; - outptr2++; - outptr3++; - } - - kernel0 += outch; - _x0 += 2; - _x1 += 2; - _x2 += 2; - _x3 += 2; - } - - for (; i < inch; i++) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x2_t rx0 asm("v4") = vld1_dup_f32(_x0); - register float32x2_t rx1 asm("v5") = vld1_dup_f32(_x1); - register float32x2_t rx2 asm("v16") = vld1_dup_f32(_x2); - register float32x2_t rx3 asm("v17") = vld1_dup_f32(_x3); - - float *outptr0 = out0; - float *outptr1 = out0 + outch; - float *outptr2 = out0 + outch * 2; - float *outptr3 = out0 + outch * 3; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile( - "cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "beq 1f\n" - - "0:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v30.4s, v6.4s, %[rx2].s[0]\n" - "fmla v31.4s, v6.4s, %[rx3].s[0]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v15.4s, v10.4s, %[rx1].s[0]\n" - "fmla v30.4s, v10.4s, %[rx2].s[0]\n" - "fmla v31.4s, v10.4s, %[rx3].s[0]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v30.4s, v6.4s, %[rx2].s[0]\n" - "fmla v31.4s, v6.4s, %[rx3].s[0]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v15.4s, v10.4s, %[rx1].s[0]\n" - "fmla v30.4s, v10.4s, %[rx2].s[0]\n" - "fmla v31.4s, v10.4s, %[rx3].s[0]\n" - - "cmp %[oddn], #1\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - - "bne 3f\n" - - "2:\n" - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - "ld1 {v30.4s}, [%[outptr2]]\n" - "ld1 {v31.4s}, [%[outptr3]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v30.4s, v6.4s, %[rx2].s[0]\n" - "fmla v31.4s, v6.4s, %[rx3].s[0]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "st1 {v30.4s}, [%[outptr2]], #16\n" - "st1 {v31.4s}, [%[outptr3]], #16\n" - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), - [_n] "+r"(_n), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3) - : [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn), [rx2] "w"(rx2), [rx3] "w"(rx3) - : "cc", "memory", "x0", "v6", "v10", "v14", "v15", "v30", "v31"); - } - - if (remain >= 2) - { - asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" - "ld1 {v15.2s}, [%[outptr1]]\n" - "ld1 {v30.2s}, [%[outptr2]]\n" - "ld1 {v31.2s}, [%[outptr3]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "ld1 {v6.2s}, [x0]\n" - - "fmla v14.2s, v6.2s, %[rx0].s[0]\n" - "fmla v15.2s, v6.2s, %[rx1].s[0]\n" - "fmla v30.2s, v6.2s, %[rx2].s[0]\n" - "fmla v31.2s, v6.2s, %[rx3].s[0]\n" - - "st1 {v14.2s}, [%[outptr0]], #8\n" - "st1 {v15.2s}, [%[outptr1]], #8\n" - "st1 {v30.2s}, [%[outptr2]], #8\n" - "st1 {v31.2s}, [%[outptr3]], #8\n" - - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), - [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3) - : [rx0] "w"(rx0), [rx1] "w"(rx1), - - [rx2] "w"(rx2), [rx3] "w"(rx3) - : "cc", "memory", "x0", "v6", "v14", "v15", "v30", "v31"); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0); - *outptr1 += (*kernel0) * (*_x1); - *outptr2 += (*kernel0) * (*_x2); - *outptr3 += (*kernel0) * (*_x3); - - kernel0++; - outptr0++; - outptr1++; - outptr2++; - outptr3++; - } - - _x0 += 1; - _x1 += 1; - _x2 += 1; - _x3 += 1; - } - - img1 += inch * 4 * _stride; - out0 += outch * 4; - q += 4; - } - - for (; q + 1 < outw; /*q += 2*/) - { - if (padding) - { - if (((q + 1) * _stride + m % kernel_w < pad_left) || - (q * _stride + m % kernel_w) >= pad_left + w) - { - out0 += outch * 2; - img1 += inch * _stride * 2; - q += 2; - continue; - } - else if ((q + 1) * _stride + m % kernel_w >= pad_left + w) - { - break; - } - else if (q * _stride + m % kernel_w < pad_left) - { - out0 += outch; - img1 += inch * _stride; - q++; - continue; - } - } - - const float *_x0 = img1; - const float *_x1 = img1 + inch * _stride; - const float *kernel0 = _kernel0; - - int i = 0; - for (; i + 3 < inch; i += 4) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x4_t rx0 asm("v4") = vld1q_f32(_x0); - register float32x4_t rx1 asm("v5") = vld1q_f32(_x1); - - float *outptr0 = out0; - float *outptr1 = out0 + outch; - - int stride = outch << 2; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile("cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v9.4s}, [x0]\n" - - "beq 1f\n" - - "0:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v11.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v12.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v13.4s}, [x0]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v15.4s, v7.4s, %[rx1].s[1]\n" - "fmla v14.4s, v8.4s, %[rx0].s[2]\n" - "fmla v15.4s, v8.4s, %[rx1].s[2]\n" - "fmla v14.4s, v9.4s, %[rx0].s[3]\n" - "fmla v15.4s, v9.4s, %[rx1].s[3]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v9.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v15.4s, v10.4s, %[rx1].s[0]\n" - "fmla v14.4s, v11.4s, %[rx0].s[1]\n" - "fmla v15.4s, v11.4s, %[rx1].s[1]\n" - "fmla v14.4s, v12.4s, %[rx0].s[2]\n" - "fmla v15.4s, v12.4s, %[rx1].s[2]\n" - "fmla v14.4s, v13.4s, %[rx0].s[3]\n" - "fmla v15.4s, v13.4s, %[rx1].s[3]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v15.4s, v7.4s, %[rx1].s[1]\n" - "fmla v14.4s, v8.4s, %[rx0].s[2]\n" - "fmla v15.4s, v8.4s, %[rx1].s[2]\n" - "fmla v14.4s, v9.4s, %[rx0].s[3]\n" - "fmla v15.4s, v9.4s, %[rx1].s[3]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v11.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v12.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v13.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v15.4s, v10.4s, %[rx1].s[0]\n" - "fmla v14.4s, v11.4s, %[rx0].s[1]\n" - "fmla v15.4s, v11.4s, %[rx1].s[1]\n" - "fmla v14.4s, v12.4s, %[rx0].s[2]\n" - "fmla v15.4s, v12.4s, %[rx1].s[2]\n" - "fmla v14.4s, v13.4s, %[rx0].s[3]\n" - "fmla v15.4s, v13.4s, %[rx1].s[3]\n" - - "cmp %[oddn], #1\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - - "bne 3f\n" - - "2:\n" - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v9.4s}, [x0]\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v15.4s, v7.4s, %[rx1].s[1]\n" - "fmla v14.4s, v8.4s, %[rx0].s[2]\n" - "fmla v15.4s, v8.4s, %[rx1].s[2]\n" - "fmla v14.4s, v9.4s, %[rx0].s[3]\n" - "fmla v15.4s, v9.4s, %[rx1].s[3]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), - [outptr1] "+r"(outptr1), [_n] "+r"(_n) - : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn) - : "cc", "memory", "x0", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", - "v14", "v15"); - } - - if (remain >= 2) - { - asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" - "ld1 {v15.2s}, [%[outptr1]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "ld1 {v6.2s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.2s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v8.2s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v9.2s}, [x0]\n" - - "fmla v14.2s, v6.2s, %[rx0].s[0]\n" - "fmla v15.2s, v6.2s, %[rx1].s[0]\n" - "fmla v14.2s, v7.2s, %[rx0].s[1]\n" - "fmla v15.2s, v7.2s, %[rx1].s[1]\n" - "fmla v14.2s, v8.2s, %[rx0].s[2]\n" - "fmla v15.2s, v8.2s, %[rx1].s[2]\n" - "fmla v14.2s, v9.2s, %[rx0].s[3]\n" - "fmla v15.2s, v9.2s, %[rx1].s[3]\n" - - "st1 {v14.2s}, [%[outptr0]], #8\n" - "st1 {v15.2s}, [%[outptr1]], #8\n" - - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1) - : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1) - : "cc", "memory", "x0", "v6", "v7", "v8", "v9", "v14", "v15"); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)) + - (*(kernel0 + outch * 2)) * (*(_x0 + 2)) + - (*(kernel0 + outch * 3)) * (*(_x0 + 3)); - - *outptr1 += (*kernel0) * (*_x1) + (*(kernel0 + outch)) * (*(_x1 + 1)) + - (*(kernel0 + outch * 2)) * (*(_x1 + 2)) + - (*(kernel0 + outch * 3)) * (*(_x1 + 3)); - - kernel0++; - outptr0++; - outptr1++; - } - - kernel0 += outch * 3; - _x0 += 4; - _x1 += 4; - } - - for (; i + 1 < inch; i += 2) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x2_t rx0 asm("v4") = vld1_f32(_x0); - register float32x2_t rx1 asm("v5") = vld1_f32(_x1); - - float *outptr0 = out0; - float *outptr1 = out0 + outch; - - int stride = outch << 2; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile("cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - - "beq 1f\n" - - "0:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v11.4s}, [x0]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v15.4s, v7.4s, %[rx1].s[1]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v15.4s, v10.4s, %[rx1].s[0]\n" - "fmla v14.4s, v11.4s, %[rx0].s[1]\n" - "fmla v15.4s, v11.4s, %[rx1].s[1]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v15.4s, v7.4s, %[rx1].s[1]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v11.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v15.4s, v10.4s, %[rx1].s[0]\n" - "fmla v14.4s, v11.4s, %[rx0].s[1]\n" - "fmla v15.4s, v11.4s, %[rx1].s[1]\n" - - "cmp %[oddn], #1\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - - "bne 3f\n" - - "2:\n" - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v15.4s, v7.4s, %[rx1].s[1]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), - [outptr1] "+r"(outptr1), [_n] "+r"(_n) - : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn) - : "cc", "memory", "x0", "v6", "v7", "v10", "v11", "v14", "v15"); - } - - if (remain >= 2) - { - asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" - "ld1 {v15.2s}, [%[outptr1]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "ld1 {v6.2s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.2s}, [x0]\n" - - "fmla v14.2s, v6.2s, %[rx0].s[0]\n" - "fmla v15.2s, v6.2s, %[rx1].s[0]\n" - "fmla v14.2s, v7.2s, %[rx0].s[1]\n" - "fmla v15.2s, v7.2s, %[rx1].s[1]\n" - - "st1 {v14.2s}, [%[outptr0]], #8\n" - "st1 {v15.2s}, [%[outptr1]], #8\n" - - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1) - : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1) - : "cc", "memory", "x0", "v6", "v7", "v14", "v15"); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)); - *outptr1 += (*kernel0) * (*_x1) + (*(kernel0 + outch)) * (*(_x1 + 1)); - - kernel0++; - outptr0++; - outptr1++; - } - - kernel0 += outch; - _x0 += 2; - _x1 += 2; - } - - for (; i < inch; i++) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x2_t rx0 asm("v4") = vld1_dup_f32(_x0); - register float32x2_t rx1 asm("v5") = vld1_dup_f32(_x1); - - float *outptr0 = out0; - float *outptr1 = out0 + outch; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile("cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "beq 1f\n" - - "0:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v15.4s, v10.4s, %[rx1].s[0]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v15.4s, v10.4s, %[rx1].s[0]\n" - - "cmp %[oddn], #1\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - - "bne 3f\n" - - "2:\n" - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - "ld1 {v15.4s}, [%[outptr1]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v15.4s, v6.4s, %[rx1].s[0]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "st1 {v15.4s}, [%[outptr1]], #16\n" - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), - [outptr1] "+r"(outptr1), [_n] "+r"(_n) - : [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn) - : "cc", "memory", "x0", "v6", "v10", "v14", "v15"); - } - - if (remain >= 2) - { - asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" - "ld1 {v15.2s}, [%[outptr1]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "ld1 {v6.2s}, [x0]\n" - - "fmla v14.2s, v6.2s, %[rx0].s[0]\n" - "fmla v15.2s, v6.2s, %[rx1].s[0]\n" - - "st1 {v14.2s}, [%[outptr0]], #8\n" - "st1 {v15.2s}, [%[outptr1]], #8\n" - - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1) - : [rx0] "w"(rx0), [rx1] "w"(rx1) - : "cc", "memory", "x0", "v6", "v14", "v15"); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0); - *outptr1 += (*kernel0) * (*_x1); - - kernel0++; - outptr0++; - outptr1++; - } - - _x0 += 1; - _x1 += 1; - } - - img1 += inch * 2 * _stride; - out0 += outch * 2; - q += 2; - } - - for (; q < outw; q++) - { - if (padding) - { - if ((q * _stride + m % kernel_w < pad_left) || - (q * _stride + m % kernel_w >= pad_left + w)) - { - img1 += inch * _stride; - out0 += outch; - continue; - } - } - - const float *_x0 = img1; - const float *kernel0 = _kernel0; - - int i = 0; - for (; i + 3 < inch; i += 4) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x4_t rx0 asm("v4") = vld1q_f32(_x0); - - float *outptr0 = out0; - - int stride = outch << 2; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile("cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v9.4s}, [x0]\n" - - "beq 1f\n" - - "0:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v11.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v12.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v13.4s}, [x0]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v14.4s, v8.4s, %[rx0].s[2]\n" - "fmla v14.4s, v9.4s, %[rx0].s[3]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v9.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v14.4s, v11.4s, %[rx0].s[1]\n" - "fmla v14.4s, v12.4s, %[rx0].s[2]\n" - "fmla v14.4s, v13.4s, %[rx0].s[3]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v14.4s, v8.4s, %[rx0].s[2]\n" - "fmla v14.4s, v9.4s, %[rx0].s[3]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v11.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v12.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v13.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v14.4s, v11.4s, %[rx0].s[1]\n" - "fmla v14.4s, v12.4s, %[rx0].s[2]\n" - "fmla v14.4s, v13.4s, %[rx0].s[3]\n" - - "cmp %[oddn], #1\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - - "bne 3f\n" - - "2:\n" - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v9.4s}, [x0]\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - "fmla v14.4s, v8.4s, %[rx0].s[2]\n" - "fmla v14.4s, v9.4s, %[rx0].s[3]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [_n] "+r"(_n) - : [stride] "r"(stride), [rx0] "w"(rx0), [oddn] "r"(oddn) - : "cc", "memory", "x0", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", - "v14"); - } - - if (remain >= 2) - { - asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "ld1 {v6.2s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.2s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v8.2s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v9.2s}, [x0]\n" - - "fmla v14.2s, v6.2s, %[rx0].s[0]\n" - "fmla v14.2s, v7.2s, %[rx0].s[1]\n" - "fmla v14.2s, v8.2s, %[rx0].s[2]\n" - "fmla v14.2s, v9.2s, %[rx0].s[3]\n" - - "st1 {v14.2s}, [%[outptr0]], #8\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0) - : [stride] "r"(stride), [rx0] "w"(rx0) - : "cc", "memory", "x0", "v6", "v7", "v8", "v9", "v14"); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)) + - (*(kernel0 + outch * 2)) * (*(_x0 + 2)) + - (*(kernel0 + outch * 3)) * (*(_x0 + 3)); - - kernel0++; - outptr0++; - } - - kernel0 += outch * 3; - _x0 += 4; - } - - for (; i + 1 < inch; i += 2) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x2_t rx0 asm("v4") = vld1_f32(_x0); - - float *outptr0 = out0; - int stride = outch << 2; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile("cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - - "beq 1f\n" - - "0:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v11.4s}, [x0]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v14.4s, v11.4s, %[rx0].s[1]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v11.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - "fmla v14.4s, v11.4s, %[rx0].s[1]\n" - - "cmp %[oddn], #1\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - - "bne 3f\n" - - "2:\n" - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.4s}, [x0]\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - "fmla v14.4s, v7.4s, %[rx0].s[1]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [_n] "+r"(_n) - : [stride] "r"(stride), [rx0] "w"(rx0), [oddn] "r"(oddn) - : "cc", "memory", "x0", "v6", "v7", "v10", "v11", "v14"); - } - - if (remain >= 2) - { - asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "ld1 {v6.2s}, [x0]\n" - "add x0, x0, %[stride]\n" - "ld1 {v7.2s}, [x0]\n" - - "fmla v14.2s, v6.2s, %[rx0].s[0]\n" - "fmla v14.2s, v7.2s, %[rx0].s[1]\n" - - "st1 {v14.2s}, [%[outptr0]], #8\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0) - : [stride] "r"(stride), [rx0] "w"(rx0) - : "cc", "memory", "x0", "v6", "v7", "v14"); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)); - - kernel0++; - outptr0++; - } - - kernel0 += outch; - _x0 += 2; - } - - for (; i < inch; i++) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x2_t rx0 asm("v4") = vld1_dup_f32(_x0); - - float *outptr0 = out0; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile("cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - "beq 1f\n" - - "0:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "ld1 {v14.4s}, [%[outptr0]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v10.4s}, [x0]\n" - - "fmla v14.4s, v10.4s, %[rx0].s[0]\n" - - "cmp %[oddn], #1\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - - "bne 3f\n" - - "2:\n" - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "ld1 {v6.4s}, [x0]\n" - - "ld1 {v14.4s}, [%[outptr0]]\n" - - "fmla v14.4s, v6.4s, %[rx0].s[0]\n" - - "st1 {v14.4s}, [%[outptr0]], #16\n" - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [_n] "+r"(_n) - : [rx0] "w"(rx0), [oddn] "r"(oddn) - : "cc", "memory", "x0", "v6", "v10", "v14"); - } - - if (remain >= 2) - { - asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" - - "mov x0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "ld1 {v6.2s}, [x0]\n" - - "fmla v14.2s, v6.2s, %[rx0].s[0]\n" - - "st1 {v14.2s}, [%[outptr0]], #8\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0) - : [rx0] "w"(rx0) - : "cc", "memory", "x0", "v6", "v14"); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0); - - kernel0++; - outptr0++; - } - - _x0 += 1; - } - - img1 += inch * _stride; - out0 += outch; - } - } - } -} - -#else // __aarch64__ -static void direct_conv_l(const convMat_t &bottom_blob, convMat_t &top_blob, - const convMat_t &_kernel, const int _stride, const int padding, - const int pad_top, const int pad_left) -{ - const int w = bottom_blob.w; - const int h = bottom_blob.h; - const int inch = bottom_blob.c; - const int outw = top_blob.w; - const int outh = top_blob.h; - const int outch = top_blob.c; - const int kernel_w = _kernel.w; - const int kernel_h = _kernel.h; - - for (int m = 0; m < kernel_w * kernel_h; m++) - { - const float *_kernel0 = _kernel.data + m * inch * outch; - const float *img0 = - bottom_blob.data + (m / kernel_w - pad_top) * w * inch + (m % kernel_w - pad_left) * inch; - -#ifdef _OPENMP -#pragma omp parallel for -#endif // _OPENMP - for (int p = 0; p < outh; p++) - { - float *out0 = top_blob.data + p * outw * outch; - // clear output. - if (m == 0) - { - for (int j = 0; j < outw * outch; j++) - { - *(out0 + j) = 0.f; - } - } - - if (padding) - { - if (((p * _stride + m / kernel_w) < pad_top) || (p * _stride + m / kernel_w >= pad_top + h)) - { - continue; - } - } - - const float *img1 = img0 + p * w * inch * _stride; - - int q = 0; - for (; q + 1 < outw; /*q += 2*/) - { - if (padding) - { - if (((q + 1) * _stride + m % kernel_w < pad_left) || - (q * _stride + m % kernel_w) >= pad_left + w) - { - out0 += outch * 2; - img1 += inch * _stride * 2; - q += 2; - continue; - } - else if (q * _stride + m % kernel_w < pad_left) - { - out0 += outch; - img1 += inch * _stride; - q++; - continue; - } - else if ((q + 1) * _stride + m % kernel_w >= pad_left + w) - { - break; - } - } - - const float *_x0 = img1; - const float *_x1 = img1 + inch * _stride; - const float *kernel0 = _kernel0; - - int i = 0; - for (; i + 3 < inch; i += 4) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x4_t rx0 asm("q4") = vld1q_f32(_x0); - register float32x4_t rx1 asm("q5") = vld1q_f32(_x1); - - float *outptr0 = out0; - float *outptr1 = out0 + outch; - - int stride = outch << 2; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile("cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d16-d17}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d18-d19}, [r0]\n" - - "beq 1f\n" - - "0:\n" - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d20-d21}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d22-d23}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d24-d25}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d26-d27}, [r0]\n" - - "vmla.f32 q14, q6, %e[rx0][0]\n" - "vmla.f32 q15, q6, %e[rx1][0]\n" - "vmla.f32 q14, q7, %e[rx0][1]\n" - "vmla.f32 q15, q7, %e[rx1][1]\n" - "vmla.f32 q14, q8, %f[rx0][0]\n" - "vmla.f32 q15, q8, %f[rx1][0]\n" - "vmla.f32 q14, q9, %f[rx0][1]\n" - "vmla.f32 q15, q9, %f[rx1][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d16-d17}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d18-d19}, [r0]\n" - - "vmla.f32 q14, q10, %e[rx0][0]\n" - "vmla.f32 q15, q10, %e[rx1][0]\n" - "vmla.f32 q14, q11, %e[rx0][1]\n" - "vmla.f32 q15, q11, %e[rx1][1]\n" - "vmla.f32 q14, q12, %f[rx0][0]\n" - "vmla.f32 q15, q12, %f[rx1][0]\n" - "vmla.f32 q14, q13, %f[rx0][1]\n" - "vmla.f32 q15, q13, %f[rx1][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "vmla.f32 q14, q6, %e[rx0][0]\n" - "vmla.f32 q15, q6, %e[rx1][0]\n" - "vmla.f32 q14, q7, %e[rx0][1]\n" - "vmla.f32 q15, q7, %e[rx1][1]\n" - "vmla.f32 q14, q8, %f[rx0][0]\n" - "vmla.f32 q15, q8, %f[rx1][0]\n" - "vmla.f32 q14, q9, %f[rx0][1]\n" - "vmla.f32 q15, q9, %f[rx1][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d20-d21}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d22-d23}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d24-d25}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d26-d27}, [r0]\n" - - "vmla.f32 q14, q10, %e[rx0][0]\n" - "vmla.f32 q15, q10, %e[rx1][0]\n" - "vmla.f32 q14, q11, %e[rx0][1]\n" - "vmla.f32 q15, q11, %e[rx1][1]\n" - "vmla.f32 q14, q12, %f[rx0][0]\n" - "vmla.f32 q15, q12, %f[rx1][0]\n" - "vmla.f32 q14, q13, %f[rx0][1]\n" - "vmla.f32 q15, q13, %f[rx1][1]\n" - - "cmp %[oddn], #1\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - - "bne 3f\n" - - "2:\n" - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d16-d17}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d18-d19}, [r0]\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "vmla.f32 q14, q6, %e[rx0][0]\n" - "vmla.f32 q15, q6, %e[rx1][0]\n" - "vmla.f32 q14, q7, %e[rx0][1]\n" - "vmla.f32 q15, q7, %e[rx1][1]\n" - "vmla.f32 q14, q8, %f[rx0][0]\n" - "vmla.f32 q15, q8, %f[rx1][0]\n" - "vmla.f32 q14, q9, %f[rx0][1]\n" - "vmla.f32 q15, q9, %f[rx1][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), - [outptr1] "+r"(outptr1), [_n] "+r"(_n) - : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn) - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", - "q14", "q15"); - } - - if (remain >= 2) - { - asm volatile("vld1.f32 {d28}, [%[outptr0]]\n" - "vld1.f32 {d30}, [%[outptr1]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "vld1.f32 {d12}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d16}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d18}, [r0]\n" - - "vmla.f32 d28, d12, %e[rx0][0]\n" - "vmla.f32 d30, d12, %e[rx1][0]\n" - "vmla.f32 d28, d14, %e[rx0][1]\n" - "vmla.f32 d30, d14, %e[rx1][1]\n" - "vmla.f32 d28, d16, %f[rx0][0]\n" - "vmla.f32 d30, d16, %f[rx1][0]\n" - "vmla.f32 d28, d18, %f[rx0][1]\n" - "vmla.f32 d30, d18, %f[rx1][1]\n" - - "vst1.f32 {d28}, [%[outptr0]]!\n" - "vst1.f32 {d30}, [%[outptr1]]!\n" - - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1) - : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1) -#ifndef _OPENMP - - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q14", "q15" -#else // _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", - "q14", "q15" -#endif // _OPENMP - ); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)) + - (*(kernel0 + outch * 2)) * (*(_x0 + 2)) + - (*(kernel0 + outch * 3)) * (*(_x0 + 3)); - - *outptr1 += (*kernel0) * (*_x1) + (*(kernel0 + outch)) * (*(_x1 + 1)) + - (*(kernel0 + outch * 2)) * (*(_x1 + 2)) + - (*(kernel0 + outch * 3)) * (*(_x1 + 3)); - - kernel0++; - outptr0++; - outptr1++; - } - - kernel0 += outch * 3; - _x0 += 4; - _x1 += 4; - } - - for (; i + 1 < inch; i += 2) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x2_t rx0 asm("d8") = vld1_f32(_x0); - register float32x2_t rx1 asm("d10") = vld1_f32(_x1); - - float *outptr0 = out0; - float *outptr1 = out0 + outch; - - int stride = outch << 2; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile("cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - - "beq 1f\n" - - "0:\n" - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d20-d21}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d22-d23}, [r0]\n" - - "vmla.f32 q14, q6, %P[rx0][0]\n" - "vmla.f32 q15, q6, %P[rx1][0]\n" - "vmla.f32 q14, q7, %P[rx0][1]\n" - "vmla.f32 q15, q7, %P[rx1][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - - "vmla.f32 q14, q10, %P[rx0][0]\n" - "vmla.f32 q15, q10, %P[rx1][0]\n" - "vmla.f32 q14, q11, %P[rx0][1]\n" - "vmla.f32 q15, q11, %P[rx1][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "vmla.f32 q14, q6, %P[rx0][0]\n" - "vmla.f32 q15, q6, %P[rx1][0]\n" - "vmla.f32 q14, q7, %P[rx0][1]\n" - "vmla.f32 q15, q7, %P[rx1][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d20-d21}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d22-d23}, [r0]\n" - - "vmla.f32 q14, q10, %P[rx0][0]\n" - "vmla.f32 q15, q10, %P[rx1][0]\n" - "vmla.f32 q14, q11, %P[rx0][1]\n" - "vmla.f32 q15, q11, %P[rx1][1]\n" - - "cmp %[oddn], #1\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - - "bne 3f\n" - - "2:\n" - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "vmla.f32 q14, q6, %P[rx0][0]\n" - "vmla.f32 q15, q6, %P[rx1][0]\n" - "vmla.f32 q14, q7, %P[rx0][1]\n" - "vmla.f32 q15, q7, %P[rx1][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), - [outptr1] "+r"(outptr1), [_n] "+r"(_n) - : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn) -#ifndef _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q10", "q11", "q14", "q15" -#else // _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", - "q14", "q15" -#endif // _OPENMP - - ); - } - - if (remain >= 2) - { - asm volatile("vld1.f32 {d28}, [%[outptr0]]\n" - "vld1.f32 {d30}, [%[outptr1]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "vld1.f32 {d12}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14}, [r0]\n" - - "vmla.f32 d28, d12, %P[rx0][0]\n" - "vmla.f32 d30, d12, %P[rx1][0]\n" - "vmla.f32 d28, d14, %P[rx0][1]\n" - "vmla.f32 d30, d14, %P[rx1][1]\n" - - "vst1.f32 {d28}, [%[outptr0]]!\n" - "vst1.f32 {d30}, [%[outptr1]]!\n" - - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1) - : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1) -#ifndef _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q14", "q15" -#else // _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", - "q14", "q15" -#endif // _OPENMP - - ); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)); - *outptr1 += (*kernel0) * (*_x1) + (*(kernel0 + outch)) * (*(_x1 + 1)); - - kernel0++; - outptr0++; - outptr1++; - } - - kernel0 += outch; - _x0 += 2; - _x1 += 2; - } - - for (; i < inch; i++) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x2_t rx0 asm("d8") = vld1_dup_f32(_x0); - register float32x2_t rx1 asm("d10") = vld1_dup_f32(_x1); - - float *outptr0 = out0; - float *outptr1 = out0 + outch; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile("cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - - "beq 1f\n" - - "0:\n" - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d20-d21}, [r0]\n" - - "vmla.f32 q14, q6, %P[rx0][0]\n" - "vmla.f32 q15, q6, %P[rx1][0]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - - "vmla.f32 q14, q10, %P[rx0][0]\n" - "vmla.f32 q15, q10, %P[rx1][0]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "vmla.f32 q14, q6, %P[rx0][0]\n" - "vmla.f32 q15, q6, %P[rx1][0]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d20-d21}, [r0]\n" - - "vmla.f32 q14, q10, %P[rx0][0]\n" - "vmla.f32 q15, q10, %P[rx1][0]\n" - - "cmp %[oddn], #1\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - - "bne 3f\n" - - "2:\n" - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "vmla.f32 q14, q6, %P[rx0][0]\n" - "vmla.f32 q15, q6, %P[rx1][0]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), - [outptr1] "+r"(outptr1), [_n] "+r"(_n) - : [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn) -#ifndef _OPENMP - : "cc", "memory", "r0", "q6", "q10", "q14", "q15" -#else // _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", - "q14", "q15" -#endif // _OPENMP - ); - } - - if (remain >= 2) - { - asm volatile("vld1.f32 {d28}, [%[outptr0]]\n" - "vld1.f32 {d30}, [%[outptr1]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "vld1.f32 {d12}, [r0]\n" - - "vmla.f32 d28, d12, %P[rx0][0]\n" - "vmla.f32 d30, d12, %P[rx1][0]\n" - - "vst1.f32 {d28}, [%[outptr0]]!\n" - "vst1.f32 {d30}, [%[outptr1]]!\n" - - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1) - : [rx0] "w"(rx0), [rx1] "w"(rx1) -#ifndef _OPENMP - : "cc", "memory", "r0", "q6", "q14", "q15" -#else // _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", - "q14", "q15" -#endif // _OPENMP - - ); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0); - *outptr1 += (*kernel0) * (*_x1); - - kernel0++; - outptr0++; - outptr1++; - } - - _x0 += 1; - _x1 += 1; - } - - img1 += inch * 2 * _stride; - out0 += outch * 2; - q += 2; - } - - for (; q < outw; q++) - { - if (padding) - { - if ((q * _stride + m % kernel_w < pad_left) || - (q * _stride + m % kernel_w) >= pad_left + bottom_blob.w) - { - img1 += inch * _stride; - out0 += outch; - continue; - } - } - - const float *_x0 = img1; - const float *kernel0 = _kernel0; - - int i = 0; - for (; i + 3 < inch; i += 4) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x4_t rx0 asm("q4") = vld1q_f32(_x0); - - float *outptr0 = out0; - - int stride = outch << 2; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile("cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d16-d17}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d18-d19}, [r0]\n" - - "beq 1f\n" - - "0:\n" - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d20-d21}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d22-d23}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d24-d25}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d26-d27}, [r0]\n" - - "vmla.f32 q14, q6, %e[rx0][0]\n" - "vmla.f32 q14, q7, %e[rx0][1]\n" - "vmla.f32 q14, q8, %f[rx0][0]\n" - "vmla.f32 q14, q9, %f[rx0][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d16-d17}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d18-d19}, [r0]\n" - - "vmla.f32 q14, q10, %e[rx0][0]\n" - "vmla.f32 q14, q11, %e[rx0][1]\n" - "vmla.f32 q14, q12, %f[rx0][0]\n" - "vmla.f32 q14, q13, %f[rx0][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "vmla.f32 q14, q6, %e[rx0][0]\n" - "vmla.f32 q14, q7, %e[rx0][1]\n" - "vmla.f32 q14, q8, %f[rx0][0]\n" - "vmla.f32 q14, q9, %f[rx0][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d20-d21}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d22-d23}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d24-d25}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d26-d27}, [r0]\n" - - "vmla.f32 q14, q10, %e[rx0][0]\n" - "vmla.f32 q14, q11, %e[rx0][1]\n" - "vmla.f32 q14, q12, %f[rx0][0]\n" - "vmla.f32 q14, q13, %f[rx0][1]\n" - - "cmp %[oddn], #1\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - - "bne 3f\n" - - "2:\n" - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d16-d17}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d18-d19}, [r0]\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "vmla.f32 q14, q6, %e[rx0][0]\n" - "vmla.f32 q14, q7, %e[rx0][1]\n" - "vmla.f32 q14, q8, %f[rx0][0]\n" - "vmla.f32 q14, q9, %f[rx0][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [_n] "+r"(_n) - : [stride] "r"(stride), [rx0] "w"(rx0), [oddn] "r"(oddn) -#ifndef _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", - "q14" -#else // _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", - "q14", "q15" -#endif // _OPENMP - - ); - } - - if (remain >= 2) - { - asm volatile("vld1.f32 {d28}, [%[outptr0]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "vld1.f32 {d12}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d16}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d18}, [r0]\n" - - "vmla.f32 d28, d12, %e[rx0][0]\n" - "vmla.f32 d28, d14, %e[rx0][1]\n" - "vmla.f32 d28, d16, %f[rx0][0]\n" - "vmla.f32 d28, d18, %f[rx0][1]\n" - - "vst1.f32 {d28}, [%[outptr0]]!\n" - - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0) - : [stride] "r"(stride), [rx0] "w"(rx0) -#ifndef _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q14", "q15" -#else // _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", - "q14", "q15" -#endif // _OPENMP - - ); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)) + - (*(kernel0 + outch * 2)) * (*(_x0 + 2)) + - (*(kernel0 + outch * 3)) * (*(_x0 + 3)); - - kernel0++; - outptr0++; - } - - kernel0 += outch * 3; - _x0 += 4; - } - - for (; i + 1 < inch; i += 2) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x2_t rx0 asm("d8") = vld1_f32(_x0); - - float *outptr0 = out0; - - int stride = outch << 2; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile("cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - - "beq 1f\n" - - "0:\n" - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d20-d21}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d22-d23}, [r0]\n" - - "vmla.f32 q14, q6, %P[rx0][0]\n" - "vmla.f32 q14, q7, %P[rx0][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - - "vmla.f32 q14, q10, %P[rx0][0]\n" - "vmla.f32 q14, q11, %P[rx0][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "vmla.f32 q14, q6, %P[rx0][0]\n" - "vmla.f32 q14, q7, %P[rx0][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d20-d21}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d22-d23}, [r0]\n" - - "vmla.f32 q14, q10, %P[rx0][0]\n" - "vmla.f32 q14, q11, %P[rx0][1]\n" - - "cmp %[oddn], #1\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "bne 3f\n" - - "2:\n" - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "vmla.f32 q14, q6, %P[rx0][0]\n" - "vmla.f32 q14, q7, %P[rx0][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [_n] "+r"(_n) - : [stride] "r"(stride), [rx0] "w"(rx0), [oddn] "r"(oddn) -#ifndef _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q10", "q11", "q14" -#else // _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", - "q14", "q15" -#endif // _OPENMP - - ); - } - - if (remain >= 2) - { - asm volatile("vld1.f32 {d28}, [%[outptr0]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "vld1.f32 {d12}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14}, [r0]\n" - - "vmla.f32 d28, d12, %P[rx0][0]\n" - "vmla.f32 d28, d14, %P[rx0][1]\n" - - "vst1.f32 {d28}, [%[outptr0]]!\n" - - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0) - : [stride] "r"(stride), [rx0] "w"(rx0) -#ifndef _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q14", "q15" -#else // _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", - "q14", "q15" -#endif // _OPENMP - - ); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)); - - kernel0++; - outptr0++; - } - - kernel0 += outch; - _x0 += 2; - } - - for (; i < inch; i++) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x2_t rx0 asm("d8") = vld1_dup_f32(_x0); - - float *outptr0 = out0; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile("cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - - "beq 1f\n" - - "0:\n" - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d20-d21}, [r0]\n" - - "vmla.f32 q14, q6, %P[rx0][0]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - - "vmla.f32 q14, q10, %P[rx0][0]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "vmla.f32 q14, q6, %P[rx0][0]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d20-d21}, [r0]\n" - - "vmla.f32 q14, q10, %P[rx0][0]\n" - - "cmp %[oddn], #1\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - - "bne 3f\n" - - "2:\n" - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "vmla.f32 q14, q6, %P[rx0][0]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [_n] "+r"(_n) - : [rx0] "w"(rx0), [oddn] "r"(oddn) -#ifndef _OPENMP - : "cc", "memory", "r0", "q6", "q10", "q14" - -#else // _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", - "q14", "q15" -#endif // _OPENMP - ); - } - - if (remain >= 2) - { - asm volatile("vld1.f32 {d28}, [%[outptr0]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "vld1.f32 {d12}, [r0]\n" - - "vmla.f32 d28, d12, %P[rx0][0]\n" - - "vst1.f32 {d28}, [%[outptr0]]!\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0) - : [rx0] "w"(rx0) -#ifndef _OPENMP - : "cc", "memory", "r0", "q6", "q14", "q15" -#else // _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", - "q14", "q15" -#endif // _OPENMP - - ); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0); - - kernel0++; - outptr0++; - } - - _x0 += 1; - } - - img1 += inch * _stride; - out0 += outch; - } - } - } -} - -static void direct_conv_s(const convMat_t &bottom_blob, convMat_t &top_blob, - const convMat_t &_kernel, const int _stride, const int padding, - const int pad_top, const int pad_left) -{ - const int w = bottom_blob.w; - const int h = bottom_blob.h; - const int inch = bottom_blob.c; - const int outw = top_blob.w; - const int outh = top_blob.h; - const int outch = top_blob.c; - const int kernel_w = _kernel.w; - const int kernel_h = _kernel.h; - -#ifdef _OPENMP -#pragma omp parallel for -#endif // _OPENMP - for (int p = 0; p < outh; p++) - { - const float *img0 = bottom_blob.data + (p * _stride - pad_top) * w * inch; - float *out = top_blob.data + p * outw * outch; - - // clear output. - for (int j = 0; j < outw * outch; j++) - { - *(out + j) = 0.f; - } - - for (int m = 0; m < kernel_w * kernel_h; m++) - { - if (padding) - { - if (((p * _stride + m / kernel_w) < pad_top) || (p * _stride + m / kernel_w >= pad_top + h)) - { - continue; - } - } - - float *out0 = out; - const float *_kernel0 = _kernel.data + m * inch * outch; - const float *img1 = img0 + (m / kernel_w) * w * inch + (m % kernel_w - pad_left) * inch; - - int q = 0; - for (; q + 1 < outw; /*q += 2*/) - { - if (padding) - { - if (((q + 1) * _stride + m % kernel_w < pad_left) || - (q * _stride + m % kernel_w >= pad_left + w)) - { - out0 += outch * 2; - img1 += inch * _stride * 2; - q += 2; - continue; - } - else if (q * _stride + m % kernel_w < pad_left) - { - out0 += outch; - img1 += inch * _stride; - q++; - continue; - } - else if ((q + 1) * _stride + m % kernel_w >= pad_left + w) - { - break; - } - } - - const float *_x0 = img1; - const float *_x1 = img1 + inch * _stride; - - const float *kernel0 = _kernel0; - - int i = 0; - for (; i + 3 < inch; i += 4) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x4_t rx0 asm("q4") = vld1q_f32(_x0); - register float32x4_t rx1 asm("q5") = vld1q_f32(_x1); - - float *outptr0 = out0; - float *outptr1 = out0 + outch; - - int stride = outch << 2; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile("cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d16-d17}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d18-d19}, [r0]\n" - - "beq 1f\n" - - "0:\n" - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d20-d21}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d22-d23}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d24-d25}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d26-d27}, [r0]\n" - - "vmla.f32 q14, q6, %e[rx0][0]\n" - "vmla.f32 q15, q6, %e[rx1][0]\n" - "vmla.f32 q14, q7, %e[rx0][1]\n" - "vmla.f32 q15, q7, %e[rx1][1]\n" - "vmla.f32 q14, q8, %f[rx0][0]\n" - "vmla.f32 q15, q8, %f[rx1][0]\n" - "vmla.f32 q14, q9, %f[rx0][1]\n" - "vmla.f32 q15, q9, %f[rx1][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d16-d17}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d18-d19}, [r0]\n" - - "vmla.f32 q14, q10, %e[rx0][0]\n" - "vmla.f32 q15, q10, %e[rx1][0]\n" - "vmla.f32 q14, q11, %e[rx0][1]\n" - "vmla.f32 q15, q11, %e[rx1][1]\n" - "vmla.f32 q14, q12, %f[rx0][0]\n" - "vmla.f32 q15, q12, %f[rx1][0]\n" - "vmla.f32 q14, q13, %f[rx0][1]\n" - "vmla.f32 q15, q13, %f[rx1][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "vmla.f32 q14, q6, %e[rx0][0]\n" - "vmla.f32 q15, q6, %e[rx1][0]\n" - "vmla.f32 q14, q7, %e[rx0][1]\n" - "vmla.f32 q15, q7, %e[rx1][1]\n" - "vmla.f32 q14, q8, %f[rx0][0]\n" - "vmla.f32 q15, q8, %f[rx1][0]\n" - "vmla.f32 q14, q9, %f[rx0][1]\n" - "vmla.f32 q15, q9, %f[rx1][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d20-d21}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d22-d23}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d24-d25}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d26-d27}, [r0]\n" - - "vmla.f32 q14, q10, %e[rx0][0]\n" - "vmla.f32 q15, q10, %e[rx1][0]\n" - "vmla.f32 q14, q11, %e[rx0][1]\n" - "vmla.f32 q15, q11, %e[rx1][1]\n" - "vmla.f32 q14, q12, %f[rx0][0]\n" - "vmla.f32 q15, q12, %f[rx1][0]\n" - "vmla.f32 q14, q13, %f[rx0][1]\n" - "vmla.f32 q15, q13, %f[rx1][1]\n" - - "cmp %[oddn], #1\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - - "bne 3f\n" - - "2:\n" - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d16-d17}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d18-d19}, [r0]\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "vmla.f32 q14, q6, %e[rx0][0]\n" - "vmla.f32 q15, q6, %e[rx1][0]\n" - "vmla.f32 q14, q7, %e[rx0][1]\n" - "vmla.f32 q15, q7, %e[rx1][1]\n" - "vmla.f32 q14, q8, %f[rx0][0]\n" - "vmla.f32 q15, q8, %f[rx1][0]\n" - "vmla.f32 q14, q9, %f[rx0][1]\n" - "vmla.f32 q15, q9, %f[rx1][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), - [outptr1] "+r"(outptr1), [_n] "+r"(_n) - : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn) - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", - "q14", "q15"); - } - - if (remain >= 2) - { - asm volatile("vld1.f32 {d28}, [%[outptr0]]\n" - "vld1.f32 {d30}, [%[outptr1]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "vld1.f32 {d12}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d16}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d18}, [r0]\n" - - "vmla.f32 d28, d12, %e[rx0][0]\n" - "vmla.f32 d30, d12, %e[rx1][0]\n" - "vmla.f32 d28, d14, %e[rx0][1]\n" - "vmla.f32 d30, d14, %e[rx1][1]\n" - "vmla.f32 d28, d16, %f[rx0][0]\n" - "vmla.f32 d30, d16, %f[rx1][0]\n" - "vmla.f32 d28, d18, %f[rx0][1]\n" - "vmla.f32 d30, d18, %f[rx1][1]\n" - - "vst1.f32 {d28}, [%[outptr0]]!\n" - "vst1.f32 {d30}, [%[outptr1]]!\n" - - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1) - : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1) -#ifndef _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q14", "q15" -#else - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", - "q14", "q15" -#endif - ); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)) + - (*(kernel0 + outch * 2)) * (*(_x0 + 2)) + - (*(kernel0 + outch * 3)) * (*(_x0 + 3)); - - *outptr1 += (*kernel0) * (*_x1) + (*(kernel0 + outch)) * (*(_x1 + 1)) + - (*(kernel0 + outch * 2)) * (*(_x1 + 2)) + - (*(kernel0 + outch * 3)) * (*(_x1 + 3)); - - kernel0++; - outptr0++; - outptr1++; - } - - kernel0 += outch * 3; - _x0 += 4; - _x1 += 4; - } - - for (; i + 1 < inch; i += 2) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x2_t rx0 asm("d8") = vld1_f32(_x0); - register float32x2_t rx1 asm("d10") = vld1_f32(_x1); - - float *outptr0 = out0; - float *outptr1 = out0 + outch; - - int stride = outch << 2; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile("cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - - "beq 1f\n" - - "0:\n" - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d20-d21}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d22-d23}, [r0]\n" - - "vmla.f32 q14, q6, %P[rx0][0]\n" - "vmla.f32 q15, q6, %P[rx1][0]\n" - "vmla.f32 q14, q7, %P[rx0][1]\n" - "vmla.f32 q15, q7, %P[rx1][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - - "vmla.f32 q14, q10, %P[rx0][0]\n" - "vmla.f32 q15, q10, %P[rx1][0]\n" - "vmla.f32 q14, q11, %P[rx0][1]\n" - "vmla.f32 q15, q11, %P[rx1][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "vmla.f32 q14, q6, %P[rx0][0]\n" - "vmla.f32 q15, q6, %P[rx1][0]\n" - "vmla.f32 q14, q7, %P[rx0][1]\n" - "vmla.f32 q15, q7, %P[rx1][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d20-d21}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d22-d23}, [r0]\n" - - "vmla.f32 q14, q10, %P[rx0][0]\n" - "vmla.f32 q15, q10, %P[rx1][0]\n" - "vmla.f32 q14, q11, %P[rx0][1]\n" - "vmla.f32 q15, q11, %P[rx1][1]\n" - - "cmp %[oddn], #1\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - - "bne 3f\n" - - "2:\n" - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "vmla.f32 q14, q6, %P[rx0][0]\n" - "vmla.f32 q15, q6, %P[rx1][0]\n" - "vmla.f32 q14, q7, %P[rx0][1]\n" - "vmla.f32 q15, q7, %P[rx1][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), - [outptr1] "+r"(outptr1), [_n] "+r"(_n) - : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn) -#ifndef _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q10", "q11", "q14", "q15" -#else // _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", - "q14", "q15" -#endif // _OPENMP - ); - } - - if (remain >= 2) - { - asm volatile("vld1.f32 {d28}, [%[outptr0]]\n" - "vld1.f32 {d30}, [%[outptr1]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "vld1.f32 {d12}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14}, [r0]\n" - - "vmla.f32 d28, d12, %P[rx0][0]\n" - "vmla.f32 d30, d12, %P[rx1][0]\n" - "vmla.f32 d28, d14, %P[rx0][1]\n" - "vmla.f32 d30, d14, %P[rx1][1]\n" - - "vst1.f32 {d28}, [%[outptr0]]!\n" - "vst1.f32 {d30}, [%[outptr1]]!\n" - - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1) - : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1) -#ifndef _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q14", "q15" -#else // _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", - "q14", "q15" -#endif // _OPENMP - ); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)); - *outptr1 += (*kernel0) * (*_x1) + (*(kernel0 + outch)) * (*(_x1 + 1)); - - kernel0++; - outptr0++; - outptr1++; - } - - kernel0 += outch; - _x0 += 2; - _x1 += 2; - } - - for (; i < inch; i++) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x2_t rx0 asm("d8") = vld1_dup_f32(_x0); - register float32x2_t rx1 asm("d10") = vld1_dup_f32(_x1); - - float *outptr0 = out0; - float *outptr1 = out0 + outch; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile("cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - - "beq 1f\n" - - "0:\n" - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d20-d21}, [r0]\n" - - "vmla.f32 q14, q6, %P[rx0][0]\n" - "vmla.f32 q15, q6, %P[rx1][0]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - - "vmla.f32 q14, q10, %P[rx0][0]\n" - "vmla.f32 q15, q10, %P[rx1][0]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "vmla.f32 q14, q6, %P[rx0][0]\n" - "vmla.f32 q15, q6, %P[rx1][0]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d20-d21}, [r0]\n" - - "vmla.f32 q14, q10, %P[rx0][0]\n" - "vmla.f32 q15, q10, %P[rx1][0]\n" - - "cmp %[oddn], #1\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - - "bne 3f\n" - - "2:\n" - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - "vld1.f32 {d30-d31}, [%[outptr1]]\n" - - "vmla.f32 q14, q6, %P[rx0][0]\n" - "vmla.f32 q15, q6, %P[rx1][0]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vst1.f32 {d30-d31}, [%[outptr1]]!\n" - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), - [outptr1] "+r"(outptr1), [_n] "+r"(_n) - : [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn) -#ifndef _OPENMP - : "cc", "memory", "r0", "q6", "q10", "q14", "q15" -#else // _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", - "q14", "q15" -#endif // _OPENMP - ); - } - - if (remain >= 2) - { - asm volatile("vld1.f32 {d28}, [%[outptr0]]\n" - "vld1.f32 {d30}, [%[outptr1]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "vld1.f32 {d12}, [r0]\n" - - "vmla.f32 d28, d12, %P[rx0][0]\n" - "vmla.f32 d30, d12, %P[rx1][0]\n" - - "vst1.f32 {d28}, [%[outptr0]]!\n" - "vst1.f32 {d30}, [%[outptr1]]!\n" - - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1) - : [rx0] "w"(rx0), [rx1] "w"(rx1) -#ifndef _OPENMP - : "cc", "memory", "r0", "q6", "q14", "q15" -#else // _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", - "q14", "q15" -#endif // _OPENMP - ); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0); - *outptr1 += (*kernel0) * (*_x1); - - kernel0++; - outptr0++; - outptr1++; - } - - _x0 += 1; - _x1 += 1; - } - - img1 += inch * 2 * _stride; - out0 += outch * 2; - q += 2; - } - - for (; q < outw; q++) - { - if (padding) - { - if ((q * _stride + m % kernel_w < pad_left) || - (q * _stride + m % kernel_w >= pad_left + w)) - { - img1 += inch * _stride; - out0 += outch; - continue; - } - } - - const float *_x0 = img1; - const float *kernel0 = _kernel0; - - int i = 0; - for (; i + 3 < inch; i += 4) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x4_t rx0 asm("q4") = vld1q_f32(_x0); - - float *outptr0 = out0; - - int stride = outch << 2; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile("cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d16-d17}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d18-d19}, [r0]\n" - - "beq 1f\n" - - "0:\n" - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d20-d21}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d22-d23}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d24-d25}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d26-d27}, [r0]\n" - - "vmla.f32 q14, q6, %e[rx0][0]\n" - "vmla.f32 q14, q7, %e[rx0][1]\n" - "vmla.f32 q14, q8, %f[rx0][0]\n" - "vmla.f32 q14, q9, %f[rx0][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d16-d17}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d18-d19}, [r0]\n" - - "vmla.f32 q14, q10, %e[rx0][0]\n" - "vmla.f32 q14, q11, %e[rx0][1]\n" - "vmla.f32 q14, q12, %f[rx0][0]\n" - "vmla.f32 q14, q13, %f[rx0][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "vmla.f32 q14, q6, %e[rx0][0]\n" - "vmla.f32 q14, q7, %e[rx0][1]\n" - "vmla.f32 q14, q8, %f[rx0][0]\n" - "vmla.f32 q14, q9, %f[rx0][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d20-d21}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d22-d23}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d24-d25}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d26-d27}, [r0]\n" - - "vmla.f32 q14, q10, %e[rx0][0]\n" - "vmla.f32 q14, q11, %e[rx0][1]\n" - "vmla.f32 q14, q12, %f[rx0][0]\n" - "vmla.f32 q14, q13, %f[rx0][1]\n" - - "cmp %[oddn], #1\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - - "bne 3f\n" - - "2:\n" - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d16-d17}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d18-d19}, [r0]\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "vmla.f32 q14, q6, %e[rx0][0]\n" - "vmla.f32 q14, q7, %e[rx0][1]\n" - "vmla.f32 q14, q8, %f[rx0][0]\n" - "vmla.f32 q14, q9, %f[rx0][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [_n] "+r"(_n) - : [stride] "r"(stride), [rx0] "w"(rx0), [oddn] "r"(oddn) -#ifndef _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", - "q14" -#else // _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", - "q14", "q15" -#endif // _OPENMP - ); - } - - if (remain >= 2) - { - asm volatile("vld1.f32 {d28}, [%[outptr0]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "vld1.f32 {d12}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d16}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d18}, [r0]\n" - - "vmla.f32 d28, d12, %e[rx0][0]\n" - "vmla.f32 d28, d14, %e[rx0][1]\n" - "vmla.f32 d28, d16, %f[rx0][0]\n" - "vmla.f32 d28, d18, %f[rx0][1]\n" - - "vst1.f32 {d28}, [%[outptr0]]!\n" - - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0) - : [stride] "r"(stride), [rx0] "w"(rx0) -#ifndef _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q14", "q15" -#else // _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", - "q14", "q15" -#endif // _OPENMP - ); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)) + - (*(kernel0 + outch * 2)) * (*(_x0 + 2)) + - (*(kernel0 + outch * 3)) * (*(_x0 + 3)); - - kernel0++; - outptr0++; - } - - kernel0 += outch * 3; - _x0 += 4; - } - - for (; i + 1 < inch; i += 2) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x2_t rx0 asm("d8") = vld1_f32(_x0); - - float *outptr0 = out0; - - int stride = outch << 2; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile("cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - - "beq 1f\n" - - "0:\n" - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d20-d21}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d22-d23}, [r0]\n" - - "vmla.f32 q14, q6, %P[rx0][0]\n" - "vmla.f32 q14, q7, %P[rx0][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - - "vmla.f32 q14, q10, %P[rx0][0]\n" - "vmla.f32 q14, q11, %P[rx0][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "vmla.f32 q14, q6, %P[rx0][0]\n" - "vmla.f32 q14, q7, %P[rx0][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d20-d21}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d22-d23}, [r0]\n" - - "vmla.f32 q14, q10, %P[rx0][0]\n" - "vmla.f32 q14, q11, %P[rx0][1]\n" - - "cmp %[oddn], #1\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "bne 3f\n" - - "2:\n" - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "vmla.f32 q14, q6, %P[rx0][0]\n" - "vmla.f32 q14, q7, %P[rx0][1]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [_n] "+r"(_n) - : [stride] "r"(stride), [rx0] "w"(rx0), [oddn] "r"(oddn) -#ifndef _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q10", "q11", "q14" -#else // _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", - "q14", "q15" -#endif // _OPENMP - ); - } - - if (remain >= 2) - { - asm volatile("vld1.f32 {d28}, [%[outptr0]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "vld1.f32 {d12}, [r0]\n" - "add r0, r0, %[stride]\n" - "vld1.f32 {d14}, [r0]\n" - - "vmla.f32 d28, d12, %P[rx0][0]\n" - "vmla.f32 d28, d14, %P[rx0][1]\n" - - "vst1.f32 {d28}, [%[outptr0]]!\n" - - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0) - : [stride] "r"(stride), [rx0] "w"(rx0) -#ifndef _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q14", "q15" -#else // _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", - "q14", "q15" -#endif // _OPENMP - ); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)); - - kernel0++; - outptr0++; - } - - kernel0 += outch; - _x0 += 2; - } - - for (; i < inch; i++) - { - int nn = outch >> 2; - int remain = outch & 0x03; - - register float32x2_t rx0 asm("d8") = vld1_dup_f32(_x0); - - float *outptr0 = out0; - - if (nn > 0) - { - int _n = nn >> 1; - int oddn = nn & 1; - - asm volatile("cmp %[_n], #0\n" - "beq 2f\n" - "subs %[_n], %[_n], #1\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - - "beq 1f\n" - - "0:\n" - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d20-d21}, [r0]\n" - - "vmla.f32 q14, q6, %P[rx0][0]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - - "vmla.f32 q14, q10, %P[rx0][0]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "subs %[_n], %[_n], #1\n" - "bne 0b\n" - - "1:\n" - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "vmla.f32 q14, q6, %P[rx0][0]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d20-d21}, [r0]\n" - - "vmla.f32 q14, q10, %P[rx0][0]\n" - - "cmp %[oddn], #1\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - - "bne 3f\n" - - "2:\n" - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #16\n" - "vld1.f32 {d12-d13}, [r0]\n" - - "vld1.f32 {d28-d29}, [%[outptr0]]\n" - - "vmla.f32 q14, q6, %P[rx0][0]\n" - - "vst1.f32 {d28-d29}, [%[outptr0]]!\n" - "3:\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [_n] "+r"(_n) - : [rx0] "w"(rx0), [oddn] "r"(oddn) -#ifndef _OPENMP - : "cc", "memory", "r0", "q6", "q10", "q14" -#else // _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", - "q14", "q15" -#endif // _OPENMP - ); - } - - if (remain >= 2) - { - asm volatile("vld1.f32 {d28}, [%[outptr0]]\n" - - "mov r0, %[kernel0]\n" - "add %[kernel0], %[kernel0], #8\n" - "vld1.f32 {d12}, [r0]\n" - - "vmla.f32 d28, d12, %P[rx0][0]\n" - - "vst1.f32 {d28}, [%[outptr0]]!\n" - : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0) - : [rx0] "w"(rx0) -#ifndef _OPENMP - : "cc", "memory", "r0", "q6", "q14", "q15" -#else // _OPENMP - : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", - "q14", "q15" -#endif // _OPENMP - ); - remain -= 2; - } - - if (remain == 1) - { - *outptr0 += (*kernel0) * (*_x0); - - kernel0++; - outptr0++; - } - - _x0 += 1; - } - - img1 += inch * _stride; - out0 += outch; - } - } - } -} -#endif // __aarch64__ - -void direct_conv_colmajor(const convMat_t &bottom_blob, convMat_t &top_blob, - const convMat_t &kernel, const convParams_t ¶ms, int num_threads) -{ - omp_set_num_threads(num_threads); - - if (bottom_blob.c * top_blob.c < 256 * 256) - { - direct_conv_s(bottom_blob, top_blob, kernel, params.stride_w, params.padding, params.pad_h, - params.pad_w); - return; - } - - direct_conv_l(bottom_blob, top_blob, kernel, params.stride_w, params.padding, params.pad_h, - params.pad_w); -} - -} // namespace srcn -} // namespace nnfw diff --git a/compute/ncnn/src/srcn/direct_conv_colmajor.h b/compute/ncnn/src/srcn/direct_conv_colmajor.h deleted file mode 100644 index 5e15192c9..000000000 --- a/compute/ncnn/src/srcn/direct_conv_colmajor.h +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __NNFW_SRCN_DIRECT_CONV_COLMAJOR_H__ -#define __NNFW_SRCN_DIRECT_CONV_COLMAJOR_H__ - -#include "ncnn/srcn/conv_type.h" - -namespace nnfw -{ -namespace srcn -{ - -void direct_conv_colmajor(const convMat_t &, convMat_t &, const convMat_t &, const convParams_t &, - int); - -} // namespace srcn -} // namespace nnfw - -#endif // __NNFW_SRCN_DIRECT_CONV_COLMAJOR_H__ diff --git a/compute/ncnn/src/srcn/sgemm_kernel.cc b/compute/ncnn/src/srcn/sgemm_kernel.cc deleted file mode 100644 index 90c3641db..000000000 --- a/compute/ncnn/src/srcn/sgemm_kernel.cc +++ /dev/null @@ -1,2508 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <arm_neon.h> - -namespace nnfw -{ -namespace srcn -{ - -#if __aarch64__ -static void sgemm_rowmajor_micro_kernel_8x12(const float *lhs_ptr, const float *rhs_ptr, - float *res_ptr, const int k, const int k0, - const int stride) -{ - int oddk = (k & 1); - int nk = ((k + 1) / 2) - 1; - - const int nstride = stride << 2; - - __asm __volatile("ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v2.4s, v3.4s, v4.4s}, [%[rhs_ptr]], #48\n" - - "cmp %[k0], #0\n" - "beq 0f\n" - - "mov x0, %[res_ptr]\n" - "ld1 {v8.4s, v9.4s, v10.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v11.4s, v12.4s, v13.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v14.4s, v15.4s, v16.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v17.4s, v18.4s, v19.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v20.4s, v21.4s, v22.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v23.4s, v24.4s, v25.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v26.4s, v27.4s, v28.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v29.4s, v30.4s, v31.4s}, [x0]\n" - "cbz %w[nk], 4f\n" - "b 1f\n" - - "0:\n" - "movi v8.4s, #0x0\n" - "movi v9.4s, #0x0\n" - "movi v10.4s, #0x0\n" - "movi v11.4s, #0x0\n" - "movi v12.4s, #0x0\n" - "movi v13.4s, #0x0\n" - "movi v14.4s, #0x0\n" - "movi v15.4s, #0x0\n" - "movi v16.4s, #0x0\n" - "movi v17.4s, #0x0\n" - "movi v18.4s, #0x0\n" - "movi v19.4s, #0x0\n" - "movi v20.4s, #0x0\n" - "movi v21.4s, #0x0\n" - "movi v22.4s, #0x0\n" - "movi v23.4s, #0x0\n" - "movi v24.4s, #0x0\n" - "movi v25.4s, #0x0\n" - "movi v26.4s, #0x0\n" - "movi v27.4s, #0x0\n" - "movi v28.4s, #0x0\n" - "movi v29.4s, #0x0\n" - "movi v30.4s, #0x0\n" - "movi v31.4s, #0x0\n" - "cbz %w[nk], 4f\n" - - "1:\n" - "fmla v8.4s, v2.4s, v0.s[0]\n" - "fmla v11.4s, v2.4s, v0.s[1]\n" - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - "fmla v14.4s, v2.4s, v0.s[2]\n" - "fmla v17.4s, v2.4s, v0.s[3]\n" - "fmla v9.4s, v3.4s, v0.s[0]\n" - "fmla v12.4s, v3.4s, v0.s[1]\n" - "fmla v15.4s, v3.4s, v0.s[2]\n" - "fmla v18.4s, v3.4s, v0.s[3]\n" - "fmla v10.4s, v4.4s, v0.s[0]\n" - "fmla v13.4s, v4.4s, v0.s[1]\n" - "fmla v16.4s, v4.4s, v0.s[2]\n" - "fmla v19.4s, v4.4s, v0.s[3]\n" - - "ld1 {v5.4s, v6.4s, v7.4s}, [%[rhs_ptr]], #48\n" - - "fmla v20.4s, v2.4s, v1.s[0]\n" - "fmla v23.4s, v2.4s, v1.s[1]\n" - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - "fmla v26.4s, v2.4s, v1.s[2]\n" - "fmla v29.4s, v2.4s, v1.s[3]\n" - "fmla v21.4s, v3.4s, v1.s[0]\n" - "fmla v24.4s, v3.4s, v1.s[1]\n" - "fmla v27.4s, v3.4s, v1.s[2]\n" - "fmla v30.4s, v3.4s, v1.s[3]\n" - "fmla v22.4s, v4.4s, v1.s[0]\n" - "fmla v25.4s, v4.4s, v1.s[1]\n" - "fmla v28.4s, v4.4s, v1.s[2]\n" - "fmla v31.4s, v4.4s, v1.s[3]\n" - - "fmla v8.4s, v5.4s, v0.s[0]\n" - "fmla v11.4s, v5.4s, v0.s[1]\n" - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - "fmla v14.4s, v5.4s, v0.s[2]\n" - "fmla v17.4s, v5.4s, v0.s[3]\n" - "fmla v9.4s, v6.4s, v0.s[0]\n" - "fmla v12.4s, v6.4s, v0.s[1]\n" - "fmla v15.4s, v6.4s, v0.s[2]\n" - "fmla v18.4s, v6.4s, v0.s[3]\n" - "fmla v10.4s, v7.4s, v0.s[0]\n" - "fmla v13.4s, v7.4s, v0.s[1]\n" - "fmla v16.4s, v7.4s, v0.s[2]\n" - "fmla v19.4s, v7.4s, v0.s[3]\n" - - "ld1 {v2.4s, v3.4s, v4.4s}, [%[rhs_ptr]], #48\n" - - "fmla v20.4s, v5.4s, v1.s[0]\n" - "fmla v23.4s, v5.4s, v1.s[1]\n" - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - "fmla v26.4s, v5.4s, v1.s[2]\n" - "fmla v29.4s, v5.4s, v1.s[3]\n" - "fmla v21.4s, v6.4s, v1.s[0]\n" - "fmla v24.4s, v6.4s, v1.s[1]\n" - "fmla v27.4s, v6.4s, v1.s[2]\n" - "fmla v30.4s, v6.4s, v1.s[3]\n" - "fmla v22.4s, v7.4s, v1.s[0]\n" - "fmla v25.4s, v7.4s, v1.s[1]\n" - "subs %w[nk], %w[nk], #1\n" - "fmla v28.4s, v7.4s, v1.s[2]\n" - "fmla v31.4s, v7.4s, v1.s[3]\n" - "bne 1b\n" - - "4:\n" - "mov x0, %[res_ptr]\n" - "cbnz %[oddk], 2f\n" - - "fmla v8.4s, v2.4s, v0.s[0]\n" - "fmla v9.4s, v3.4s, v0.s[0]\n" - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - "fmla v10.4s, v4.4s, v0.s[0]\n" - "fmla v11.4s, v2.4s, v0.s[1]\n" - "fmla v12.4s, v3.4s, v0.s[1]\n" - "fmla v13.4s, v4.4s, v0.s[1]\n" - "ld1 {v5.4s, v6.4s, v7.4s}, [%[rhs_ptr]], #48\n" - "fmla v14.4s, v2.4s, v0.s[2]\n" - "fmla v15.4s, v3.4s, v0.s[2]\n" - "fmla v16.4s, v4.4s, v0.s[2]\n" - "fmla v17.4s, v2.4s, v0.s[3]\n" - "fmla v18.4s, v3.4s, v0.s[3]\n" - "fmla v19.4s, v4.4s, v0.s[3]\n" - - "fmla v20.4s, v2.4s, v1.s[0]\n" - "fmla v21.4s, v3.4s, v1.s[0]\n" - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - "fmla v22.4s, v4.4s, v1.s[0]\n" - "fmla v23.4s, v2.4s, v1.s[1]\n" - "fmla v24.4s, v3.4s, v1.s[1]\n" - "fmla v25.4s, v4.4s, v1.s[1]\n" - "fmla v26.4s, v2.4s, v1.s[2]\n" - "fmla v27.4s, v3.4s, v1.s[2]\n" - "fmla v28.4s, v4.4s, v1.s[2]\n" - "fmla v29.4s, v2.4s, v1.s[3]\n" - "fmla v30.4s, v3.4s, v1.s[3]\n" - "fmla v31.4s, v4.4s, v1.s[3]\n" - - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v2.4s, v3.4s, v4.4s}, [%[rhs_ptr]], #48\n" - - "fmla v8.4s, v5.4s, v0.s[0]\n" - "fmla v9.4s, v6.4s, v0.s[0]\n" - "fmla v10.4s, v7.4s, v0.s[0]\n" - "st1 {v8.4s, v9.4s, v10.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v11.4s, v5.4s, v0.s[1]\n" - "fmla v12.4s, v6.4s, v0.s[1]\n" - "fmla v13.4s, v7.4s, v0.s[1]\n" - "st1 {v11.4s, v12.4s, v13.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v14.4s, v5.4s, v0.s[2]\n" - "fmla v15.4s, v6.4s, v0.s[2]\n" - "fmla v16.4s, v7.4s, v0.s[2]\n" - "st1 {v14.4s, v15.4s, v16.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v17.4s, v5.4s, v0.s[3]\n" - "fmla v18.4s, v6.4s, v0.s[3]\n" - "fmla v19.4s, v7.4s, v0.s[3]\n" - "st1 {v17.4s, v18.4s, v19.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - - "fmla v20.4s, v5.4s, v1.s[0]\n" - "fmla v21.4s, v6.4s, v1.s[0]\n" - "fmla v22.4s, v7.4s, v1.s[0]\n" - "st1 {v20.4s, v21.4s, v22.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v23.4s, v5.4s, v1.s[1]\n" - "fmla v24.4s, v6.4s, v1.s[1]\n" - "fmla v25.4s, v7.4s, v1.s[1]\n" - "st1 {v23.4s, v24.4s, v25.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v26.4s, v5.4s, v1.s[2]\n" - "fmla v27.4s, v6.4s, v1.s[2]\n" - "fmla v28.4s, v7.4s, v1.s[2]\n" - "st1 {v26.4s, v27.4s, v28.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v29.4s, v5.4s, v1.s[3]\n" - "fmla v30.4s, v6.4s, v1.s[3]\n" - "fmla v31.4s, v7.4s, v1.s[3]\n" - "b 3f\n" - - "2:\n" - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v5.4s, v6.4s, v7.4s}, [%[rhs_ptr]], #48\n" - - "fmla v8.4s, v2.4s, v0.s[0]\n" - "fmla v9.4s, v3.4s, v0.s[0]\n" - "fmla v10.4s, v4.4s, v0.s[0]\n" - "st1 {v8.4s, v9.4s, v10.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v11.4s, v2.4s, v0.s[1]\n" - "fmla v12.4s, v3.4s, v0.s[1]\n" - "fmla v13.4s, v4.4s, v0.s[1]\n" - "st1 {v11.4s, v12.4s, v13.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v14.4s, v2.4s, v0.s[2]\n" - "fmla v15.4s, v3.4s, v0.s[2]\n" - "fmla v16.4s, v4.4s, v0.s[2]\n" - "st1 {v14.4s, v15.4s, v16.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v17.4s, v2.4s, v0.s[3]\n" - "fmla v18.4s, v3.4s, v0.s[3]\n" - "fmla v19.4s, v4.4s, v0.s[3]\n" - "st1 {v17.4s, v18.4s, v19.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - - "fmla v20.4s, v2.4s, v1.s[0]\n" - "fmla v21.4s, v3.4s, v1.s[0]\n" - "fmla v22.4s, v4.4s, v1.s[0]\n" - "st1 {v20.4s, v21.4s, v22.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v23.4s, v2.4s, v1.s[1]\n" - "fmla v24.4s, v3.4s, v1.s[1]\n" - "fmla v25.4s, v4.4s, v1.s[1]\n" - "st1 {v23.4s, v24.4s, v25.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v26.4s, v2.4s, v1.s[2]\n" - "fmla v27.4s, v3.4s, v1.s[2]\n" - "fmla v28.4s, v4.4s, v1.s[2]\n" - "st1 {v26.4s, v27.4s, v28.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v29.4s, v2.4s, v1.s[3]\n" - "fmla v30.4s, v3.4s, v1.s[3]\n" - "fmla v31.4s, v4.4s, v1.s[3]\n" - - "3:\n" - "st1 {v29.4s, v30.4s, v31.4s}, [x0]\n" - : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), [res_ptr] "+r"(res_ptr), - [nk] "+r"(nk) - : [oddk] "r"(oddk), [k0] "r"(k0), [nstride] "r"(nstride) - : "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", - "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", - "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); -} - -static void sgemm_rowmajor_micro_kernel_12x8(const float *lhs_ptr, const float *rhs_ptr, - float *res_ptr, const int k, const int k0, - const int stride) -{ - int oddk = (k & 1); - int nk = ((k + 1) / 2) - 1; - - const int nstride = stride << 2; - - __asm __volatile("ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v4.4s, v5.4s}, [%[rhs_ptr]], #32\n" - - "cmp %[k0], #0\n" - "beq 0f\n" - - "mov x0, %[res_ptr]\n" - "ld1 {v8.4s, v9.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v10.4s, v11.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v12.4s, v13.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v14.4s, v15.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v16.4s, v17.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v18.4s, v19.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v20.4s, v21.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v22.4s, v23.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v24.4s, v25.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v26.4s, v27.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v28.4s, v29.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v30.4s, v31.4s}, [x0]\n" - "cbz %w[nk], 4f\n" - "b 1f\n" - - "0:\n" - "movi v8.4s, #0x0\n" - "movi v9.4s, #0x0\n" - "movi v10.4s, #0x0\n" - "movi v11.4s, #0x0\n" - "movi v12.4s, #0x0\n" - "movi v13.4s, #0x0\n" - "movi v14.4s, #0x0\n" - "movi v15.4s, #0x0\n" - "movi v16.4s, #0x0\n" - "movi v17.4s, #0x0\n" - "movi v18.4s, #0x0\n" - "movi v19.4s, #0x0\n" - "movi v20.4s, #0x0\n" - "movi v21.4s, #0x0\n" - "movi v22.4s, #0x0\n" - "movi v23.4s, #0x0\n" - "movi v24.4s, #0x0\n" - "movi v25.4s, #0x0\n" - "movi v26.4s, #0x0\n" - "movi v27.4s, #0x0\n" - "movi v28.4s, #0x0\n" - "movi v29.4s, #0x0\n" - "movi v30.4s, #0x0\n" - "movi v31.4s, #0x0\n" - "cbz %w[nk], 4f\n" - - "1:\n" - "fmla v8.4s, v4.4s, v0.s[0]\n" - "fmla v10.4s, v4.4s, v0.s[1]\n" - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - "fmla v12.4s, v4.4s, v0.s[2]\n" - "fmla v14.4s, v4.4s, v0.s[3]\n" - "fmla v9.4s, v5.4s, v0.s[0]\n" - "fmla v11.4s, v5.4s, v0.s[1]\n" - "fmla v13.4s, v5.4s, v0.s[2]\n" - "fmla v15.4s, v5.4s, v0.s[3]\n" - - "fmla v16.4s, v4.4s, v1.s[0]\n" - "fmla v18.4s, v4.4s, v1.s[1]\n" - "ld1 {v2.4s}, [%[lhs_ptr]], #16\n" - "fmla v20.4s, v4.4s, v1.s[2]\n" - "fmla v22.4s, v4.4s, v1.s[3]\n" - "fmla v17.4s, v5.4s, v1.s[0]\n" - "fmla v19.4s, v5.4s, v1.s[1]\n" - "fmla v21.4s, v5.4s, v1.s[2]\n" - "fmla v23.4s, v5.4s, v1.s[3]\n" - - "ld1 {v6.4s, v7.4s}, [%[rhs_ptr]], #32\n" - - "fmla v24.4s, v4.4s, v2.s[0]\n" - "fmla v26.4s, v4.4s, v2.s[1]\n" - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - "fmla v28.4s, v4.4s, v2.s[2]\n" - "fmla v30.4s, v4.4s, v2.s[3]\n" - "fmla v25.4s, v5.4s, v2.s[0]\n" - "fmla v27.4s, v5.4s, v2.s[1]\n" - "fmla v29.4s, v5.4s, v2.s[2]\n" - "fmla v31.4s, v5.4s, v2.s[3]\n" - - "fmla v8.4s, v6.4s, v0.s[0]\n" - "fmla v10.4s, v6.4s, v0.s[1]\n" - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - "fmla v12.4s, v6.4s, v0.s[2]\n" - "fmla v14.4s, v6.4s, v0.s[3]\n" - "fmla v9.4s, v7.4s, v0.s[0]\n" - "fmla v11.4s, v7.4s, v0.s[1]\n" - "fmla v13.4s, v7.4s, v0.s[2]\n" - "fmla v15.4s, v7.4s, v0.s[3]\n" - - "fmla v16.4s, v6.4s, v1.s[0]\n" - "fmla v18.4s, v6.4s, v1.s[1]\n" - "ld1 {v2.4s}, [%[lhs_ptr]], #16\n" - "fmla v20.4s, v6.4s, v1.s[2]\n" - "fmla v22.4s, v6.4s, v1.s[3]\n" - "fmla v17.4s, v7.4s, v1.s[0]\n" - "fmla v19.4s, v7.4s, v1.s[1]\n" - "fmla v21.4s, v7.4s, v1.s[2]\n" - "fmla v23.4s, v7.4s, v1.s[3]\n" - - "ld1 {v4.4s, v5.4s}, [%[rhs_ptr]], #32\n" - - "fmla v24.4s, v6.4s, v2.s[0]\n" - "fmla v26.4s, v6.4s, v2.s[1]\n" - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - "fmla v28.4s, v6.4s, v2.s[2]\n" - "fmla v30.4s, v6.4s, v2.s[3]\n" - "fmla v25.4s, v7.4s, v2.s[0]\n" - "fmla v27.4s, v7.4s, v2.s[1]\n" - "subs %w[nk], %w[nk], #1\n" - "fmla v29.4s, v7.4s, v2.s[2]\n" - "fmla v31.4s, v7.4s, v2.s[3]\n" - "bne 1b\n" - - "4:\n" - "mov x0, %[res_ptr]\n" - "cbnz %[oddk], 2f\n" - - "fmla v8.4s, v4.4s, v0.s[0]\n" - "fmla v9.4s, v5.4s, v0.s[0]\n" - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - "fmla v10.4s, v4.4s, v0.s[1]\n" - "fmla v11.4s, v5.4s, v0.s[1]\n" - "fmla v12.4s, v4.4s, v0.s[2]\n" - "fmla v13.4s, v5.4s, v0.s[2]\n" - "fmla v14.4s, v4.4s, v0.s[3]\n" - "fmla v15.4s, v5.4s, v0.s[3]\n" - - "fmla v16.4s, v4.4s, v1.s[0]\n" - "fmla v17.4s, v5.4s, v1.s[0]\n" - "ld1 {v2.4s}, [%[lhs_ptr]], #16\n" - "fmla v18.4s, v4.4s, v1.s[1]\n" - "fmla v19.4s, v5.4s, v1.s[1]\n" - "fmla v20.4s, v4.4s, v1.s[2]\n" - "fmla v21.4s, v5.4s, v1.s[2]\n" - "fmla v22.4s, v4.4s, v1.s[3]\n" - "fmla v23.4s, v5.4s, v1.s[3]\n" - - "ld1 {v6.4s, v7.4s}, [%[rhs_ptr]], #32\n" - - "fmla v24.4s, v4.4s, v2.s[0]\n" - "fmla v25.4s, v5.4s, v2.s[0]\n" - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - "fmla v26.4s, v4.4s, v2.s[1]\n" - "fmla v27.4s, v5.4s, v2.s[1]\n" - "fmla v28.4s, v4.4s, v2.s[2]\n" - "fmla v29.4s, v5.4s, v2.s[2]\n" - "fmla v30.4s, v4.4s, v2.s[3]\n" - "fmla v31.4s, v5.4s, v2.s[3]\n" - - "fmla v8.4s, v6.4s, v0.s[0]\n" - "fmla v9.4s, v7.4s, v0.s[0]\n" - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - "st1 {v8.4s, v9.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v10.4s, v6.4s, v0.s[1]\n" - "fmla v11.4s, v7.4s, v0.s[1]\n" - "st1 {v10.4s, v11.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v12.4s, v6.4s, v0.s[2]\n" - "fmla v13.4s, v7.4s, v0.s[2]\n" - "st1 {v12.4s, v13.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v14.4s, v6.4s, v0.s[3]\n" - "fmla v15.4s, v7.4s, v0.s[3]\n" - "st1 {v14.4s, v15.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - - "fmla v16.4s, v6.4s, v1.s[0]\n" - "fmla v17.4s, v7.4s, v1.s[0]\n" - "ld1 {v2.4s}, [%[lhs_ptr]], #16\n" - "st1 {v16.4s, v17.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v18.4s, v6.4s, v1.s[1]\n" - "fmla v19.4s, v7.4s, v1.s[1]\n" - "st1 {v18.4s, v19.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v20.4s, v6.4s, v1.s[2]\n" - "fmla v21.4s, v7.4s, v1.s[2]\n" - "st1 {v20.4s, v21.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v22.4s, v6.4s, v1.s[3]\n" - "fmla v23.4s, v7.4s, v1.s[3]\n" - "st1 {v22.4s, v23.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - - "fmla v24.4s, v6.4s, v2.s[0]\n" - "fmla v25.4s, v7.4s, v2.s[0]\n" - "st1 {v24.4s, v25.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v26.4s, v6.4s, v2.s[1]\n" - "fmla v27.4s, v7.4s, v2.s[1]\n" - "st1 {v26.4s, v27.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v28.4s, v6.4s, v2.s[2]\n" - "fmla v29.4s, v7.4s, v2.s[2]\n" - "st1 {v28.4s, v29.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v30.4s, v6.4s, v2.s[3]\n" - "fmla v31.4s, v7.4s, v2.s[3]\n" - "b 3f\n" - - "2:\n" - "fmla v8.4s, v4.4s, v0.s[0]\n" - "fmla v9.4s, v5.4s, v0.s[0]\n" - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - "st1 {v8.4s, v9.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v10.4s, v4.4s, v0.s[1]\n" - "fmla v11.4s, v5.4s, v0.s[1]\n" - "st1 {v10.4s, v11.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v12.4s, v4.4s, v0.s[2]\n" - "fmla v13.4s, v5.4s, v0.s[2]\n" - "st1 {v12.4s, v13.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v14.4s, v4.4s, v0.s[3]\n" - "fmla v15.4s, v5.4s, v0.s[3]\n" - "st1 {v14.4s, v15.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - - "fmla v16.4s, v4.4s, v1.s[0]\n" - "fmla v17.4s, v5.4s, v1.s[0]\n" - "ld1 {v2.4s}, [%[lhs_ptr]], #16\n" - "st1 {v16.4s, v17.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v18.4s, v4.4s, v1.s[1]\n" - "fmla v19.4s, v5.4s, v1.s[1]\n" - "st1 {v18.4s, v19.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v20.4s, v4.4s, v1.s[2]\n" - "fmla v21.4s, v5.4s, v1.s[2]\n" - "st1 {v20.4s, v21.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v22.4s, v4.4s, v1.s[3]\n" - "fmla v23.4s, v5.4s, v1.s[3]\n" - "st1 {v22.4s, v23.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - - "fmla v24.4s, v4.4s, v2.s[0]\n" - "fmla v25.4s, v5.4s, v2.s[0]\n" - "st1 {v24.4s, v25.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v26.4s, v4.4s, v2.s[1]\n" - "fmla v27.4s, v5.4s, v2.s[1]\n" - "st1 {v26.4s, v27.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v28.4s, v4.4s, v2.s[2]\n" - "fmla v29.4s, v5.4s, v2.s[2]\n" - "st1 {v28.4s, v29.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v30.4s, v4.4s, v2.s[3]\n" - "fmla v31.4s, v5.4s, v2.s[3]\n" - - "3:\n" - "st1 {v30.4s, v31.4s}, [x0]\n" - : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), [res_ptr] "+r"(res_ptr), - [nk] "+r"(nk) - : [oddk] "r"(oddk), [k0] "r"(k0), [nstride] "r"(nstride) - : "x0", "v0", "v1", "v2", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", - "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", - "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); -} - -#ifdef BATCH_DILATION_FIX -static void sgemm_rowmajor_micro_kernel_4x24(const float *lhs_ptr, const float *rhs_ptr, - float *res_ptr, const int k, const int k0, - const int stride) -{ - int oddk = (k & 1); - int nk = ((k + 1) / 2) - 1; - - const int nstride = stride << 2; - - __asm __volatile("ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - - "cmp %[k0], #0\n" - "beq 0f\n" - - "mov x0, %[res_ptr]\n" - "mov x1, x0\n" - "ld1 {v8.4s, v9.4s, v10.4s}, [x1], #48\n" - "ld1 {v11.4s, v12.4s, v13.4s}, [x1]\n" - "add x0, x0, %[nstride]\n" - "mov x1, x0\n" - "ld1 {v14.4s, v15.4s, v16.4s}, [x1], #48\n" - "ld1 {v17.4s, v18.4s, v19.4s}, [x1]\n" - "add x0, x0, %[nstride]\n" - "mov x1, x0\n" - "ld1 {v20.4s, v21.4s, v22.4s}, [x1], #48\n" - "ld1 {v23.4s, v24.4s, v25.4s}, [x1]\n" - "add x0, x0, %[nstride]\n" - "mov x1, x0\n" - "ld1 {v26.4s, v27.4s, v28.4s}, [x1], #48\n" - "ld1 {v29.4s, v30.4s, v31.4s}, [x1]\n" - "cbz %w[nk], 4f\n" - "b 1f\n" - - "0:\n" - "movi v8.4s, #0x0\n" - "movi v9.4s, #0x0\n" - "movi v10.4s, #0x0\n" - "movi v11.4s, #0x0\n" - "movi v12.4s, #0x0\n" - "movi v13.4s, #0x0\n" - "movi v14.4s, #0x0\n" - "movi v15.4s, #0x0\n" - "movi v16.4s, #0x0\n" - "movi v17.4s, #0x0\n" - "movi v18.4s, #0x0\n" - "movi v19.4s, #0x0\n" - "movi v20.4s, #0x0\n" - "movi v21.4s, #0x0\n" - "movi v22.4s, #0x0\n" - "movi v23.4s, #0x0\n" - "movi v24.4s, #0x0\n" - "movi v25.4s, #0x0\n" - "movi v26.4s, #0x0\n" - "movi v27.4s, #0x0\n" - "movi v28.4s, #0x0\n" - "movi v29.4s, #0x0\n" - "movi v30.4s, #0x0\n" - "movi v31.4s, #0x0\n" - "cbz %w[nk], 4f\n" - - "1:\n" - "mov x0, v0.d[0]\n" - "cmp x0, #0\n" - "bne 5f\n" - "mov x0, v0.d[1]\n" - "cmp x0, #0\n" - "bne 5f\n" - "add %[rhs_ptr], %[rhs_ptr], #96\n" - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - "b 6f\n" - "5:\n" - "ld1 {v2.4s, v3.4s, v4.4s}, [%[rhs_ptr]], #48\n" - "fmla v8.4s, v2.4s, v0.s[0]\n" - "fmla v14.4s, v2.4s, v0.s[1]\n" - "fmla v20.4s, v2.4s, v0.s[2]\n" - "fmla v26.4s, v2.4s, v0.s[3]\n" - "ld1 {v5.4s, v6.4s, v7.4s}, [%[rhs_ptr]], #48\n" - "fmla v9.4s, v3.4s, v0.s[0]\n" - "fmla v15.4s, v3.4s, v0.s[1]\n" - "fmla v21.4s, v3.4s, v0.s[2]\n" - "fmla v27.4s, v3.4s, v0.s[3]\n" - "fmla v10.4s, v4.4s, v0.s[0]\n" - "fmla v16.4s, v4.4s, v0.s[1]\n" - "fmla v22.4s, v4.4s, v0.s[2]\n" - "fmla v28.4s, v4.4s, v0.s[3]\n" - - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - - "fmla v11.4s, v5.4s, v0.s[0]\n" - "fmla v17.4s, v5.4s, v0.s[1]\n" - "fmla v23.4s, v5.4s, v0.s[2]\n" - "fmla v29.4s, v5.4s, v0.s[3]\n" - "fmla v12.4s, v6.4s, v0.s[0]\n" - "fmla v18.4s, v6.4s, v0.s[1]\n" - "fmla v24.4s, v6.4s, v0.s[2]\n" - "fmla v30.4s, v6.4s, v0.s[3]\n" - "fmla v13.4s, v7.4s, v0.s[0]\n" - "fmla v19.4s, v7.4s, v0.s[1]\n" - "fmla v25.4s, v7.4s, v0.s[2]\n" - "fmla v31.4s, v7.4s, v0.s[3]\n" - - "6:\n" - "mov x0, v1.d[0]\n" - "cmp x0, #0\n" - "bne 7f\n" - "mov x0, v1.d[1]\n" - "cmp x0, #0\n" - "bne 7f\n" - "add %[rhs_ptr], %[rhs_ptr], #96\n" - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - "b 8f\n" - "7:\n" - "ld1 {v2.4s, v3.4s, v4.4s}, [%[rhs_ptr]], #48\n" - "fmla v8.4s, v2.4s, v1.s[0]\n" - "fmla v14.4s, v2.4s, v1.s[1]\n" - "fmla v20.4s, v2.4s, v1.s[2]\n" - "fmla v26.4s, v2.4s, v1.s[3]\n" - "ld1 {v5.4s, v6.4s, v7.4s}, [%[rhs_ptr]], #48\n" - "fmla v9.4s, v3.4s, v1.s[0]\n" - "fmla v15.4s, v3.4s, v1.s[1]\n" - "fmla v21.4s, v3.4s, v1.s[2]\n" - "fmla v27.4s, v3.4s, v1.s[3]\n" - "fmla v10.4s, v4.4s, v1.s[0]\n" - "fmla v16.4s, v4.4s, v1.s[1]\n" - "fmla v22.4s, v4.4s, v1.s[2]\n" - "fmla v28.4s, v4.4s, v1.s[3]\n" - - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - - "fmla v11.4s, v5.4s, v1.s[0]\n" - "fmla v17.4s, v5.4s, v1.s[1]\n" - "fmla v23.4s, v5.4s, v1.s[2]\n" - "fmla v29.4s, v5.4s, v1.s[3]\n" - "fmla v12.4s, v6.4s, v1.s[0]\n" - "fmla v18.4s, v6.4s, v1.s[1]\n" - "fmla v24.4s, v6.4s, v1.s[2]\n" - "fmla v30.4s, v6.4s, v1.s[3]\n" - "fmla v13.4s, v7.4s, v1.s[0]\n" - "fmla v19.4s, v7.4s, v1.s[1]\n" - "fmla v25.4s, v7.4s, v1.s[2]\n" - "fmla v31.4s, v7.4s, v1.s[3]\n" - - "8:\n" - "subs %w[nk], %w[nk], #1\n" - "bne 1b\n" - - "4:\n" - "mov x0, %[res_ptr]\n" - "cbnz %[oddk], 2f\n" - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v2.4s, v3.4s, v4.4s}, [%[rhs_ptr]], #48\n" - "ld1 {v5.4s, v6.4s, v7.4s}, [%[rhs_ptr]], #48\n" - - "fmla v8.4s, v2.4s, v0.s[0]\n" - "fmla v9.4s, v3.4s, v0.s[0]\n" - "fmla v10.4s, v4.4s, v0.s[0]\n" - "fmla v14.4s, v2.4s, v0.s[1]\n" - "fmla v15.4s, v3.4s, v0.s[1]\n" - "fmla v16.4s, v4.4s, v0.s[1]\n" - "fmla v20.4s, v2.4s, v0.s[2]\n" - "fmla v21.4s, v3.4s, v0.s[2]\n" - "fmla v22.4s, v4.4s, v0.s[2]\n" - "fmla v26.4s, v2.4s, v0.s[3]\n" - "fmla v27.4s, v3.4s, v0.s[3]\n" - "fmla v28.4s, v4.4s, v0.s[3]\n" - - "ld1 {v2.4s, v3.4s, v4.4s}, [%[rhs_ptr]], #48\n" - - "fmla v11.4s, v5.4s, v0.s[0]\n" - "fmla v12.4s, v6.4s, v0.s[0]\n" - "fmla v13.4s, v7.4s, v0.s[0]\n" - "fmla v17.4s, v5.4s, v0.s[1]\n" - "fmla v18.4s, v6.4s, v0.s[1]\n" - "fmla v19.4s, v7.4s, v0.s[1]\n" - "fmla v23.4s, v5.4s, v0.s[2]\n" - "fmla v24.4s, v6.4s, v0.s[2]\n" - "fmla v25.4s, v7.4s, v0.s[2]\n" - "fmla v29.4s, v5.4s, v0.s[3]\n" - "fmla v30.4s, v6.4s, v0.s[3]\n" - "fmla v31.4s, v7.4s, v0.s[3]\n" - - "ld1 {v5.4s, v6.4s, v7.4s}, [%[rhs_ptr]], #48\n" - - "fmla v8.4s, v2.4s, v1.s[0]\n" - "fmla v9.4s, v3.4s, v1.s[0]\n" - "fmla v10.4s, v4.4s, v1.s[0]\n" - "mov x1, x0\n" - "st1 {v8.4s, v9.4s, v10.4s}, [x1], #48\n" - "fmla v11.4s, v5.4s, v1.s[0]\n" - "fmla v12.4s, v6.4s, v1.s[0]\n" - "fmla v13.4s, v7.4s, v1.s[0]\n" - "st1 {v11.4s, v12.4s, v13.4s}, [x1]\n" - "fmla v14.4s, v2.4s, v1.s[1]\n" - "fmla v15.4s, v3.4s, v1.s[1]\n" - "fmla v16.4s, v4.4s, v1.s[1]\n" - "add x0, x0, %[nstride]\n" - "mov x1, x0\n" - "st1 {v14.4s, v15.4s, v16.4s}, [x1], #48\n" - "fmla v17.4s, v5.4s, v1.s[1]\n" - "fmla v18.4s, v6.4s, v1.s[1]\n" - "fmla v19.4s, v7.4s, v1.s[1]\n" - "st1 {v17.4s, v18.4s, v19.4s}, [x1]\n" - "fmla v20.4s, v2.4s, v1.s[2]\n" - "fmla v21.4s, v3.4s, v1.s[2]\n" - "fmla v22.4s, v4.4s, v1.s[2]\n" - "add x0, x0, %[nstride]\n" - "mov x1, x0\n" - "st1 {v20.4s, v21.4s, v22.4s}, [x1], #48\n" - "fmla v23.4s, v5.4s, v1.s[2]\n" - "fmla v24.4s, v6.4s, v1.s[2]\n" - "fmla v25.4s, v7.4s, v1.s[2]\n" - "st1 {v23.4s, v24.4s, v25.4s}, [x1]\n" - "fmla v26.4s, v2.4s, v1.s[3]\n" - "fmla v27.4s, v3.4s, v1.s[3]\n" - "fmla v28.4s, v4.4s, v1.s[3]\n" - "add x0, x0, %[nstride]\n" - "mov x1, x0\n" - "st1 {v26.4s, v27.4s, v28.4s}, [x1], #48\n" - "fmla v29.4s, v5.4s, v1.s[3]\n" - "fmla v30.4s, v6.4s, v1.s[3]\n" - "fmla v31.4s, v7.4s, v1.s[3]\n" - "b 3f\n" - - "2:\n" - "ld1 {v2.4s, v3.4s, v4.4s}, [%[rhs_ptr]], #48\n" - "ld1 {v5.4s, v6.4s, v7.4s}, [%[rhs_ptr]], #48\n" - - "fmla v8.4s, v2.4s, v0.s[0]\n" - "fmla v9.4s, v3.4s, v0.s[0]\n" - "fmla v10.4s, v4.4s, v0.s[0]\n" - "mov x1, x0\n" - "st1 {v8.4s, v9.4s, v10.4s}, [x1], #48\n" - "fmla v11.4s, v5.4s, v0.s[0]\n" - "fmla v12.4s, v6.4s, v0.s[0]\n" - "fmla v13.4s, v7.4s, v0.s[0]\n" - "st1 {v11.4s, v12.4s, v13.4s}, [x1]\n" - "fmla v14.4s, v2.4s, v0.s[1]\n" - "fmla v15.4s, v3.4s, v0.s[1]\n" - "fmla v16.4s, v4.4s, v0.s[1]\n" - "add x0, x0, %[nstride]\n" - "mov x1, x0\n" - "st1 {v14.4s, v15.4s, v16.4s}, [x1], #48\n" - "fmla v17.4s, v5.4s, v0.s[1]\n" - "fmla v18.4s, v6.4s, v0.s[1]\n" - "fmla v19.4s, v7.4s, v0.s[1]\n" - "st1 {v17.4s, v18.4s, v19.4s}, [x1]\n" - "fmla v20.4s, v2.4s, v0.s[2]\n" - "fmla v21.4s, v3.4s, v0.s[2]\n" - "fmla v22.4s, v4.4s, v0.s[2]\n" - "add x0, x0, %[nstride]\n" - "mov x1, x0\n" - "st1 {v20.4s, v21.4s, v22.4s}, [x1], #48\n" - "fmla v23.4s, v5.4s, v0.s[2]\n" - "fmla v24.4s, v6.4s, v0.s[2]\n" - "fmla v25.4s, v7.4s, v0.s[2]\n" - "st1 {v23.4s, v24.4s, v25.4s}, [x1]\n" - "fmla v26.4s, v2.4s, v0.s[3]\n" - "fmla v27.4s, v3.4s, v0.s[3]\n" - "fmla v28.4s, v4.4s, v0.s[3]\n" - "add x0, x0, %[nstride]\n" - "mov x1, x0\n" - "st1 {v26.4s, v27.4s, v28.4s}, [x1], #48\n" - "fmla v29.4s, v5.4s, v0.s[3]\n" - "fmla v30.4s, v6.4s, v0.s[3]\n" - "fmla v31.4s, v7.4s, v0.s[3]\n" - "3:\n" - "st1 {v29.4s, v30.4s, v31.4s}, [x1]\n" - : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), [res_ptr] "+r"(res_ptr), - [nk] "+r"(nk) - : [oddk] "r"(oddk), [k0] "r"(k0), [nstride] "r"(nstride) - : "x0", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", - "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); -} -#else // BATCH_DILATION_FIX -static void sgemm_rowmajor_micro_kernel_4x24(const float *lhs_ptr, const float *rhs_ptr, - float *res_ptr, const int k, const int k0, - const int stride) -{ - int oddk = (k & 1); - int nk = ((k + 1) / 2) - 1; - - const int nstride = stride << 2; - - __asm __volatile("ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" - "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" - "ld1 {v4.4s}, [%[rhs_ptr]], #16\n" - - "cmp %[k0], #0\n" - "beq 0f\n" - - "mov x0, %[res_ptr]\n" - "mov x1, x0\n" - "ld1 {v8.4s, v9.4s, v10.4s}, [x1], #48\n" - "ld1 {v11.4s, v12.4s, v13.4s}, [x1]\n" - "add x0, x0, %[nstride]\n" - "mov x1, x0\n" - "ld1 {v14.4s, v15.4s, v16.4s}, [x1], #48\n" - "ld1 {v17.4s, v18.4s, v19.4s}, [x1]\n" - "add x0, x0, %[nstride]\n" - "mov x1, x0\n" - "ld1 {v20.4s, v21.4s, v22.4s}, [x1], #48\n" - "ld1 {v23.4s, v24.4s, v25.4s}, [x1]\n" - "add x0, x0, %[nstride]\n" - "mov x1, x0\n" - "ld1 {v26.4s, v27.4s, v28.4s}, [x1], #48\n" - "ld1 {v29.4s, v30.4s, v31.4s}, [x1]\n" - "cbz %w[nk], 4f\n" - "b 1f\n" - - "0:\n" - "movi v8.4s, #0x0\n" - "movi v9.4s, #0x0\n" - "movi v10.4s, #0x0\n" - "movi v11.4s, #0x0\n" - "movi v12.4s, #0x0\n" - "movi v13.4s, #0x0\n" - "movi v14.4s, #0x0\n" - "movi v15.4s, #0x0\n" - "movi v16.4s, #0x0\n" - "movi v17.4s, #0x0\n" - "movi v18.4s, #0x0\n" - "movi v19.4s, #0x0\n" - "movi v20.4s, #0x0\n" - "movi v21.4s, #0x0\n" - "movi v22.4s, #0x0\n" - "movi v23.4s, #0x0\n" - "movi v24.4s, #0x0\n" - "movi v25.4s, #0x0\n" - "movi v26.4s, #0x0\n" - "movi v27.4s, #0x0\n" - "movi v28.4s, #0x0\n" - "movi v29.4s, #0x0\n" - "movi v30.4s, #0x0\n" - "movi v31.4s, #0x0\n" - "cbz %w[nk], 4f\n" - - "1:\n" - "fmla v8.4s, v2.4s, v0.s[0]\n" - "fmla v14.4s, v2.4s, v0.s[1]\n" - "fmla v20.4s, v2.4s, v0.s[2]\n" - "fmla v26.4s, v2.4s, v0.s[3]\n" - "fmla v9.4s, v3.4s, v0.s[0]\n" - "fmla v15.4s, v3.4s, v0.s[1]\n" - "ld1 {v5.4s, v6.4s, v7.4s}, [%[rhs_ptr]], #48\n" - "fmla v21.4s, v3.4s, v0.s[2]\n" - "fmla v27.4s, v3.4s, v0.s[3]\n" - "fmla v10.4s, v4.4s, v0.s[0]\n" - "fmla v16.4s, v4.4s, v0.s[1]\n" - "fmla v22.4s, v4.4s, v0.s[2]\n" - "fmla v28.4s, v4.4s, v0.s[3]\n" - - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - - "fmla v11.4s, v5.4s, v0.s[0]\n" - "fmla v17.4s, v5.4s, v0.s[1]\n" - "fmla v23.4s, v5.4s, v0.s[2]\n" - "fmla v29.4s, v5.4s, v0.s[3]\n" - "fmla v12.4s, v6.4s, v0.s[0]\n" - "fmla v18.4s, v6.4s, v0.s[1]\n" - "ld1 {v2.4s, v3.4s, v4.4s}, [%[rhs_ptr]], #48\n" - "fmla v24.4s, v6.4s, v0.s[2]\n" - "fmla v30.4s, v6.4s, v0.s[3]\n" - "fmla v13.4s, v7.4s, v0.s[0]\n" - "fmla v19.4s, v7.4s, v0.s[1]\n" - "fmla v25.4s, v7.4s, v0.s[2]\n" - "fmla v31.4s, v7.4s, v0.s[3]\n" - - "fmla v8.4s, v2.4s, v1.s[0]\n" - "fmla v14.4s, v2.4s, v1.s[1]\n" - "fmla v20.4s, v2.4s, v1.s[2]\n" - "fmla v26.4s, v2.4s, v1.s[3]\n" - "fmla v9.4s, v3.4s, v1.s[0]\n" - "fmla v15.4s, v3.4s, v1.s[1]\n" - "ld1 {v5.4s, v6.4s, v7.4s}, [%[rhs_ptr]], #48\n" - "fmla v21.4s, v3.4s, v1.s[2]\n" - "fmla v27.4s, v3.4s, v1.s[3]\n" - "fmla v10.4s, v4.4s, v1.s[0]\n" - "fmla v16.4s, v4.4s, v1.s[1]\n" - "fmla v22.4s, v4.4s, v1.s[2]\n" - "fmla v28.4s, v4.4s, v1.s[3]\n" - - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - - "fmla v11.4s, v5.4s, v1.s[0]\n" - "fmla v17.4s, v5.4s, v1.s[1]\n" - "fmla v23.4s, v5.4s, v1.s[2]\n" - "fmla v29.4s, v5.4s, v1.s[3]\n" - "fmla v12.4s, v6.4s, v1.s[0]\n" - "fmla v18.4s, v6.4s, v1.s[1]\n" - "ld1 {v2.4s, v3.4s, v4.4s}, [%[rhs_ptr]], #48\n" - "fmla v24.4s, v6.4s, v1.s[2]\n" - "fmla v30.4s, v6.4s, v1.s[3]\n" - "fmla v13.4s, v7.4s, v1.s[0]\n" - "fmla v19.4s, v7.4s, v1.s[1]\n" - "subs %w[nk], %w[nk], #1\n" - "fmla v25.4s, v7.4s, v1.s[2]\n" - "fmla v31.4s, v7.4s, v1.s[3]\n" - "bne 1b\n" - - "4:\n" - "mov x0, %[res_ptr]\n" - "cbnz %[oddk], 2f\n" - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v5.4s, v6.4s, v7.4s}, [%[rhs_ptr]], #48\n" - - "fmla v8.4s, v2.4s, v0.s[0]\n" - "fmla v9.4s, v3.4s, v0.s[0]\n" - "fmla v10.4s, v4.4s, v0.s[0]\n" - "fmla v14.4s, v2.4s, v0.s[1]\n" - "fmla v15.4s, v3.4s, v0.s[1]\n" - "fmla v16.4s, v4.4s, v0.s[1]\n" - "fmla v20.4s, v2.4s, v0.s[2]\n" - "fmla v21.4s, v3.4s, v0.s[2]\n" - "fmla v22.4s, v4.4s, v0.s[2]\n" - "fmla v26.4s, v2.4s, v0.s[3]\n" - "fmla v27.4s, v3.4s, v0.s[3]\n" - "fmla v28.4s, v4.4s, v0.s[3]\n" - - "ld1 {v2.4s, v3.4s, v4.4s}, [%[rhs_ptr]], #48\n" - - "fmla v11.4s, v5.4s, v0.s[0]\n" - "fmla v12.4s, v6.4s, v0.s[0]\n" - "fmla v13.4s, v7.4s, v0.s[0]\n" - "fmla v17.4s, v5.4s, v0.s[1]\n" - "fmla v18.4s, v6.4s, v0.s[1]\n" - "fmla v19.4s, v7.4s, v0.s[1]\n" - "fmla v23.4s, v5.4s, v0.s[2]\n" - "fmla v24.4s, v6.4s, v0.s[2]\n" - "fmla v25.4s, v7.4s, v0.s[2]\n" - "fmla v29.4s, v5.4s, v0.s[3]\n" - "fmla v30.4s, v6.4s, v0.s[3]\n" - "fmla v31.4s, v7.4s, v0.s[3]\n" - - "ld1 {v5.4s, v6.4s, v7.4s}, [%[rhs_ptr]], #48\n" - - "fmla v8.4s, v2.4s, v1.s[0]\n" - "fmla v9.4s, v3.4s, v1.s[0]\n" - "fmla v10.4s, v4.4s, v1.s[0]\n" - "mov x1, x0\n" - "st1 {v8.4s, v9.4s, v10.4s}, [x1], #48\n" - "fmla v11.4s, v5.4s, v1.s[0]\n" - "fmla v12.4s, v6.4s, v1.s[0]\n" - "fmla v13.4s, v7.4s, v1.s[0]\n" - "st1 {v11.4s, v12.4s, v13.4s}, [x1]\n" - "fmla v14.4s, v2.4s, v1.s[1]\n" - "fmla v15.4s, v3.4s, v1.s[1]\n" - "fmla v16.4s, v4.4s, v1.s[1]\n" - "add x0, x0, %[nstride]\n" - "mov x1, x0\n" - "st1 {v14.4s, v15.4s, v16.4s}, [x1], #48\n" - "fmla v17.4s, v5.4s, v1.s[1]\n" - "fmla v18.4s, v6.4s, v1.s[1]\n" - "fmla v19.4s, v7.4s, v1.s[1]\n" - "st1 {v17.4s, v18.4s, v19.4s}, [x1]\n" - "fmla v20.4s, v2.4s, v1.s[2]\n" - "fmla v21.4s, v3.4s, v1.s[2]\n" - "fmla v22.4s, v4.4s, v1.s[2]\n" - "add x0, x0, %[nstride]\n" - "mov x1, x0\n" - "st1 {v20.4s, v21.4s, v22.4s}, [x1], #48\n" - "fmla v23.4s, v5.4s, v1.s[2]\n" - "fmla v24.4s, v6.4s, v1.s[2]\n" - "fmla v25.4s, v7.4s, v1.s[2]\n" - "st1 {v23.4s, v24.4s, v25.4s}, [x1]\n" - "fmla v26.4s, v2.4s, v1.s[3]\n" - "fmla v27.4s, v3.4s, v1.s[3]\n" - "fmla v28.4s, v4.4s, v1.s[3]\n" - "add x0, x0, %[nstride]\n" - "mov x1, x0\n" - "st1 {v26.4s, v27.4s, v28.4s}, [x1], #48\n" - "fmla v29.4s, v5.4s, v1.s[3]\n" - "fmla v30.4s, v6.4s, v1.s[3]\n" - "fmla v31.4s, v7.4s, v1.s[3]\n" - "b 3f\n" - - "2:\n" - "ld1 {v5.4s, v6.4s, v7.4s}, [%[rhs_ptr]], #48\n" - - "fmla v8.4s, v2.4s, v0.s[0]\n" - "fmla v9.4s, v3.4s, v0.s[0]\n" - "fmla v10.4s, v4.4s, v0.s[0]\n" - "mov x1, x0\n" - "st1 {v8.4s, v9.4s, v10.4s}, [x1], #48\n" - "fmla v11.4s, v5.4s, v0.s[0]\n" - "fmla v12.4s, v6.4s, v0.s[0]\n" - "fmla v13.4s, v7.4s, v0.s[0]\n" - "st1 {v11.4s, v12.4s, v13.4s}, [x1]\n" - "fmla v14.4s, v2.4s, v0.s[1]\n" - "fmla v15.4s, v3.4s, v0.s[1]\n" - "fmla v16.4s, v4.4s, v0.s[1]\n" - "add x0, x0, %[nstride]\n" - "mov x1, x0\n" - "st1 {v14.4s, v15.4s, v16.4s}, [x1], #48\n" - "fmla v17.4s, v5.4s, v0.s[1]\n" - "fmla v18.4s, v6.4s, v0.s[1]\n" - "fmla v19.4s, v7.4s, v0.s[1]\n" - "st1 {v17.4s, v18.4s, v19.4s}, [x1]\n" - "fmla v20.4s, v2.4s, v0.s[2]\n" - "fmla v21.4s, v3.4s, v0.s[2]\n" - "fmla v22.4s, v4.4s, v0.s[2]\n" - "add x0, x0, %[nstride]\n" - "mov x1, x0\n" - "st1 {v20.4s, v21.4s, v22.4s}, [x1], #48\n" - "fmla v23.4s, v5.4s, v0.s[2]\n" - "fmla v24.4s, v6.4s, v0.s[2]\n" - "fmla v25.4s, v7.4s, v0.s[2]\n" - "st1 {v23.4s, v24.4s, v25.4s}, [x1]\n" - "fmla v26.4s, v2.4s, v0.s[3]\n" - "fmla v27.4s, v3.4s, v0.s[3]\n" - "fmla v28.4s, v4.4s, v0.s[3]\n" - "add x0, x0, %[nstride]\n" - "mov x1, x0\n" - "st1 {v26.4s, v27.4s, v28.4s}, [x1], #48\n" - "fmla v29.4s, v5.4s, v0.s[3]\n" - "fmla v30.4s, v6.4s, v0.s[3]\n" - "fmla v31.4s, v7.4s, v0.s[3]\n" - "3:\n" - "st1 {v29.4s, v30.4s, v31.4s}, [x1]\n" - : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), [res_ptr] "+r"(res_ptr), - [nk] "+r"(nk) - : [oddk] "r"(oddk), [k0] "r"(k0), [nstride] "r"(nstride) - : "x0", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", - "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); -} -#endif // BATCH_DILATION_FIX - -static void sgemm_rowmajor_micro_kernel_24x4(const float *lhs_ptr, const float *rhs_ptr, - float *res_ptr, const int k, const int k0, - const int stride) -{ - int oddk = (k & 1); - int nk = ((k + 1) / 2) - 1; - - const int nstride = stride << 2; - - __asm __volatile("ld1 {v0.4s, v1.4s, v2.4s}, [%[lhs_ptr]], #48\n" - "ld1 {v6.4s}, [%[rhs_ptr]], #16\n" - - "cmp %[k0], #0\n" - "beq 0f\n" - - "mov x0, %[res_ptr]\n" - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v9.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v11.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v12.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v13.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v14.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v15.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v16.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v17.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v18.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v19.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v20.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v21.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v22.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v23.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v24.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v25.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v26.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v27.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v28.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v29.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v30.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "ld1 {v31.4s}, [x0]\n" - "cbz %w[nk], 4f\n" - "b 1f\n" - - "0:\n" - "movi v8.4s, #0x0\n" - "movi v9.4s, #0x0\n" - "movi v10.4s, #0x0\n" - "movi v11.4s, #0x0\n" - "movi v12.4s, #0x0\n" - "movi v13.4s, #0x0\n" - "movi v14.4s, #0x0\n" - "movi v15.4s, #0x0\n" - "movi v16.4s, #0x0\n" - "movi v17.4s, #0x0\n" - "movi v18.4s, #0x0\n" - "movi v19.4s, #0x0\n" - "movi v20.4s, #0x0\n" - "movi v21.4s, #0x0\n" - "movi v22.4s, #0x0\n" - "movi v23.4s, #0x0\n" - "movi v24.4s, #0x0\n" - "movi v25.4s, #0x0\n" - "movi v26.4s, #0x0\n" - "movi v27.4s, #0x0\n" - "movi v28.4s, #0x0\n" - "movi v29.4s, #0x0\n" - "movi v30.4s, #0x0\n" - "movi v31.4s, #0x0\n" - "cbz %w[nk], 4f\n" - - "1:\n" - "ld1 {v3.4s, v4.4s, v5.4s}, [%[lhs_ptr]], #48\n" - "fmla v8.4s, v6.4s, v0.s[0]\n" - "fmla v9.4s, v6.4s, v0.s[1]\n" - "fmla v10.4s, v6.4s, v0.s[2]\n" - "fmla v11.4s, v6.4s, v0.s[3]\n" - "fmla v12.4s, v6.4s, v1.s[0]\n" - "fmla v13.4s, v6.4s, v1.s[1]\n" - "ld1 {v7.4s}, [%[rhs_ptr]], #16\n" - "fmla v14.4s, v6.4s, v1.s[2]\n" - "fmla v15.4s, v6.4s, v1.s[3]\n" - "fmla v16.4s, v6.4s, v2.s[0]\n" - "fmla v17.4s, v6.4s, v2.s[1]\n" - "fmla v18.4s, v6.4s, v2.s[2]\n" - "fmla v19.4s, v6.4s, v2.s[3]\n" - "ld1 {v0.4s, v1.4s, v2.4s}, [%[lhs_ptr]], #48\n" - "fmla v20.4s, v6.4s, v3.s[0]\n" - "fmla v21.4s, v6.4s, v3.s[1]\n" - "fmla v22.4s, v6.4s, v3.s[2]\n" - "fmla v23.4s, v6.4s, v3.s[3]\n" - "fmla v24.4s, v6.4s, v4.s[0]\n" - "fmla v25.4s, v6.4s, v4.s[1]\n" - "fmla v26.4s, v6.4s, v4.s[2]\n" - "fmla v27.4s, v6.4s, v4.s[3]\n" - "fmla v28.4s, v6.4s, v5.s[0]\n" - "fmla v29.4s, v6.4s, v5.s[1]\n" - "fmla v30.4s, v6.4s, v5.s[2]\n" - "fmla v31.4s, v6.4s, v5.s[3]\n" - - "ld1 {v3.4s, v4.4s, v5.4s}, [%[lhs_ptr]], #48\n" - "fmla v8.4s, v7.4s, v0.s[0]\n" - "fmla v9.4s, v7.4s, v0.s[1]\n" - "fmla v10.4s, v7.4s, v0.s[2]\n" - "fmla v11.4s, v7.4s, v0.s[3]\n" - "fmla v12.4s, v7.4s, v1.s[0]\n" - "fmla v13.4s, v7.4s, v1.s[1]\n" - "ld1 {v6.4s}, [%[rhs_ptr]], #16\n" - "fmla v14.4s, v7.4s, v1.s[2]\n" - "fmla v15.4s, v7.4s, v1.s[3]\n" - "fmla v16.4s, v7.4s, v2.s[0]\n" - "fmla v17.4s, v7.4s, v2.s[1]\n" - "fmla v18.4s, v7.4s, v2.s[2]\n" - "fmla v19.4s, v7.4s, v2.s[3]\n" - "ld1 {v0.4s, v1.4s, v2.4s}, [%[lhs_ptr]], #48\n" - "fmla v20.4s, v7.4s, v3.s[0]\n" - "fmla v21.4s, v7.4s, v3.s[1]\n" - "fmla v22.4s, v7.4s, v3.s[2]\n" - "fmla v23.4s, v7.4s, v3.s[3]\n" - "fmla v24.4s, v7.4s, v4.s[0]\n" - "fmla v25.4s, v7.4s, v4.s[1]\n" - "fmla v26.4s, v7.4s, v4.s[2]\n" - "fmla v27.4s, v7.4s, v4.s[3]\n" - "fmla v28.4s, v7.4s, v5.s[0]\n" - "fmla v29.4s, v7.4s, v5.s[1]\n" - "subs %w[nk], %w[nk], #1\n" - "fmla v30.4s, v7.4s, v5.s[2]\n" - "fmla v31.4s, v7.4s, v5.s[3]\n" - "bne 1b\n" - - "4:\n" - "mov x0, %[res_ptr]\n" - "cbnz %[oddk], 2f\n" - - "ld1 {v3.4s, v4.4s, v5.4s}, [%[lhs_ptr]], #48\n" - "fmla v8.4s, v6.4s, v0.s[0]\n" - "fmla v9.4s, v6.4s, v0.s[1]\n" - "fmla v10.4s, v6.4s, v0.s[2]\n" - "fmla v11.4s, v6.4s, v0.s[3]\n" - "fmla v12.4s, v6.4s, v1.s[0]\n" - "fmla v13.4s, v6.4s, v1.s[1]\n" - "ld1 {v7.4s}, [%[rhs_ptr]], #16\n" - "fmla v14.4s, v6.4s, v1.s[2]\n" - "fmla v15.4s, v6.4s, v1.s[3]\n" - "fmla v16.4s, v6.4s, v2.s[0]\n" - "fmla v17.4s, v6.4s, v2.s[1]\n" - "fmla v18.4s, v6.4s, v2.s[2]\n" - "fmla v19.4s, v6.4s, v2.s[3]\n" - "ld1 {v0.4s, v1.4s, v2.4s}, [%[lhs_ptr]], #48\n" - "fmla v20.4s, v6.4s, v3.s[0]\n" - "fmla v21.4s, v6.4s, v3.s[1]\n" - "fmla v22.4s, v6.4s, v3.s[2]\n" - "fmla v23.4s, v6.4s, v3.s[3]\n" - "fmla v24.4s, v6.4s, v4.s[0]\n" - "fmla v25.4s, v6.4s, v4.s[1]\n" - "fmla v26.4s, v6.4s, v4.s[2]\n" - "fmla v27.4s, v6.4s, v4.s[3]\n" - "fmla v28.4s, v6.4s, v5.s[0]\n" - "fmla v29.4s, v6.4s, v5.s[1]\n" - "fmla v30.4s, v6.4s, v5.s[2]\n" - "fmla v31.4s, v6.4s, v5.s[3]\n" - - "ld1 {v3.4s, v4.4s, v5.4s}, [%[lhs_ptr]], #48\n" - "fmla v8.4s, v7.4s, v0.s[0]\n" - "fmla v9.4s, v7.4s, v0.s[1]\n" - "st1 {v8.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v10.4s, v7.4s, v0.s[2]\n" - "st1 {v9.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v11.4s, v7.4s, v0.s[3]\n" - "st1 {v10.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v12.4s, v7.4s, v1.s[0]\n" - "st1 {v11.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v13.4s, v7.4s, v1.s[1]\n" - "st1 {v12.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v14.4s, v7.4s, v1.s[2]\n" - "st1 {v13.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v15.4s, v7.4s, v1.s[3]\n" - "st1 {v14.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v16.4s, v7.4s, v2.s[0]\n" - "st1 {v15.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v17.4s, v7.4s, v2.s[1]\n" - "st1 {v16.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v18.4s, v7.4s, v2.s[2]\n" - "st1 {v17.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v19.4s, v7.4s, v2.s[3]\n" - "st1 {v18.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v20.4s, v7.4s, v3.s[0]\n" - "st1 {v19.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v21.4s, v7.4s, v3.s[1]\n" - "st1 {v20.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v22.4s, v7.4s, v3.s[2]\n" - "st1 {v21.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v23.4s, v7.4s, v3.s[3]\n" - "st1 {v22.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v24.4s, v7.4s, v4.s[0]\n" - "st1 {v23.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v25.4s, v7.4s, v4.s[1]\n" - "st1 {v24.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v26.4s, v7.4s, v4.s[2]\n" - "st1 {v25.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v27.4s, v7.4s, v4.s[3]\n" - "st1 {v26.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v28.4s, v7.4s, v5.s[0]\n" - "st1 {v27.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v29.4s, v7.4s, v5.s[1]\n" - "st1 {v28.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v30.4s, v7.4s, v5.s[2]\n" - "st1 {v29.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v31.4s, v7.4s, v5.s[3]\n" - "st1 {v30.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "b 3f\n" - - "2:\n" - "ld1 {v3.4s, v4.4s, v5.4s}, [%[lhs_ptr]], #48\n" - "fmla v8.4s, v6.4s, v0.s[0]\n" - "fmla v9.4s, v6.4s, v0.s[1]\n" - "st1 {v8.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v10.4s, v6.4s, v0.s[2]\n" - "st1 {v9.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v11.4s, v6.4s, v0.s[3]\n" - "st1 {v10.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v12.4s, v6.4s, v1.s[0]\n" - "st1 {v11.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v13.4s, v6.4s, v1.s[1]\n" - "st1 {v12.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v14.4s, v6.4s, v1.s[2]\n" - "st1 {v13.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v15.4s, v6.4s, v1.s[3]\n" - "st1 {v14.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v16.4s, v6.4s, v2.s[0]\n" - "st1 {v15.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v17.4s, v6.4s, v2.s[1]\n" - "st1 {v16.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v18.4s, v6.4s, v2.s[2]\n" - "st1 {v17.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v19.4s, v6.4s, v2.s[3]\n" - "st1 {v18.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v20.4s, v6.4s, v3.s[0]\n" - "st1 {v19.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v21.4s, v6.4s, v3.s[1]\n" - "st1 {v20.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v22.4s, v6.4s, v3.s[2]\n" - "st1 {v21.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v23.4s, v6.4s, v3.s[3]\n" - "st1 {v22.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v24.4s, v6.4s, v4.s[0]\n" - "st1 {v23.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v25.4s, v6.4s, v4.s[1]\n" - "st1 {v24.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v26.4s, v6.4s, v4.s[2]\n" - "st1 {v25.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v27.4s, v6.4s, v4.s[3]\n" - "st1 {v26.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v28.4s, v6.4s, v5.s[0]\n" - "st1 {v27.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v29.4s, v6.4s, v5.s[1]\n" - "st1 {v28.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v30.4s, v6.4s, v5.s[2]\n" - "st1 {v29.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "fmla v31.4s, v6.4s, v5.s[3]\n" - "st1 {v30.4s}, [x0]\n" - "add x0, x0, %[nstride]\n" - "3:\n" - "st1 {v31.4s}, [x0]\n" - : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), [res_ptr] "+r"(res_ptr), - [nk] "+r"(nk) - : [oddk] "r"(oddk), [k0] "r"(k0), [nstride] "r"(nstride) - : "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", - "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", - "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); -} - -#else // __aarch64__ -static void sgemm_rowmajor_micro_kernel_6x8(const float *lhs_ptr, const float *rhs_ptr, - float *res_ptr, const int k, const int k0, - const int stride) -{ - int nk = k >> 2; - int rk = k & 3; - - const int nstride = stride << 2; - - if (rk == 0) - { - nk--; - rk = 4; - } - - __asm __volatile("vld1.32 {d0-d1}, [%[lhs_ptr]]!\n" - "vld1.32 {d4-d5}, [%[rhs_ptr]]!\n" - - "cmp %[k0], #0\n" - "beq 0f\n" - - "mov r0, %[res_ptr]\n" - - "vld1.f32 {d8-d11}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vld1.f32 {d12-d15}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vld1.f32 {d16-d19}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vld1.f32 {d20-d23}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vld1.f32 {d24-d27}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vld1.f32 {d28-d31}, [r0]\n" - "b 1f\n" - - "0:\n" - "vmov.i32 q4, #0\n" - "vmov.i32 q5, #0\n" - "vmov.i32 q6, #0\n" - "pld [%[lhs_ptr], #48]\n" - "vmov.i32 q7, #0\n" - "pld [%[rhs_ptr], #48]\n" - "vmov.i32 q8, #0\n" - "pld [%[lhs_ptr], #112]\n" - "vmov.i32 q9, #0\n" - "pld [%[rhs_ptr], #112]\n" - "vmov.i32 q10, #0\n" - "vmov.i32 q11, #0\n" - "vmov.i32 q12, #0\n" - "vmov.i32 q13, #0\n" - "pld [%[lhs_ptr], #176]\n" - "vmov.i32 q14, #0\n" - "pld [%[rhs_ptr], #176]\n" - "vmov.i32 q15, #0\n" - - "1:\n" - "cmp %[nk], #0\n" - "beq 6f\n" - "vmla.f32 q4, q2, d0[0]\n" - "vld1.32 {d2-d3}, [%[lhs_ptr]]!\n" - "vmla.f32 q6, q2, d0[1]\n" - "vmla.f32 q8, q2, d1[0]\n" - "vld1.32 {d6-d7}, [%[rhs_ptr]]!\n" - "vmla.f32 q10, q2, d1[1]\n" - "vmla.f32 q12, q2, d2[0]\n" - "vmla.f32 q14, q2, d2[1]\n" - "vld1.32 {d4-d5}, [%[rhs_ptr]]!\n" - - "vmla.f32 q5, q3, d0[0]\n" - "vmla.f32 q7, q3, d0[1]\n" - "vmla.f32 q9, q3, d1[0]\n" - "vmla.f32 q11, q3, d1[1]\n" - "vld1.32 {d0-d1}, [%[lhs_ptr]]!\n" - "vmla.f32 q13, q3, d2[0]\n" - "vmla.f32 q15, q3, d2[1]\n" - "vld1.32 {d6-d7}, [%[rhs_ptr]]!\n" - - "vmla.f32 q4, q2, d3[0]\n" - "subs %[nk], %[nk], #1\n" - "vmla.f32 q6, q2, d3[1]\n" - "pld [%[lhs_ptr], #208]\n" - "vmla.f32 q8, q2, d0[0]\n" - "vmla.f32 q10, q2, d0[1]\n" - "pld [%[rhs_ptr], #192]\n" - "vmla.f32 q12, q2, d1[0]\n" - "vmla.f32 q14, q2, d1[1]\n" - "vld1.32 {d4-d5}, [%[rhs_ptr]]!\n" - - "vmla.f32 q5, q3, d3[0]\n" - "vmla.f32 q7, q3, d3[1]\n" - "vld1.32 {d2-d3}, [%[lhs_ptr]]!\n" - "vmla.f32 q9, q3, d0[0]\n" - "vmla.f32 q11, q3, d0[1]\n" - "vmla.f32 q13, q3, d1[0]\n" - "vmla.f32 q15, q3, d1[1]\n" - "vld1.32 {d0-d1}, [%[lhs_ptr]]!\n" - - "vmla.f32 q4, q2, d2[0]\n" - "vmla.f32 q6, q2, d2[1]\n" - "vld1.32 {d6-d7}, [%[rhs_ptr]]!\n" - "vmla.f32 q8, q2, d3[0]\n" - "vmla.f32 q10, q2, d3[1]\n" - "pld [%[lhs_ptr], #240]\n" - "vmla.f32 q12, q2, d0[0]\n" - "vmla.f32 q14, q2, d0[1]\n" - "vld1.32 {d4-d5}, [%[rhs_ptr]]!\n" - - "vmla.f32 q5, q3, d2[0]\n" - "vmla.f32 q7, q3, d2[1]\n" - "pld [%[rhs_ptr], #208]\n" - "vmla.f32 q9, q3, d3[0]\n" - "vmla.f32 q11, q3, d3[1]\n" - "vld1.32 {d2-d3}, [%[lhs_ptr]]!\n" - "vmla.f32 q13, q3, d0[0]\n" - "vmla.f32 q15, q3, d0[1]\n" - "vld1.32 {d6-d7}, [%[rhs_ptr]]!\n" - - "vmla.f32 q4, q2, d1[0]\n" - "vmla.f32 q6, q2, d1[1]\n" - "vmla.f32 q8, q2, d2[0]\n" - "vmla.f32 q10, q2, d2[1]\n" - "vmla.f32 q12, q2, d3[0]\n" - "vmla.f32 q14, q2, d3[1]\n" - "vld1.32 {d4-d5}, [%[rhs_ptr]]!\n" - - "vmla.f32 q5, q3, d1[0]\n" - "vmla.f32 q7, q3, d1[1]\n" - "vld1.32 {d0-d1}, [%[lhs_ptr]]!\n" - "vmla.f32 q9, q3, d2[0]\n" - "vmla.f32 q11, q3, d2[1]\n" - "vmla.f32 q13, q3, d3[0]\n" - "vmla.f32 q15, q3, d3[1]\n" - "bne 1b\n" - - "6:\n" - "mov r0, %[res_ptr]\n" - "subs %[rk], %[rk], #1\n" - "beq 3f\n" - - "vmla.f32 q4, q2, d0[0]\n" - "vld1.32 {d2-d3}, [%[lhs_ptr]]!\n" - "vmla.f32 q6, q2, d0[1]\n" - "vmla.f32 q8, q2, d1[0]\n" - "vld1.32 {d6-d7}, [%[rhs_ptr]]!\n" - "vmla.f32 q10, q2, d1[1]\n" - "vmla.f32 q12, q2, d2[0]\n" - "subs %[rk], %[rk], #1\n" - "vmla.f32 q14, q2, d2[1]\n" - "vld1.32 {d4-d5}, [%[rhs_ptr]]!\n" - - "vmla.f32 q5, q3, d0[0]\n" - "vmla.f32 q7, q3, d0[1]\n" - "vmla.f32 q9, q3, d1[0]\n" - "vmla.f32 q11, q3, d1[1]\n" - "vld1.32 {d0-d1}, [%[lhs_ptr]]!\n" - "vmla.f32 q13, q3, d2[0]\n" - "vmla.f32 q15, q3, d2[1]\n" - "vld1.32 {d6-d7}, [%[rhs_ptr]]!\n" - "beq 4f\n" - - "vmla.f32 q4, q2, d3[0]\n" - "vmla.f32 q6, q2, d3[1]\n" - "subs %[rk], %[rk], #1\n" - "vmla.f32 q8, q2, d0[0]\n" - "vmla.f32 q10, q2, d0[1]\n" - "vmla.f32 q12, q2, d1[0]\n" - "vmla.f32 q14, q2, d1[1]\n" - "vld1.32 {d4-d5}, [%[rhs_ptr]]!\n" - - "vmla.f32 q5, q3, d3[0]\n" - "vmla.f32 q7, q3, d3[1]\n" - "vld1.32 {d2-d3}, [%[lhs_ptr]]!\n" - "vmla.f32 q9, q3, d0[0]\n" - "vmla.f32 q11, q3, d0[1]\n" - "vmla.f32 q13, q3, d1[0]\n" - "vmla.f32 q15, q3, d1[1]\n" - "vld1.32 {d6-d7}, [%[rhs_ptr]]!\n" - "beq 5f\n" - - "vld1.32 {d0-d1}, [%[lhs_ptr]]!\n" - "vmla.f32 q4, q2, d2[0]\n" - "vmla.f32 q6, q2, d2[1]\n" - "vmla.f32 q8, q2, d3[0]\n" - "vmla.f32 q10, q2, d3[1]\n" - "vmla.f32 q12, q2, d0[0]\n" - "vmla.f32 q14, q2, d0[1]\n" - "vld1.32 {d4-d5}, [%[rhs_ptr]]!\n" - - "vmla.f32 q5, q3, d2[0]\n" - "vmla.f32 q7, q3, d2[1]\n" - "vmla.f32 q9, q3, d3[0]\n" - "vmla.f32 q11, q3, d3[1]\n" - "vld1.32 {d2-d3}, [%[lhs_ptr]]!\n" - "vmla.f32 q13, q3, d0[0]\n" - "vmla.f32 q15, q3, d0[1]\n" - "vld1.32 {d6-d7}, [%[rhs_ptr]]!\n" - - "vmla.f32 q4, q2, d1[0]\n" - "vmla.f32 q5, q3, d1[0]\n" - "vst1.32 {d8-d11}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q6, q2, d1[1]\n" - "vmla.f32 q7, q3, d1[1]\n" - "vst1.32 {d12-d15}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q8, q2, d2[0]\n" - "vmla.f32 q9, q3, d2[0]\n" - "vst1.32 {d16-d19}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q10, q2, d2[1]\n" - "vmla.f32 q11, q3, d2[1]\n" - "vst1.32 {d20-d23}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q12, q2, d3[0]\n" - "vmla.f32 q13, q3, d3[0]\n" - "vst1.32 {d24-d27}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q14, q2, d3[1]\n" - "vmla.f32 q15, q3, d3[1]\n" - "b 2f\n" - - "3:\n" - "vld1.32 {d6-d7}, [%[rhs_ptr]]!\n" - "vmla.f32 q4, q2, d0[0]\n" - "vmla.f32 q5, q3, d0[0]\n" - "vst1.32 {d8-d11}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q6, q2, d0[1]\n" - "vmla.f32 q7, q3, d0[1]\n" - "vst1.32 {d12-d15}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q8, q2, d1[0]\n" - "vld1.32 {d2}, [%[lhs_ptr]]!\n" - "vmla.f32 q9, q3, d1[0]\n" - "vst1.32 {d16-d19}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q10, q2, d1[1]\n" - "vmla.f32 q11, q3, d1[1]\n" - "vst1.32 {d20-d23}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q12, q2, d2[0]\n" - "vmla.f32 q13, q3, d2[0]\n" - "vst1.32 {d24-d27}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q14, q2, d2[1]\n" - "vmla.f32 q15, q3, d2[1]\n" - "b 2f\n" - - "4:\n" - "vmla.f32 q4, q2, d3[0]\n" - "vmla.f32 q5, q3, d3[0]\n" - "vst1.32 {d8-d11}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q6, q2, d3[1]\n" - "vmla.f32 q7, q3, d3[1]\n" - "vst1.32 {d12-d15}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q8, q2, d0[0]\n" - "vmla.f32 q9, q3, d0[0]\n" - "vst1.32 {d16-d19}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q10, q2, d0[1]\n" - "vmla.f32 q11, q3, d0[1]\n" - "vst1.32 {d20-d23}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q12, q2, d1[0]\n" - "vmla.f32 q13, q3, d1[0]\n" - "vst1.32 {d24-d27}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q14, q2, d1[1]\n" - "vmla.f32 q15, q3, d1[1]\n" - "b 2f\n" - - "5:\n" - "vld1.32 {d0}, [%[lhs_ptr]]!\n" - "vmla.f32 q4, q2, d2[0]\n" - "vmla.f32 q5, q3, d2[0]\n" - "vst1.32 {d8-d11}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q6, q2, d2[1]\n" - "vmla.f32 q7, q3, d2[1]\n" - "vst1.32 {d12-d15}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q8, q2, d3[0]\n" - "vmla.f32 q9, q3, d3[0]\n" - "vst1.32 {d16-d19}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q10, q2, d3[1]\n" - "vmla.f32 q11, q3, d3[1]\n" - "vst1.32 {d20-d23}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q12, q2, d0[0]\n" - "vmla.f32 q13, q3, d0[0]\n" - "vst1.32 {d24-d27}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q14, q2, d0[1]\n" - "vmla.f32 q15, q3, d0[1]\n" - "2:\n" - "vst1.32 {d28-d31}, [r0]\n" - : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), [res_ptr] "+r"(res_ptr), - [nk] "+r"(nk), [rk] "+r"(rk) - : [k0] "r"(k0), [nstride] "r"(nstride) - : "r0", "r1", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", - "q11", "q12", "q13", "q14", "q15", "cc"); -} - -static void sgemm_rowmajor_micro_kernel_4x12(const float *lhs_ptr, const float *rhs_ptr, - float *res_ptr, const int k, const int k0, - const int stride) -{ - int rk = (k & 1); - int nk = (k + 1) / 2; - - const int nstride = stride << 2; - - asm volatile("vld1.f32 {d0-d1}, [%[lhs_ptr]]!\n" - "vld1.f32 {d4-d5}, [%[rhs_ptr]]!\n" - - "cmp %[k0], #0\n" - "beq 0f\n" - - "mov r1, %[res_ptr]\n" - - "subs %[nk], %[nk], #1\n" - "mov r0, r1\n" - "vld1.f32 {d8-d9}, [r0]!\n" - "add r1, %[nstride]\n" - "vld1.f32 {d16-d17}, [r0]!\n" - "vld1.f32 {d24-d25}, [r0]\n" - "mov r0, r1\n" - "vld1.f32 {d10-d11}, [r0]!\n" - "add r1, %[nstride]\n" - "vld1.f32 {d18-d19}, [r0]!\n" - "vld1.f32 {d26-d27}, [r0]\n" - "mov r0, r1\n" - "vld1.f32 {d12-d13}, [r0]!\n" - "add r1, %[nstride]\n" - "vld1.f32 {d20-d21}, [r0]!\n" - "vld1.f32 {d28-d29}, [r0]\n" - "mov r0, r1\n" - "vld1.f32 {d14-d15}, [r0]!\n" - "vld1.f32 {d22-d23}, [r0]!\n" - "vld1.f32 {d30-d31}, [r0]\n" - "beq 2f\n" - - "b 1f\n" - - "0:\n" - "veor q4, q4\n" - "subs %[nk],%[nk], #1\n" - "vmov.f32 q8, q4\n" - "vmov.f32 q12, q4\n" - "vmov.f32 q5, q4\n" - "vmov.f32 q9, q4\n" - "vmov.f32 q13, q4\n" - "vmov.f32 q6, q4\n" - "vmov.f32 q10, q4\n" - "vmov.f32 q14, q4\n" - "vmov.f32 q7, q4\n" - "vmov.f32 q11, q4\n" - "vmov.f32 q15, q4\n" - - "beq 2f\n" - - "1:\n" - "vld1.f32 {d6-d7}, [%[rhs_ptr]]!\n" - "vmla.f32 q4, q2, d0[0]\n" - "vmla.f32 q5, q2, d0[1]\n" - "vmla.f32 q6, q2, d1[0]\n" - "vmla.f32 q7, q2, d1[1]\n" - "vld1.f32 {d4-d5}, [%[rhs_ptr]]!\n" - "vmla.f32 q8, q3, d0[0]\n" - "vmla.f32 q9, q3, d0[1]\n" - "vld1.f32 {d2-d3}, [%[lhs_ptr]]!\n" - "vmla.f32 q10, q3, d1[0]\n" - "vmla.f32 q11, q3, d1[1]\n" - "vld1.f32 {d6-d7}, [%[rhs_ptr]]!\n" - "vmla.f32 q12, q2, d0[0]\n" - "vmla.f32 q13, q2, d0[1]\n" - "pld [%[lhs_ptr], #208]\n" - "vmla.f32 q14, q2, d1[0]\n" - "pld [%[rhs_ptr], #192]\n" - "vmla.f32 q15, q2, d1[1]\n" - - "vld1.f32 {d4-d5}, [%[rhs_ptr]]!\n" - "vmla.f32 q4, q3, d2[0]\n" - "vmla.f32 q5, q3, d2[1]\n" - "vmla.f32 q6, q3, d3[0]\n" - "vmla.f32 q7, q3, d3[1]\n" - "vld1.f32 {d6-d7}, [%[rhs_ptr]]!\n" - "vmla.f32 q8, q2, d2[0]\n" - "vmla.f32 q9, q2, d2[1]\n" - "vld1.f32 {d0-d1}, [%[lhs_ptr]]!\n" - "vmla.f32 q10, q2, d3[0]\n" - "vmla.f32 q11, q2, d3[1]\n" - "vld1.f32 {d4-d5}, [%[rhs_ptr]]!\n" - "vmla.f32 q12, q3, d2[0]\n" - "vmla.f32 q13, q3, d2[1]\n" - "subs %[nk],%[nk], #1\n" - "pld [%[lhs_ptr], #240]\n" - "vmla.f32 q14, q3, d3[0]\n" - "pld [%[rhs_ptr], #208]\n" - "vmla.f32 q15, q3, d3[1]\n" - "bne 1b\n" - - "2:\n" - "cmp %[rk], #1\n" - "beq 3f\n" - - "vld1.f32 {d2-d3}, [%[lhs_ptr]]!\n" - "vld1.f32 {d6-d7}, [%[rhs_ptr]]!\n" - "vmla.f32 q4, q2, d0[0]\n" - "vmla.f32 q5, q2, d0[1]\n" - "vmla.f32 q6, q2, d1[0]\n" - "vmla.f32 q7, q2, d1[1]\n" - "vld1.f32 {d4-d5}, [%[rhs_ptr]]!\n" - "vmla.f32 q8, q3, d0[0]\n" - "vmla.f32 q9, q3, d0[1]\n" - "vmla.f32 q10, q3, d1[0]\n" - "vmla.f32 q11, q3, d1[1]\n" - "vld1.f32 {d6-d7}, [%[rhs_ptr]]!\n" - "vmla.f32 q12, q2, d0[0]\n" - "vmla.f32 q13, q2, d0[1]\n" - "vmla.f32 q14, q2, d1[0]\n" - "vmla.f32 q15, q2, d1[1]\n" - - "vld1.f32 {d4-d5}, [%[rhs_ptr]]!\n" - "vld1.f32 {d0-d1}, [%[rhs_ptr]]!\n" - "mov r1, %[res_ptr]\n" - "mov r0, r1\n" - "vmla.f32 q4, q3, d2[0]\n" - "vmla.f32 q8, q2, d2[0]\n" - "vmla.f32 q12, q0, d2[0]\n" - "vst1.f32 {d8-d9}, [r0]!\n" - "add r1, %[nstride]\n" - "vmla.f32 q5, q3, d2[1]\n" - "vst1.f32 {d16-d17}, [r0]!\n" - "vmla.f32 q9, q2, d2[1]\n" - "vst1.f32 {d24-d25}, [r0]\n" - "mov r0, r1\n" - "vmla.f32 q13, q0, d2[1]\n" - "vst1.f32 {d10-d11}, [r0]!\n" - "vmla.f32 q6, q3, d3[0]\n" - "add r1, %[nstride]\n" - "vst1.f32 {d18-d19}, [r0]!\n" - "vmla.f32 q10, q2, d3[0]\n" - "vst1.f32 {d26-d27}, [r0]\n" - "mov r0, r1\n" - "vmla.f32 q14, q0, d3[0]\n" - "vst1.f32 {d12-d13}, [r0]!\n" - "add r1, %[nstride]\n" - "vmla.f32 q7, q3, d3[1]\n" - "vst1.f32 {d20-d21}, [r0]!\n" - "vmla.f32 q11, q2, d3[1]\n" - "vst1.f32 {d28-d29}, [r0]\n" - "mov r0, r1\n" - "vmla.f32 q15, q0, d3[1]\n" - "b 4f\n" - - "3:\n" - "vld1.f32 {d6-d7}, [%[rhs_ptr]]!\n" - "vld1.f32 {d2-d3}, [%[rhs_ptr]]!\n" - "mov r1, %[res_ptr]\n" - "mov r0, r1\n" - "vmla.f32 q4, q2, d0[0]\n" - "vmla.f32 q8, q3, d0[0]\n" - "vmla.f32 q12, q1, d0[0]\n" - "vst1.f32 {d8-d9}, [r0]!\n" - "add r1, %[nstride]\n" - "vmla.f32 q5, q2, d0[1]\n" - "vst1.f32 {d16-d17}, [r0]!\n" - "vmla.f32 q9, q3, d0[1]\n" - "vst1.f32 {d24-d25}, [r0]\n" - "mov r0, r1\n" - "vmla.f32 q13, q1, d0[1]\n" - "vst1.f32 {d10-d11}, [r0]!\n" - "vmla.f32 q6, q2, d1[0]\n" - "add r1, %[nstride]\n" - "vst1.f32 {d18-d19}, [r0]!\n" - "vmla.f32 q10, q3, d1[0]\n" - "vst1.f32 {d26-d27}, [r0]\n" - "mov r0, r1\n" - "vmla.f32 q14, q1, d1[0]\n" - "vst1.f32 {d12-d13}, [r0]!\n" - "add r1, %[nstride]\n" - "vmla.f32 q7, q2, d1[1]\n" - "vst1.f32 {d20-d21}, [r0]!\n" - "vmla.f32 q11, q3, d1[1]\n" - "vst1.f32 {d28-d29}, [r0]\n" - "mov r0, r1\n" - "vmla.f32 q15, q1, d1[1]\n" - - "4:\n" - "vst1.f32 {d14-d15}, [r0]!\n" - "vst1.f32 {d22-d23}, [r0]!\n" - "vst1.f32 {d30-d31}, [r0]\n" - - : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), [res_ptr] "+r"(res_ptr), - [nk] "+r"(nk), [rk] "+r"(rk) - : [k0] "r"(k0), [nstride] "r"(nstride) - : "r0", "r1", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", - "q11", "q12", "q13", "q14", "q15", "cc"); -} - -static void sgemm_rowmajor_micro_kernel_12x4(const float *lhs_ptr, const float *rhs_ptr, - float *res_ptr, const int k, const int k0, - const int stride) -{ - int rk = (k & 1); - int nk = (k + 1) / 2; - - const int nstride = stride << 2; - - asm volatile("vld1.f32 {d0-d1}, [%[lhs_ptr]]!\n" - "vld1.f32 {d4-d5}, [%[rhs_ptr]]!\n" - - "cmp %[k0], #0\n" - "beq 0f\n" - - "mov r0, %[res_ptr]\n" - "subs %[nk], %[nk], #1\n" - "vld1.f32 {d8-d9}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vld1.f32 {d10-d11}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vld1.f32 {d16-d17}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vld1.f32 {d18-d19}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vld1.f32 {d20-d21}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vld1.f32 {d22-d23}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vld1.f32 {d24-d25}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vld1.f32 {d26-d27}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vld1.f32 {d28-d29}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vld1.f32 {d30-d31}, [r0]\n" - "beq 2f\n" - "b 1f\n" - - "0:\n" - "veor q4, q4\n" - "subs %[nk],%[nk], #1\n" - "vmov.f32 q5, q4\n" - "vmov.f32 q6, q4\n" - "vmov.f32 q7, q4\n" - "vmov.f32 q8, q4\n" - "vmov.f32 q9, q4\n" - "vmov.f32 q10, q4\n" - "vmov.f32 q11, q4\n" - "vmov.f32 q12, q4\n" - "vmov.f32 q13, q4\n" - "vmov.f32 q14, q4\n" - "vmov.f32 q15, q4\n" - - "beq 2f\n" - - "1:\n" - "vld1.f32 {d2-d3}, [%[lhs_ptr]]!\n" - "vmla.f32 q4, q2, d0[0]\n" - "vmla.f32 q5, q2, d0[1]\n" - "vmla.f32 q6, q2, d1[0]\n" - "vmla.f32 q7, q2, d1[1]\n" - "vld1.f32 {d0-d1}, [%[lhs_ptr]]!\n" - "vmla.f32 q8, q2, d2[0]\n" - "vmla.f32 q9, q2, d2[1]\n" - "vld1.f32 {d6-d7}, [%[rhs_ptr]]!\n" - "vmla.f32 q10, q2, d3[0]\n" - "vmla.f32 q11, q2, d3[1]\n" - "vld1.f32 {d2-d3}, [%[lhs_ptr]]!\n" - "vmla.f32 q12, q2, d0[0]\n" - "vmla.f32 q13, q2, d0[1]\n" - "pld [%[rhs_ptr], #208]\n" - "vmla.f32 q14, q2, d1[0]\n" - "pld [%[lhs_ptr], #192]\n" - "vmla.f32 q15, q2, d1[1]\n" - - "vld1.f32 {d0-d1}, [%[lhs_ptr]]!\n" - "vmla.f32 q4, q3, d2[0]\n" - "vmla.f32 q5, q3, d2[1]\n" - "vmla.f32 q6, q3, d3[0]\n" - "vmla.f32 q7, q3, d3[1]\n" - "vld1.f32 {d2-d3}, [%[lhs_ptr]]!\n" - "vmla.f32 q8, q3, d0[0]\n" - "vmla.f32 q9, q3, d0[1]\n" - "vld1.f32 {d4-d5}, [%[rhs_ptr]]!\n" - "vmla.f32 q10, q3, d1[0]\n" - "vmla.f32 q11, q3, d1[1]\n" - "vld1.f32 {d0-d1}, [%[lhs_ptr]]!\n" - "vmla.f32 q12, q3, d2[0]\n" - "vmla.f32 q13, q3, d2[1]\n" - "subs %[nk],%[nk], #1\n" - "pld [%[rhs_ptr], #240]\n" - "vmla.f32 q14, q3, d3[0]\n" - "pld [%[lhs_ptr], #208]\n" - "vmla.f32 q15, q3, d3[1]\n" - "bne 1b\n" - - "2:\n" - "cmp %[rk], #1\n" - "beq 3f\n" - - "vld1.f32 {d2-d3}, [%[lhs_ptr]]!\n" - "vmla.f32 q4, q2, d0[0]\n" - "vmla.f32 q5, q2, d0[1]\n" - "vmla.f32 q6, q2, d1[0]\n" - "vmla.f32 q7, q2, d1[1]\n" - "vld1.f32 {d0-d1}, [%[lhs_ptr]]!\n" - "vmla.f32 q8, q2, d2[0]\n" - "vmla.f32 q9, q2, d2[1]\n" - "vld1.f32 {d6-d7}, [%[rhs_ptr]]!\n" - "vmla.f32 q10, q2, d3[0]\n" - "vmla.f32 q11, q2, d3[1]\n" - "vld1.f32 {d2-d3}, [%[lhs_ptr]]!\n" - "vmla.f32 q12, q2, d0[0]\n" - "vmla.f32 q13, q2, d0[1]\n" - "vmla.f32 q14, q2, d1[0]\n" - "vmla.f32 q15, q2, d1[1]\n" - - "mov r0, %[res_ptr]\n" - "vld1.f32 {d0-d1}, [%[lhs_ptr]]!\n" - "vmla.f32 q4, q3, d2[0]\n" - "vst1.f32 {d8-d9}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q5, q3, d2[1]\n" - "vst1.f32 {d10-d11}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q6, q3, d3[0]\n" - "vst1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q7, q3, d3[1]\n" - "vst1.f32 {d14-d15}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vld1.f32 {d2-d3}, [%[lhs_ptr]]!\n" - "vmla.f32 q8, q3, d0[0]\n" - "vst1.f32 {d16-d17}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q9, q3, d0[1]\n" - "vst1.f32 {d18-d19}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q10, q3, d1[0]\n" - "vst1.f32 {d20-d21}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q11, q3, d1[1]\n" - "vst1.f32 {d22-d23}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q12, q3, d2[0]\n" - "vst1.f32 {d24-d25}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q13, q3, d2[1]\n" - "vst1.f32 {d26-d27}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q14, q3, d3[0]\n" - "vst1.f32 {d28-d29}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q15, q3, d3[1]\n" - "b 4f\n" - - "3:\n" - "mov r0, %[res_ptr]\n" - "vld1.f32 {d2-d3}, [%[lhs_ptr]]!\n" - "vmla.f32 q4, q2, d0[0]\n" - "vst1.f32 {d8-d9}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q5, q2, d0[1]\n" - "vst1.f32 {d10-d11}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q6, q2, d1[0]\n" - "vst1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q7, q2, d1[1]\n" - "vst1.f32 {d14-d15}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vld1.f32 {d0-d1}, [%[lhs_ptr]]!\n" - "vmla.f32 q8, q2, d2[0]\n" - "vst1.f32 {d16-d17}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q9, q2, d2[1]\n" - "vst1.f32 {d18-d19}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q10, q2, d3[0]\n" - "vst1.f32 {d20-d21}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q11, q2, d3[1]\n" - "vst1.f32 {d22-d23}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q12, q2, d0[0]\n" - "vst1.f32 {d24-d25}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q13, q2, d0[1]\n" - "vst1.f32 {d26-d27}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q14, q2, d1[0]\n" - "vst1.f32 {d28-d29}, [r0]\n" - "add r0, r0, %[nstride]\n" - "vmla.f32 q15, q3, d1[1]\n" - - "4:\n" - "vst1.f32 {d30-d31}, [r0]\n" - : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), [res_ptr] "+r"(res_ptr), - [nk] "+r"(nk), [rk] "+r"(rk) - : [k0] "r"(k0), [nstride] "r"(nstride) - : "r0", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", - "q12", "q13", "q14", "q15", "cc"); -} -#endif // __aarch64__ - -typedef void (*sgemm_rowmajoy_micro_kernel_func)(const float *, const float *, float *, const int, - const int, const int); - -static sgemm_rowmajoy_micro_kernel_func sgemm_rowmajoy_micro_kernel_table[12][12] = { - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - { - - 0, 0, 0, 0, 0, -#if !__aarch64__ - sgemm_rowmajor_micro_kernel_4x12, -#else // !__aarch64__ - 0, -#endif // !__aarch64__ - 0, 0, 0, 0, 0, -#if __aarch64__ - sgemm_rowmajor_micro_kernel_4x24 -#else // __aarch64__ - 0 -#endif // __aarch64__ - }, - {0, 0, 0, -#if !__aarch64__ - sgemm_rowmajor_micro_kernel_6x8, -#else // !__aarch64__ - 0, -#endif // !__aarch64__ - 0, 0, 0, 0, 0, 0, 0, 0}, - {0, 0, 0, 0, 0, -#if __aarch64__ - sgemm_rowmajor_micro_kernel_8x12, -#else // __aarch64__ - 0, -#endif // __aarch64__ - 0, 0, 0, 0, 0, 0 - - }, - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 - - }, - {0, -#if !__aarch64__ - sgemm_rowmajor_micro_kernel_12x4, -#else // !__aarch64__ - 0, -#endif // !__aarch64__ - 0, -#if __aarch64__ - sgemm_rowmajor_micro_kernel_12x8, -#else // __aarch64__ - 0, -#endif // __aarch64__ - 0, 0, 0, 0, 0, 0, 0, 0 - - }, - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - { - - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 - - }, - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 - - }, - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 - - }, - {0, -#if __aarch64__ - sgemm_rowmajor_micro_kernel_24x4, -#else // __aarch64__ - 0, -#endif // __aarch64__ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 - - }, - -}; - -void _sgemm_rowmajor_macro_kernel_divnm(const int mr, const int nr, const int mb, const int nb, - const int kb, const float *lhs_ptr, const float *rhs_ptr, - float *res_ptr, const int k0, const int nstride, - const int kstride) -{ - const int nm = (mb + mr - 1) / mr; - const int nn = (nb + nr - 1) / nr; - const int rm = mb % mr; - const int rn = nb % nr; - - sgemm_rowmajoy_micro_kernel_func sgemm_rowmajoy_micro_kernel = - sgemm_rowmajoy_micro_kernel_table[mr / 2 - 1][nr / 2 - 1]; - if (!sgemm_rowmajoy_micro_kernel) - return; - - for (int j = 0; j < nn; j++) - { - const int _nr = (j != nn - 1 || rn == 0) ? nr : rn; - for (int i = 0; i < nm; i++) - { - const int _mr = (i != nm - 1 || rm == 0) ? mr : rm; - if (_mr == mr && _nr == nr) - { - sgemm_rowmajoy_micro_kernel(&lhs_ptr[i * mr * kstride], &rhs_ptr[j * nr * kstride], - &res_ptr[i * mr * nstride + j * nr], kb, k0, nstride); - } - else - { - float res_micro[mr * nr]; - float *res = &res_ptr[i * mr * nstride + j * nr]; - - sgemm_rowmajoy_micro_kernel(&lhs_ptr[i * mr * kstride], &rhs_ptr[j * nr * kstride], - res_micro, kb, 0, nr); - if (k0 == 0) - { - for (int pi = 0; pi < _mr; pi++) - { - for (int pj = 0; pj < _nr; pj++) - { - res[pi * nstride + pj] = res_micro[pi * nr + pj]; - } - } - } - else - { - for (int pi = 0; pi < _mr; pi++) - { - for (int pj = 0; pj < _nr; pj++) - { - res[pi * nstride + pj] += res_micro[pi * nr + pj]; - } - } - } - } - } - } -} - -void _sgemm_rowmajor_macro_kernel_divmn(const int mr, const int nr, const int mb, const int nb, - const int kb, const float *lhs_ptr, const float *rhs_ptr, - float *res_ptr, const int k0, const int nstride, - const int kstride) -{ - const int nm = (mb + mr - 1) / mr; - const int nn = (nb + nr - 1) / nr; - const int rm = mb % mr; - const int rn = nb % nr; - - sgemm_rowmajoy_micro_kernel_func sgemm_rowmajoy_micro_kernel = - sgemm_rowmajoy_micro_kernel_table[mr / 2 - 1][nr / 2 - 1]; - if (!sgemm_rowmajoy_micro_kernel) - return; - - for (int j = 0; j < nm; j++) - { - const int _mr = (j != nm - 1 || rm == 0) ? mr : rm; - for (int i = 0; i < nn; i++) - { - const int _nr = (i != nn - 1 || rn == 0) ? nr : rn; - if (_mr == mr && _nr == nr) - { - sgemm_rowmajoy_micro_kernel(&lhs_ptr[j * mr * kstride], &rhs_ptr[i * nr * kstride], - &res_ptr[j * mr * nstride + i * nr], kb, k0, nstride); - } - else - { - float res_micro[mr * nr]; - float *res = &res_ptr[j * mr * nstride + i * nr]; - - sgemm_rowmajoy_micro_kernel(&lhs_ptr[j * mr * kstride], &rhs_ptr[i * nr * kstride], - res_micro, kb, 0, nr); - if (k0 == 0) - { - for (int pi = 0; pi < _mr; pi++) - { - for (int pj = 0; pj < _nr; pj++) - { - res[pi * nstride + pj] = res_micro[pi * nr + pj]; - } - } - } - else - { - for (int pi = 0; pi < _mr; pi++) - { - for (int pj = 0; pj < _nr; pj++) - { - res[pi * nstride + pj] += res_micro[pi * nr + pj]; - } - } - } - } - } - } -} - -void _sgemm_colmajor_macro_kernel_divnm(const int mr, const int nr, const int mb, const int nb, - const int kb, const float *lhs_ptr, const float *rhs_ptr, - float *res_ptr, const int k0, const int mstride, - const int kstride) -{ - _sgemm_rowmajor_macro_kernel_divmn(nr, mr, nb, mb, kb, rhs_ptr, lhs_ptr, res_ptr, k0, mstride, - kstride); -} - -void _sgemm_colmajor_macro_kernel_divmn(const int mr, const int nr, const int mb, const int nb, - const int kb, const float *lhs_ptr, const float *rhs_ptr, - float *res_ptr, const int k0, const int mstride, - const int kstride) -{ - _sgemm_rowmajor_macro_kernel_divnm(nr, mr, nb, mb, kb, rhs_ptr, lhs_ptr, res_ptr, k0, mstride, - kstride); -} - -#if __aarch64__ -void _sparse_sgemm_kernel(const int nb, float lhs_data, const float *rhs_ptr, float *res_ptr) -{ - int nn = nb >> 3; - int rn = nb & 7; - - if (nn > 0) - { - asm volatile("mov x0, %[res_ptr]\n" - "dup v0.2d, %[lhs_data]\n" - "ld1 {v1.4s}, [%[rhs_ptr]], #16\n" - "ld1 {v2.4s}, [x0], #16\n" - - "subs %[nn], %[nn], #1\n" - "beq 2f\n" - - "1:\n" - "ld1 {v4.4s}, [x0], #16\n" - "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" - - "fmla v2.4s, v1.4s, v0.s[0]\n" - "st1 {v2.4s}, [%[res_ptr]], #16\n" - - "ld1 {v2.4s}, [x0], #16\n" - "ld1 {v1.4s}, [%[rhs_ptr]], #16\n" - - "fmla v4.4s, v3.4s, v0.s[0]\n" - "st1 {v4.4s}, [%[res_ptr]], #16\n" - - "subs %[nn], %[nn], #1\n" - "bne 1b\n" - - "2:\n" - "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" - "ld1 {v4.4s}, [x0], #16\n" - - "fmla v2.4s, v1.4s, v0.s[0]\n" - "st1 {v2.4s}, [%[res_ptr]], #16\n" - - "fmla v4.4s, v3.4s, v0.s[0]\n" - "st1 {v4.4s}, [%[res_ptr]], #16\n" - : [rhs_ptr] "+r"(rhs_ptr), [res_ptr] "+r"(res_ptr), [nn] "+r"(nn) - : [lhs_data] "r"(lhs_data) - : "x0", "v0", "v1", "v2", "v3", "v4", "cc"); - } - if (rn > 0) - { - int _nn = rn >> 2; - int _rn = rn & 3; - - if (_nn > 0) - { - asm volatile("dup v0.2d, %[lhs_data]\n" - "ld1 {v1.4s}, [%[rhs_ptr]], #16\n" - "ld1 {v2.4s}, [%[res_ptr]]\n" - "fmla v2.4s, v1.4s, v0.s[0]\n" - "st1 {v2.4s}, [%[res_ptr]], #16\n" - : [rhs_ptr] "+r"(rhs_ptr), [res_ptr] "+r"(res_ptr) - : [lhs_data] "r"(lhs_data) - : "x0", "x1", "x2", "cc"); - } - if (_rn > 0) - { - for (int i = 0; i < _rn; i++) - { - res_ptr[i] += lhs_data * rhs_ptr[i]; - } - } - } -} - -#else // __aarch64__ -void _sparse_sgemm_kernel(const int nb, float lhs_data, const float *rhs_ptr, float *res_ptr) -{ - int nn = nb >> 3; - int rn = nb & 7; - - if (nn > 0) - { - asm volatile("mov r0, %[res_ptr]\n" - "vdup.32 d0, %[lhs_data]\n" - "vld1.f32 {d2-d3}, [%[rhs_ptr]]!\n" - "vld1.f32 {d4-d5}, [r0]!\n" - - "subs %[nn], %[nn], #1\n" - "beq 2f\n" - - "1:\n" - "vld1.f32 {d8-d9}, [r0]!\n" - "vld1.f32 {d6-d7}, [%[rhs_ptr]]!\n" - - "vmla.f32 q2, q1, d0[0]\n" - "vst1.f32 {d4-d5}, [%[res_ptr]]!\n" - - "vld1.f32 {d4-d5}, [r0]!\n" - "vld1.f32 {d2-d3}, [%[rhs_ptr]]!\n" - - "vmla.f32 q4, q3, d0[0]\n" - "vst1.f32 {d8-d9}, [%[res_ptr]]!\n" - - "subs %[nn], %[nn], #1\n" - "bne 1b\n" - - "2:\n" - "vld1.f32 {d6-d7}, [%[rhs_ptr]]!\n" - "vld1.f32 {d8-d9}, [r0]!\n" - - "vmla.f32 q2, q1, d0[0]\n" - "vst1.f32 {d4-d5}, [%[res_ptr]]!\n" - - "vmla.f32 q4, q3, d0[0]\n" - "vst1.f32 {d8-d9}, [%[res_ptr]]!\n" - : [rhs_ptr] "+r"(rhs_ptr), [res_ptr] "+r"(res_ptr), [nn] "+r"(nn) - : [lhs_data] "r"(lhs_data) - : "r0", "q0", "q1", "q2", "q3", "q4", "cc"); - } - if (rn > 0) - { - int _nn = rn >> 2; - int _rn = rn & 3; - - if (_nn > 0) - { - asm volatile("vdup.32 d0, %[lhs_data]\n" - "vld1.f32 {d2-d3}, [%[rhs_ptr]]!\n" - "vld1.f32 {d4-d5}, [%[res_ptr]]\n" - "vmla.f32 q2, q1, d0[0]\n" - "vst1.f32 {d4-d5}, [%[res_ptr]]!\n" - : [rhs_ptr] "+r"(rhs_ptr), [res_ptr] "+r"(res_ptr) - : [lhs_data] "r"(lhs_data) - : "q0", "q1", "q2", "cc"); - } - if (_rn > 0) - { - for (int i = 0; i < _rn; i++) - { - res_ptr[i] += lhs_data * rhs_ptr[i]; - } - } - } -} -#endif // __aarch64__ - -} // namespace srcn -} // namespace nnfw diff --git a/compute/ncnn/src/srcn/sgemm_kernel.h b/compute/ncnn/src/srcn/sgemm_kernel.h deleted file mode 100644 index 9e220bc33..000000000 --- a/compute/ncnn/src/srcn/sgemm_kernel.h +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __NNFW_SRCN_SGEMM_KERNEL_H__ -#define __NNFW_SRCN_SGEMM_KERNEL_H__ - -#include "ncnn/srcn/conv_type.h" - -namespace nnfw -{ -namespace srcn -{ - -void _sgemm_rowmajor_macro_kernel_divnm(const int mr, const int nr, const int mb, const int nb, - const int kb, const float *lhs_ptr, const float *rhs_ptr, - float *res_ptr, const int k0, const int nstride, - const int kstride); - -void _sgemm_rowmajor_macro_kernel_divmn(const int mr, const int nr, const int mb, const int nb, - const int kb, const float *lhs_ptr, const float *rhs_ptr, - float *res_ptr, const int k0, const int nstride, - const int kstride); - -void _sgemm_colmajor_macro_kernel_divnm(const int mr, const int nr, const int mb, const int nb, - const int kb, const float *lhs_ptr, const float *rhs_ptr, - float *res_ptr, const int k0, const int mstride, - const int kstride); - -void _sgemm_colmajor_macro_kernel_divmn(const int mr, const int nr, const int mb, const int nb, - const int kb, const float *lhs_ptr, const float *rhs_ptr, - float *res_ptr, const int k0, const int mstride, - const int kstride); - -void _sparse_sgemm_kernel(const int nb, float lhs_data, const float *rhs_ptr, float *res_ptr); - -} // namespace srcn -} // namespace nnfw - -#endif // __NNFW_SRCN_SGEMM_KERNEL_H__ diff --git a/compute/ncnn/src/srcn/sgemm_pack.cc b/compute/ncnn/src/srcn/sgemm_pack.cc deleted file mode 100644 index 8767f6c0a..000000000 --- a/compute/ncnn/src/srcn/sgemm_pack.cc +++ /dev/null @@ -1,2316 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <stdlib.h> -#include <arm_neon.h> - -#include "ncnn/srcn/conv_type.h" -#include "common.h" - -namespace nnfw -{ -namespace srcn -{ - -void _pack_rowmajor_notrans_lhs(const int mr, const int mb, const int kb, const int stride, - const float *lhs_ptr, float *plhs_ptr) -{ - const int nm = mb / mr; - const int rm = mb % mr; - - switch (mr) - { -#if __aarch64__ - case 24: - for (int i = 0; i < nm; i++) - { - int nk = kb >> 2; - int rk = kb & 0x03; - - const float *lhs_temp = lhs_ptr; - const int _stride = stride << 2; - - if (nk > 0) - { - asm volatile("0:\n" - "mov x0, %[lhs_temp]\n" - - "ld1 {v4.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v5.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - - "zip1 v28.4s, v4.4s, v6.4s\n" - "zip2 v30.4s, v4.4s, v6.4s\n" - "zip1 v29.4s, v5.4s, v7.4s\n" - "zip2 v31.4s, v5.4s, v7.4s\n" - "zip1 v4.4s, v28.4s, v29.4s\n" - "zip2 v5.4s, v28.4s, v29.4s\n" - "zip1 v6.4s, v30.4s, v31.4s\n" - "zip2 v7.4s, v30.4s, v31.4s\n" - - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v9.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v11.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - - "zip1 v28.4s, v8.4s, v10.4s\n" - "zip2 v30.4s, v8.4s, v10.4s\n" - "zip1 v29.4s, v9.4s, v11.4s\n" - "zip2 v31.4s, v9.4s, v11.4s\n" - "zip1 v8.4s, v28.4s, v29.4s\n" - "zip2 v9.4s, v28.4s, v29.4s\n" - "zip1 v10.4s, v30.4s, v31.4s\n" - "zip2 v11.4s, v30.4s, v31.4s\n" - - "ld1 {v12.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v13.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v14.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v15.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - - "zip1 v28.4s, v12.4s, v14.4s\n" - "zip2 v30.4s, v12.4s, v14.4s\n" - "zip1 v29.4s, v13.4s, v15.4s\n" - "zip2 v31.4s, v13.4s, v15.4s\n" - "zip1 v12.4s, v28.4s, v29.4s\n" - "zip2 v13.4s, v28.4s, v29.4s\n" - "zip1 v14.4s, v30.4s, v31.4s\n" - "zip2 v15.4s, v30.4s, v31.4s\n" - - "ld1 {v16.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v17.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v18.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v19.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - - "zip1 v28.4s, v16.4s, v18.4s\n" - "zip2 v30.4s, v16.4s, v18.4s\n" - "zip1 v29.4s, v17.4s, v19.4s\n" - "zip2 v31.4s, v17.4s, v19.4s\n" - "zip1 v16.4s, v28.4s, v29.4s\n" - "zip2 v17.4s, v28.4s, v29.4s\n" - "zip1 v18.4s, v30.4s, v31.4s\n" - "zip2 v19.4s, v30.4s, v31.4s\n" - - "ld1 {v20.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v21.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v22.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v23.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - - "zip1 v28.4s, v20.4s, v22.4s\n" - "zip2 v30.4s, v20.4s, v22.4s\n" - "zip1 v29.4s, v21.4s, v23.4s\n" - "zip2 v31.4s, v21.4s, v23.4s\n" - "zip1 v20.4s, v28.4s, v29.4s\n" - "zip2 v21.4s, v28.4s, v29.4s\n" - "zip1 v22.4s, v30.4s, v31.4s\n" - "zip2 v23.4s, v30.4s, v31.4s\n" - - "ld1 {v24.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v25.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v26.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v27.4s}, [x0]\n" - - "zip1 v28.4s, v24.4s, v26.4s\n" - "zip2 v30.4s, v24.4s, v26.4s\n" - "zip1 v29.4s, v25.4s, v27.4s\n" - "zip2 v31.4s, v25.4s, v27.4s\n" - "zip1 v24.4s, v28.4s, v29.4s\n" - "zip2 v25.4s, v28.4s, v29.4s\n" - "zip1 v26.4s, v30.4s, v31.4s\n" - "zip2 v27.4s, v30.4s, v31.4s\n" - - "st1 {v4.4s}, [%[plhs_ptr]], #16\n" - "st1 {v8.4s}, [%[plhs_ptr]], #16\n" - "st1 {v12.4s}, [%[plhs_ptr]], #16\n" - "st1 {v16.4s}, [%[plhs_ptr]], #16\n" - "st1 {v20.4s}, [%[plhs_ptr]], #16\n" - "st1 {v24.4s}, [%[plhs_ptr]], #16\n" - "st1 {v5.4s}, [%[plhs_ptr]], #16\n" - "st1 {v9.4s}, [%[plhs_ptr]], #16\n" - "st1 {v13.4s}, [%[plhs_ptr]], #16\n" - "st1 {v17.4s}, [%[plhs_ptr]], #16\n" - "st1 {v21.4s}, [%[plhs_ptr]], #16\n" - "st1 {v25.4s}, [%[plhs_ptr]], #16\n" - "st1 {v6.4s}, [%[plhs_ptr]], #16\n" - "st1 {v10.4s}, [%[plhs_ptr]], #16\n" - "st1 {v14.4s}, [%[plhs_ptr]], #16\n" - "st1 {v18.4s}, [%[plhs_ptr]], #16\n" - "st1 {v22.4s}, [%[plhs_ptr]], #16\n" - "st1 {v26.4s}, [%[plhs_ptr]], #16\n" - "st1 {v7.4s}, [%[plhs_ptr]], #16\n" - "st1 {v11.4s}, [%[plhs_ptr]], #16\n" - "st1 {v15.4s}, [%[plhs_ptr]], #16\n" - "st1 {v19.4s}, [%[plhs_ptr]], #16\n" - "st1 {v23.4s}, [%[plhs_ptr]], #16\n" - "st1 {v27.4s}, [%[plhs_ptr]], #16\n" - - "subs %[nk], %[nk], #1\n" - "add %[lhs_temp], %[lhs_temp], #16\n" - "bne 0b\n" - : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "x0", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", - "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", - "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); - } - - for (int j = 0; j < rk; j++) - { - plhs_ptr[0] = lhs_temp[0]; - plhs_ptr[1] = lhs_temp[stride]; - plhs_ptr[2] = lhs_temp[stride << 1]; - plhs_ptr[3] = lhs_temp[3 * stride]; - plhs_ptr[4] = lhs_temp[stride << 2]; - plhs_ptr[5] = lhs_temp[5 * stride]; - plhs_ptr[6] = lhs_temp[6 * stride]; - plhs_ptr[7] = lhs_temp[7 * stride]; - plhs_ptr[8] = lhs_temp[stride << 3]; - plhs_ptr[9] = lhs_temp[9 * stride]; - plhs_ptr[10] = lhs_temp[10 * stride]; - plhs_ptr[11] = lhs_temp[11 * stride]; - plhs_ptr[12] = lhs_temp[0]; - plhs_ptr[13] = lhs_temp[13 * stride]; - plhs_ptr[14] = lhs_temp[14 * stride]; - plhs_ptr[15] = lhs_temp[15 * stride]; - plhs_ptr[16] = lhs_temp[stride << 4]; - plhs_ptr[17] = lhs_temp[17 * stride]; - plhs_ptr[18] = lhs_temp[18 * stride]; - plhs_ptr[19] = lhs_temp[19 * stride]; - plhs_ptr[20] = lhs_temp[20 * stride]; - plhs_ptr[21] = lhs_temp[21 * stride]; - plhs_ptr[22] = lhs_temp[22 * stride]; - plhs_ptr[23] = lhs_temp[23 * stride]; - plhs_ptr += mr; - lhs_temp++; - } - - lhs_ptr += mr * stride; - } - break; - case 16: - for (int i = 0; i < nm; i++) - { - int nk = kb >> 2; - int rk = kb & 0x03; - - const float *lhs_temp = lhs_ptr; - const int _stride = stride << 2; - - if (nk > 0) - { - asm volatile("0:\n" - "mov x0, %[lhs_temp]\n" - - "ld1 {v4.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v5.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - - "zip1 v28.4s, v4.4s, v6.4s\n" - "zip2 v30.4s, v4.4s, v6.4s\n" - "zip1 v29.4s, v5.4s, v7.4s\n" - "zip2 v31.4s, v5.4s, v7.4s\n" - "zip1 v4.4s, v28.4s, v29.4s\n" - "zip2 v5.4s, v28.4s, v29.4s\n" - "zip1 v6.4s, v30.4s, v31.4s\n" - "zip2 v7.4s, v30.4s, v31.4s\n" - - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v9.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v11.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - - "zip1 v28.4s, v8.4s, v10.4s\n" - "zip2 v30.4s, v8.4s, v10.4s\n" - "zip1 v29.4s, v9.4s, v11.4s\n" - "zip2 v31.4s, v9.4s, v11.4s\n" - "zip1 v8.4s, v28.4s, v29.4s\n" - "zip2 v9.4s, v28.4s, v29.4s\n" - "zip1 v10.4s, v30.4s, v31.4s\n" - "zip2 v11.4s, v30.4s, v31.4s\n" - - "ld1 {v12.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v13.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v14.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v15.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - - "zip1 v28.4s, v12.4s, v14.4s\n" - "zip2 v30.4s, v12.4s, v14.4s\n" - "zip1 v29.4s, v13.4s, v15.4s\n" - "zip2 v31.4s, v13.4s, v15.4s\n" - "zip1 v12.4s, v28.4s, v29.4s\n" - "zip2 v13.4s, v28.4s, v29.4s\n" - "zip1 v14.4s, v30.4s, v31.4s\n" - "zip2 v15.4s, v30.4s, v31.4s\n" - - "ld1 {v16.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v17.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v18.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v19.4s}, [x0]\n" - - "zip1 v28.4s, v16.4s, v18.4s\n" - "zip2 v30.4s, v16.4s, v18.4s\n" - "zip1 v29.4s, v17.4s, v19.4s\n" - "zip2 v31.4s, v17.4s, v19.4s\n" - "zip1 v16.4s, v28.4s, v29.4s\n" - "zip2 v17.4s, v28.4s, v29.4s\n" - "zip1 v18.4s, v30.4s, v31.4s\n" - "zip2 v19.4s, v30.4s, v31.4s\n" - - "st1 {v4.4s}, [%[plhs_ptr]], #16\n" - "st1 {v8.4s}, [%[plhs_ptr]], #16\n" - "st1 {v12.4s}, [%[plhs_ptr]], #16\n" - "st1 {v16.4s}, [%[plhs_ptr]], #16\n" - "st1 {v5.4s}, [%[plhs_ptr]], #16\n" - "st1 {v9.4s}, [%[plhs_ptr]], #16\n" - "st1 {v13.4s}, [%[plhs_ptr]], #16\n" - "st1 {v17.4s}, [%[plhs_ptr]], #16\n" - "st1 {v6.4s}, [%[plhs_ptr]], #16\n" - "st1 {v10.4s}, [%[plhs_ptr]], #16\n" - "st1 {v14.4s}, [%[plhs_ptr]], #16\n" - "st1 {v18.4s}, [%[plhs_ptr]], #16\n" - "st1 {v7.4s}, [%[plhs_ptr]], #16\n" - "st1 {v11.4s}, [%[plhs_ptr]], #16\n" - "st1 {v15.4s}, [%[plhs_ptr]], #16\n" - "st1 {v19.4s}, [%[plhs_ptr]], #16\n" - - "subs %[nk], %[nk], #1\n" - "add %[lhs_temp], %[lhs_temp], #16\n" - "bne 0b\n" - : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "x0", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", - "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v28", "v29", - "v30", "v31"); - } - - for (int j = 0; j < rk; j++) - { - plhs_ptr[0] = lhs_temp[0]; - plhs_ptr[1] = lhs_temp[stride]; - plhs_ptr[2] = lhs_temp[stride << 1]; - plhs_ptr[3] = lhs_temp[3 * stride]; - plhs_ptr[4] = lhs_temp[stride << 2]; - plhs_ptr[5] = lhs_temp[5 * stride]; - plhs_ptr[6] = lhs_temp[6 * stride]; - plhs_ptr[7] = lhs_temp[7 * stride]; - plhs_ptr[8] = lhs_temp[stride << 3]; - plhs_ptr[9] = lhs_temp[9 * stride]; - plhs_ptr[10] = lhs_temp[10 * stride]; - plhs_ptr[11] = lhs_temp[11 * stride]; - plhs_ptr[12] = lhs_temp[0]; - plhs_ptr[13] = lhs_temp[13 * stride]; - plhs_ptr[14] = lhs_temp[14 * stride]; - plhs_ptr[15] = lhs_temp[15 * stride]; - plhs_ptr += mr; - lhs_temp++; - } - - lhs_ptr += mr * stride; - } - break; -#endif // __aarch64__ - case 12: - for (int i = 0; i < nm; i++) - { - int nk = kb >> 2; - int rk = kb & 0x03; - - const float *lhs_temp = lhs_ptr; - const int _stride = stride << 2; - - if (nk > 0) - { -#if __aarch64__ - asm volatile("0:\n" - "mov x0, %[lhs_temp]\n" - - "ld1 {v4.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v5.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - - "zip1 v28.4s, v4.4s, v6.4s\n" - "zip2 v30.4s, v4.4s, v6.4s\n" - "zip1 v29.4s, v5.4s, v7.4s\n" - "zip2 v31.4s, v5.4s, v7.4s\n" - "zip1 v4.4s, v28.4s, v29.4s\n" - "zip2 v5.4s, v28.4s, v29.4s\n" - "zip1 v6.4s, v30.4s, v31.4s\n" - "zip2 v7.4s, v30.4s, v31.4s\n" - - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v9.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v11.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - - "zip1 v28.4s, v8.4s, v10.4s\n" - "zip2 v30.4s, v8.4s, v10.4s\n" - "zip1 v29.4s, v9.4s, v11.4s\n" - "zip2 v31.4s, v9.4s, v11.4s\n" - "zip1 v8.4s, v28.4s, v29.4s\n" - "zip2 v9.4s, v28.4s, v29.4s\n" - "zip1 v10.4s, v30.4s, v31.4s\n" - "zip2 v11.4s, v30.4s, v31.4s\n" - - "ld1 {v12.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v13.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v14.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v15.4s}, [x0]\n" - - "zip1 v28.4s, v12.4s, v14.4s\n" - "zip2 v30.4s, v12.4s, v14.4s\n" - "zip1 v29.4s, v13.4s, v15.4s\n" - "zip2 v31.4s, v13.4s, v15.4s\n" - "zip1 v12.4s, v28.4s, v29.4s\n" - "zip2 v13.4s, v28.4s, v29.4s\n" - "zip1 v14.4s, v30.4s, v31.4s\n" - "zip2 v15.4s, v30.4s, v31.4s\n" - - "st1 {v4.4s}, [%[plhs_ptr]], #16\n" - "st1 {v8.4s}, [%[plhs_ptr]], #16\n" - "st1 {v12.4s}, [%[plhs_ptr]], #16\n" - "st1 {v5.4s}, [%[plhs_ptr]], #16\n" - "st1 {v9.4s}, [%[plhs_ptr]], #16\n" - "st1 {v13.4s}, [%[plhs_ptr]], #16\n" - "st1 {v6.4s}, [%[plhs_ptr]], #16\n" - "st1 {v10.4s}, [%[plhs_ptr]], #16\n" - "st1 {v14.4s}, [%[plhs_ptr]], #16\n" - "st1 {v7.4s}, [%[plhs_ptr]], #16\n" - "st1 {v11.4s}, [%[plhs_ptr]], #16\n" - "st1 {v15.4s}, [%[plhs_ptr]], #16\n" - - "subs %[nk], %[nk], #1\n" - "add %[lhs_temp], %[lhs_temp], #16\n" - "bne 0b\n" - : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "x0", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", - "v12", "v13", "v14", "v15", "v28", "v29", "v30", "v31"); -#else // __aarch64__ - asm volatile("0:\n" - "mov r0, %[lhs_temp]\n" - - "vld1.f32 {d8-d9}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d10-d11}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - "add r0, r0, %[_stride]\n" - - "vzip.32 q4, q6\n" - "vzip.32 q5, q7\n" - "vzip.32 q4, q5\n" - "vzip.32 q6, q7\n" - - "vld1.f32 {d16-d17}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d18-d19}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d20-d21}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d22-d23}, [r0]\n" - "add r0, r0, %[_stride]\n" - - "vzip.32 q8, q10\n" - "vzip.32 q9, q11\n" - "vzip.32 q8, q9\n" - "vzip.32 q10, q11\n" - - "vld1.f32 {d24-d25}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d26-d27}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d28-d29}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d30-d31}, [r0]\n" - - "vzip.32 q12, q14\n" - "vzip.32 q13, q15\n" - "vzip.32 q12, q13\n" - "vzip.32 q14, q15\n" - - "vst1.f32 {d8-d9}, [%[plhs_ptr]]!\n" - "vst1.f32 {d16-d17}, [%[plhs_ptr]]!\n" - "vst1.f32 {d24-d25}, [%[plhs_ptr]]!\n" - "vst1.f32 {d10-d11}, [%[plhs_ptr]]!\n" - "vst1.f32 {d18-d19}, [%[plhs_ptr]]!\n" - "vst1.f32 {d26-d27}, [%[plhs_ptr]]!\n" - "vst1.f32 {d12-d13}, [%[plhs_ptr]]!\n" - "vst1.f32 {d20-d21}, [%[plhs_ptr]]!\n" - "vst1.f32 {d28-d29}, [%[plhs_ptr]]!\n" - "vst1.f32 {d14-d15}, [%[plhs_ptr]]!\n" - "vst1.f32 {d22-d23}, [%[plhs_ptr]]!\n" - "vst1.f32 {d30-d31}, [%[plhs_ptr]]!\n" - - "subs %[nk], %[nk], #1\n" - "add %[lhs_temp], %[lhs_temp], #16\n" - "bne 0b\n" - : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "r0", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", - "q12", "q13", "q14", "q15"); -#endif // __aarch64__ - } - - for (int j = 0; j < rk; j++) - { - plhs_ptr[0] = lhs_temp[0]; - plhs_ptr[1] = lhs_temp[stride]; - plhs_ptr[2] = lhs_temp[stride << 1]; - plhs_ptr[3] = lhs_temp[3 * stride]; - plhs_ptr[4] = lhs_temp[stride << 2]; - plhs_ptr[5] = lhs_temp[5 * stride]; - plhs_ptr[6] = lhs_temp[6 * stride]; - plhs_ptr[7] = lhs_temp[7 * stride]; - plhs_ptr[8] = lhs_temp[stride << 3]; - plhs_ptr[9] = lhs_temp[9 * stride]; - plhs_ptr[10] = lhs_temp[10 * stride]; - plhs_ptr[11] = lhs_temp[11 * stride]; - plhs_ptr += mr; - lhs_temp++; - } - - lhs_ptr += mr * stride; - } - break; - case 8: - for (int i = 0; i < nm; i++) - { - int nk = kb >> 2; - int rk = kb & 0x03; - - const float *lhs_temp = lhs_ptr; - const int _stride = stride << 2; - - if (nk > 0) - { -#if __aarch64__ - asm volatile("0:\n" - "mov x0, %[lhs_temp]\n" - - "ld1 {v4.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v5.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - - "zip1 v28.4s, v4.4s, v6.4s\n" - "zip2 v30.4s, v4.4s, v6.4s\n" - "zip1 v29.4s, v5.4s, v7.4s\n" - "zip2 v31.4s, v5.4s, v7.4s\n" - "zip1 v4.4s, v28.4s, v29.4s\n" - "zip2 v5.4s, v28.4s, v29.4s\n" - "zip1 v6.4s, v30.4s, v31.4s\n" - "zip2 v7.4s, v30.4s, v31.4s\n" - - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v9.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v11.4s}, [x0]\n" - - "zip1 v28.4s, v8.4s, v10.4s\n" - "zip2 v30.4s, v8.4s, v10.4s\n" - "zip1 v29.4s, v9.4s, v11.4s\n" - "zip2 v31.4s, v9.4s, v11.4s\n" - "zip1 v8.4s, v28.4s, v29.4s\n" - "zip2 v9.4s, v28.4s, v29.4s\n" - "zip1 v10.4s, v30.4s, v31.4s\n" - "zip2 v11.4s, v30.4s, v31.4s\n" - - "st1 {v4.4s}, [%[plhs_ptr]], #16\n" - "st1 {v8.4s}, [%[plhs_ptr]], #16\n" - "st1 {v5.4s}, [%[plhs_ptr]], #16\n" - "st1 {v9.4s}, [%[plhs_ptr]], #16\n" - "st1 {v6.4s}, [%[plhs_ptr]], #16\n" - "st1 {v10.4s}, [%[plhs_ptr]], #16\n" - "st1 {v7.4s}, [%[plhs_ptr]], #16\n" - "st1 {v11.4s}, [%[plhs_ptr]], #16\n" - - "subs %[nk], %[nk], #1\n" - "add %[lhs_temp], %[lhs_temp], #16\n" - "bne 0b\n" - : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "x0", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", - "v28", "v29", "v30", "v31"); -#else // __aarch64__ - asm volatile("0:\n" - "mov r0, %[lhs_temp]\n" - - "vld1.f32 {d8-d9}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d10-d11}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - "add r0, r0, %[_stride]\n" - - "vzip.32 q4, q6\n" - "vzip.32 q5, q7\n" - "vzip.32 q4, q5\n" - "vzip.32 q6, q7\n" - - "vld1.f32 {d16-d17}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d18-d19}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d20-d21}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d22-d23}, [r0]\n" - - "vzip.32 q8, q10\n" - "vzip.32 q9, q11\n" - "vzip.32 q8, q9\n" - "vzip.32 q10, q11\n" - - "vst1.f32 {d8-d9}, [%[plhs_ptr]]!\n" - "vst1.f32 {d16-d17}, [%[plhs_ptr]]!\n" - "vst1.f32 {d10-d11}, [%[plhs_ptr]]!\n" - "vst1.f32 {d18-d19}, [%[plhs_ptr]]!\n" - "vst1.f32 {d12-d13}, [%[plhs_ptr]]!\n" - "vst1.f32 {d20-d21}, [%[plhs_ptr]]!\n" - "vst1.f32 {d14-d15}, [%[plhs_ptr]]!\n" - "vst1.f32 {d22-d23}, [%[plhs_ptr]]!\n" - - "subs %[nk], %[nk], #1\n" - "add %[lhs_temp], %[lhs_temp], #16\n" - "bne 0b\n" - : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "r0", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11"); -#endif // __aarch64__ - } - - for (int j = 0; j < rk; j++) - { - plhs_ptr[0] = lhs_temp[0]; - plhs_ptr[1] = lhs_temp[stride]; - plhs_ptr[2] = lhs_temp[stride << 1]; - plhs_ptr[3] = lhs_temp[3 * stride]; - plhs_ptr[4] = lhs_temp[stride << 2]; - plhs_ptr[5] = lhs_temp[5 * stride]; - plhs_ptr[6] = lhs_temp[6 * stride]; - plhs_ptr[7] = lhs_temp[7 * stride]; - plhs_ptr += mr; - lhs_temp++; - } - - lhs_ptr += mr * stride; - } - break; - case 6: - for (int i = 0; i < nm; i++) - { - int nk = kb >> 2; - int rk = kb & 0x03; - - const float *lhs_temp = lhs_ptr; - const int _stride = stride << 2; - - if (nk > 0) - { -#if __aarch64__ - // TODO: 4--->6 - asm volatile("0:\n" - "mov x0, %[lhs_temp]\n" - - "ld1 {v4.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v5.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v8.4s}, [x0]\n" - - "zip1 v28.4s, v4.4s, v6.4s\n" - "zip2 v30.4s, v4.4s, v6.4s\n" - "zip1 v29.4s, v5.4s, v7.4s\n" - "zip2 v31.4s, v5.4s, v7.4s\n" - "zip1 v4.4s, v28.4s, v29.4s\n" - "zip2 v5.4s, v28.4s, v29.4s\n" - "zip1 v6.4s, v30.4s, v31.4s\n" - "zip2 v7.4s, v30.4s, v31.4s\n" - - "st1 {v4.4s}, [%[plhs_ptr]], #16\n" - "st1 {v5.4s}, [%[plhs_ptr]], #16\n" - "st1 {v6.4s}, [%[plhs_ptr]], #16\n" - "st1 {v7.4s}, [%[plhs_ptr]], #16\n" - - "subs %[nk], %[nk], #1\n" - "add %[lhs_temp], %[lhs_temp], #16\n" - "bne 0b\n" - : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "x0", "v4", "v5", "v6", "v7", "v28", "v29", "v30", "v31"); -#else // __aarch64__ - asm volatile("0:\n" - "mov r0, %[lhs_temp]\n" - - "vld1.f32 {d8-d9}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d10-d11}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d16-d17}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d18-d19}, [r0]\n" - - "vzip.32 q4, q6\n" - "vzip.32 q5, q7\n" - "vzip.32 q4, q5\n" - "vzip.32 q6, q7\n" - "vzip.32 q8, q9\n" - - "vst1.f32 {d8-d9}, [%[plhs_ptr]]!\n" - "vst1.f32 {d16}, [%[plhs_ptr]]!\n" - "vst1.f32 {d10-d11}, [%[plhs_ptr]]!\n" - "vst1.f32 {d17}, [%[plhs_ptr]]!\n" - "vst1.f32 {d12-d13}, [%[plhs_ptr]]!\n" - "vst1.f32 {d18}, [%[plhs_ptr]]!\n" - "vst1.f32 {d14-d15}, [%[plhs_ptr]]!\n" - "vst1.f32 {d19}, [%[plhs_ptr]]!\n" - - "subs %[nk], %[nk], #1\n" - "add %[lhs_temp], %[lhs_temp], #16\n" - "bne 0b\n" - : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "r0", "q4", "q5", "q6", "q7", "q8", "q9"); -#endif // __aarch64__ - } - - for (int j = 0; j < rk; j++) - { - plhs_ptr[0] = lhs_temp[0]; - plhs_ptr[1] = lhs_temp[stride]; - plhs_ptr[2] = lhs_temp[stride << 1]; - plhs_ptr[3] = lhs_temp[3 * stride]; - plhs_ptr[4] = lhs_temp[stride << 2]; - plhs_ptr[5] = lhs_temp[5 * stride]; - plhs_ptr += mr; - lhs_temp++; - } - - lhs_ptr += mr * stride; - } - break; - case 4: - for (int i = 0; i < nm; i++) - { - int nk = kb >> 2; - int rk = kb & 0x03; - - const float *lhs_temp = lhs_ptr; - const int _stride = stride << 2; - - if (nk > 0) - { -#if __aarch64__ - asm volatile("0:\n" - "mov x0, %[lhs_temp]\n" - - "ld1 {v4.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v5.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v7.4s}, [x0]\n" - - "zip1 v28.4s, v4.4s, v6.4s\n" - "zip2 v30.4s, v4.4s, v6.4s\n" - "zip1 v29.4s, v5.4s, v7.4s\n" - "zip2 v31.4s, v5.4s, v7.4s\n" - "zip1 v4.4s, v28.4s, v29.4s\n" - "zip2 v5.4s, v28.4s, v29.4s\n" - "zip1 v6.4s, v30.4s, v31.4s\n" - "zip2 v7.4s, v30.4s, v31.4s\n" - - "st1 {v4.4s}, [%[plhs_ptr]], #16\n" - "st1 {v5.4s}, [%[plhs_ptr]], #16\n" - "st1 {v6.4s}, [%[plhs_ptr]], #16\n" - "st1 {v7.4s}, [%[plhs_ptr]], #16\n" - - "subs %[nk], %[nk], #1\n" - "add %[lhs_temp], %[lhs_temp], #16\n" - "bne 0b\n" - : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "x0", "v4", "v5", "v6", "v7", "v28", "v29", "v30", "v31"); -#else // __aarch64__ - asm volatile("0:\n" - "mov r0, %[lhs_temp]\n" - - "vld1.f32 {d8-d9}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d10-d11}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - - "vzip.32 q4, q6\n" - "vzip.32 q5, q7\n" - "vzip.32 q4, q5\n" - "vzip.32 q6, q7\n" - - "vst1.f32 {d8-d9}, [%[plhs_ptr]]!\n" - "vst1.f32 {d10-d11}, [%[plhs_ptr]]!\n" - "vst1.f32 {d12-d13}, [%[plhs_ptr]]!\n" - "vst1.f32 {d14-d15}, [%[plhs_ptr]]!\n" - - "subs %[nk], %[nk], #1\n" - "add %[lhs_temp], %[lhs_temp], #16\n" - "bne 0b\n" - : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "r0", "q4", "q5", "q6", "q7"); -#endif // __aarch64__ - } - - for (int j = 0; j < rk; j++) - { - plhs_ptr[0] = lhs_temp[0]; - plhs_ptr[1] = lhs_temp[stride]; - plhs_ptr[2] = lhs_temp[stride << 1]; - plhs_ptr[3] = lhs_temp[3 * stride]; - plhs_ptr += mr; - lhs_temp++; - } - - lhs_ptr += mr * stride; - } - break; - default: - break; - } - - if (rm > 0) - { - for (int j = 0; j < kb; j++) - { - for (int i = 0; i < rm; i++) - { - plhs_ptr[i] = lhs_ptr[i * stride]; - } - for (int i = rm; i < mr; i++) - { - plhs_ptr[i] = 0.f; - } - plhs_ptr += mr; - lhs_ptr++; - } - } -} - -void _pack_rowmajor_notrans_rhs(const int nr, const int nb, const int kb, const int stride, - const float *rhs_ptr, float *prhs_ptr) -{ - const int nn = nb / nr; - const int rn = nb % nr; - - switch (nr) - { - case 24: - for (int j = 0; j < nn; j++) - { - const float *rhs_temp = rhs_ptr; - float32x4_t q0, q1, q2, q3, q4, q5; - for (int i = 0; i < kb; i++) - { - q0 = vld1q_f32(rhs_temp); - q1 = vld1q_f32(rhs_temp + 4); - q2 = vld1q_f32(rhs_temp + 8); - q3 = vld1q_f32(rhs_temp + 12); - q4 = vld1q_f32(rhs_temp + 16); - q5 = vld1q_f32(rhs_temp + 20); - vst1q_f32(prhs_ptr, q0); - vst1q_f32(prhs_ptr + 4, q1); - vst1q_f32(prhs_ptr + 8, q2); - vst1q_f32(prhs_ptr + 12, q3); - vst1q_f32(prhs_ptr + 16, q4); - vst1q_f32(prhs_ptr + 20, q5); - - rhs_temp += stride; - prhs_ptr += nr; - } - - rhs_ptr += nr; - } - break; - case 16: - for (int j = 0; j < nn; j++) - { - const float *rhs_temp = rhs_ptr; - float32x4_t q0, q1, q2, q3; - for (int i = 0; i < kb; i++) - { - q0 = vld1q_f32(rhs_temp); - q1 = vld1q_f32(rhs_temp + 4); - q2 = vld1q_f32(rhs_temp + 8); - q3 = vld1q_f32(rhs_temp + 12); - vst1q_f32(prhs_ptr, q0); - vst1q_f32(prhs_ptr + 4, q1); - vst1q_f32(prhs_ptr + 8, q2); - vst1q_f32(prhs_ptr + 12, q3); - - rhs_temp += stride; - prhs_ptr += nr; - } - - rhs_ptr += nr; - } - break; - case 12: - for (int j = 0; j < nn; j++) - { - const float *rhs_temp = rhs_ptr; - float32x4_t q0, q1, q2; - for (int i = 0; i < kb; i++) - { - q0 = vld1q_f32(rhs_temp); - q1 = vld1q_f32(rhs_temp + 4); - q2 = vld1q_f32(rhs_temp + 8); - vst1q_f32(prhs_ptr, q0); - vst1q_f32(prhs_ptr + 4, q1); - vst1q_f32(prhs_ptr + 8, q2); - - rhs_temp += stride; - prhs_ptr += nr; - } - - rhs_ptr += nr; - } - break; - case 8: - for (int j = 0; j < nn; j++) - - { - const float *rhs_temp = rhs_ptr; - float32x4_t q0, q1, q2, q3; - - int i = 0; - for (; i + 1 < kb; i += 2) - { - q0 = vld1q_f32(rhs_temp); - q1 = vld1q_f32(rhs_temp + 4); - q2 = vld1q_f32(rhs_temp + stride); - q3 = vld1q_f32(rhs_temp + stride + 4); - vst1q_f32(prhs_ptr, q0); - vst1q_f32(prhs_ptr + 4, q1); - vst1q_f32(prhs_ptr + 8, q2); - vst1q_f32(prhs_ptr + 12, q3); - - rhs_temp += stride << 1; - prhs_ptr += nr << 1; - } - - for (; i < kb; i++) - { - q0 = vld1q_f32(rhs_temp); - q1 = vld1q_f32(rhs_temp + 4); - vst1q_f32(prhs_ptr, q0); - vst1q_f32(prhs_ptr + 4, q1); - - rhs_temp += stride; - prhs_ptr += nr; - } - - rhs_ptr += nr; - } - break; - case 6: - for (int j = 0; j < nn; j++) - - { - const float *rhs_temp = rhs_ptr; - float32x4_t q0, q2; - float32x2_t q1, q3; - - int i = 0; - for (; i + 1 < kb; i += 2) - { - q0 = vld1q_f32(rhs_temp); - q1 = vld1_f32(rhs_temp + 4); - - q2 = vld1q_f32(rhs_temp + stride); - q3 = vld1_f32(rhs_temp + stride + 4); - vst1q_f32(prhs_ptr, q0); - vst1_f32(prhs_ptr + 4, q1); - vst1q_f32(prhs_ptr + 6, q2); - vst1_f32(prhs_ptr + 10, q3); - - rhs_temp += stride << 1; - prhs_ptr += nr << 1; - } - - for (; i < kb; i++) - { - q0 = vld1q_f32(rhs_temp); - q1 = vld1_f32(rhs_temp + 4); - - vst1q_f32(prhs_ptr, q0); - vst1_f32(prhs_ptr + 4, q1); - - rhs_temp += stride; - prhs_ptr += nr; - } - - rhs_ptr += nr; - } - break; - case 4: - for (int j = 0; j < nn; j++) - - { - const float *rhs_temp = rhs_ptr; - float32x4_t q0, q1, q2, q3; - - int i = 0; - for (; i + 3 < kb; i += 4) - { - q0 = vld1q_f32(rhs_temp); - q1 = vld1q_f32(rhs_temp + stride); - q2 = vld1q_f32(rhs_temp + (stride << 1)); - q3 = vld1q_f32(rhs_temp + (stride * 3)); - vst1q_f32(prhs_ptr, q0); - vst1q_f32(prhs_ptr + 4, q1); - vst1q_f32(prhs_ptr + 8, q2); - vst1q_f32(prhs_ptr + 12, q3); - - rhs_temp += stride << 2; - prhs_ptr += nr << 2; - } - for (; i + 1 < kb; i += 2) - { - q0 = vld1q_f32(rhs_temp); - q1 = vld1q_f32(rhs_temp + stride); - vst1q_f32(prhs_ptr, q0); - vst1q_f32(prhs_ptr + 4, q1); - - rhs_temp += stride << 1; - prhs_ptr += nr << 1; - } - for (; i < kb; i++) - { - q0 = vld1q_f32(rhs_temp); - vst1q_f32(prhs_ptr, q0); - - rhs_temp += stride; - prhs_ptr += nr; - } - - rhs_ptr += nr; - } - break; - default: - break; - } - - if (rn > 0) - { - for (int i = 0; i < kb; i++) - { - for (int j = 0; j < rn; j++) - { - prhs_ptr[j] = rhs_ptr[j]; - } - for (int j = rn; j < nr; j++) - { - prhs_ptr[j] = 0.f; - } - prhs_ptr += nr; - rhs_ptr += stride; - } - } -} - -void _pack_rowmajor_trans_lhs(const int mr, const int mb, const int kb, const int stride, - const float *lhs_ptr, float *plhs_ptr) -{ - _pack_rowmajor_notrans_rhs(mr, mb, kb, stride, lhs_ptr, plhs_ptr); -} - -void _pack_rowmajor_trans_rhs(const int nr, const int nb, const int kb, const int stride, - const float *rhs_ptr, float *prhs_ptr) -{ - _pack_rowmajor_notrans_lhs(nr, nb, kb, stride, rhs_ptr, prhs_ptr); -} - -static inline void _pack_rowmajor_image_subn(const int nr, const int nb, const int stride, - const float *buffer, float *prhs_ptr) -{ - const int nn = nb / nr; - const int rn = nb % nr; - - switch (nr) - { - case 24: - for (int j = 0; j < nn; j++) - { - float32x4_t q0, q1, q2, q3, q4, q5; - q0 = vld1q_f32(buffer); - q1 = vld1q_f32(buffer + 4); - q2 = vld1q_f32(buffer + 8); - q3 = vld1q_f32(buffer + 12); - q4 = vld1q_f32(buffer + 16); - q5 = vld1q_f32(buffer + 20); - vst1q_f32(prhs_ptr, q0); - vst1q_f32(prhs_ptr + 4, q1); - vst1q_f32(prhs_ptr + 8, q2); - vst1q_f32(prhs_ptr + 12, q3); - vst1q_f32(prhs_ptr + 16, q4); - vst1q_f32(prhs_ptr + 20, q5); - prhs_ptr += stride; - buffer += nr; - } - break; - case 16: - for (int j = 0; j < nn; j++) - { - float32x4_t q0, q1, q2, q3; - q0 = vld1q_f32(buffer); - q1 = vld1q_f32(buffer + 4); - q2 = vld1q_f32(buffer + 8); - q3 = vld1q_f32(buffer + 12); - vst1q_f32(prhs_ptr, q0); - vst1q_f32(prhs_ptr + 4, q1); - vst1q_f32(prhs_ptr + 8, q2); - vst1q_f32(prhs_ptr + 12, q3); - prhs_ptr += stride; - buffer += nr; - } - break; - case 12: - for (int j = 0; j < nn; j++) - { - float32x4_t q0, q1, q2; - q0 = vld1q_f32(buffer); - q1 = vld1q_f32(buffer + 4); - q2 = vld1q_f32(buffer + 8); - vst1q_f32(prhs_ptr, q0); - vst1q_f32(prhs_ptr + 4, q1); - vst1q_f32(prhs_ptr + 8, q2); - prhs_ptr += stride; - buffer += nr; - } - break; - case 8: - for (int j = 0; j < nn; j++) - { - float32x4_t q0, q1; - q0 = vld1q_f32(buffer); - q1 = vld1q_f32(buffer + 4); - vst1q_f32(prhs_ptr, q0); - vst1q_f32(prhs_ptr + 4, q1); - prhs_ptr += stride; - buffer += nr; - } - break; - case 6: - for (int j = 0; j < nn; j++) - { - float32x4_t q0; - float32x2_t q1; - q0 = vld1q_f32(buffer); - q1 = vld1_f32(buffer + 4); - vst1q_f32(prhs_ptr, q0); - vst1_f32(prhs_ptr + 4, q1); - prhs_ptr += stride; - buffer += nr; - } - break; - case 4: - for (int j = 0; j < nn; j++) - { - float32x4_t q0; - q0 = vld1q_f32(buffer); - vst1q_f32(prhs_ptr, q0); - prhs_ptr += stride; - buffer += nr; - } - break; - default: - break; - } - - if (rn > 0) - { - for (int j = 0; j < rn; j++) - { - prhs_ptr[j] = buffer[j]; - } - for (int j = rn; j < nr; j++) - { - prhs_ptr[j] = 0.f; - } - } -} - -void _pack_rowmajor_image_rhs(const int nr, const int nb, const int kb, const int k0, const int n0, - convMat_t *input, convMat_t *output, convParams_t *params, - float *prhs_ptr) -{ - const int w = input->w; - const int h = input->h; - const int outw = output->w; - const int kernel_w = params->kernel_w; - const int kernel_h = params->kernel_h; - const int stride_w = params->stride_w; - const int stride_h = params->stride_h; - const int pad_w = params->pad_w; - const int pad_h = params->pad_h; - - const int in_row0 = n0 / outw * stride_h; - const int in_col0 = n0 % outw * stride_w; - int seg0 = outw - n0 % outw; - if (seg0 > nb) - seg0 = nb; - int rows = (nb - seg0 + outw - 1) / outw; - if (seg0) - rows++; - const int segn = (nb - seg0) % outw; - - float row_data[nb]; - - for (int i = k0; i < kb + k0; i++) - { - const int ic = i / (kernel_w * kernel_h); - const int in_row1 = ((i / kernel_w) % kernel_h) * params->dilation_h + in_row0; - const int in_col1 = i % kernel_w * params->dilation_w; - -#ifdef NCNN - const float *input_data = input->data + ic * alignSize(w * h, 16 / sizeof(float)); -#else // NCNN - const float *input_data = input->data + ic * w * h; -#endif // NCNN - float *buffer = row_data; - int in_row = in_row1 - pad_h; - - for (int out_rows = rows; out_rows; out_rows--) - { - int cols = (out_rows != 1 || segn == 0) ? outw : segn; - int in_col = in_col1 - pad_w; - if (out_rows == rows) - { - cols = seg0; - in_col += in_col0; - } - if ((unsigned int)in_row < (unsigned int)h) - { - for (int out_col = cols; out_col; out_col--) - { - if ((unsigned int)in_col < (unsigned int)w) - *(buffer++) = input_data[in_row * w + in_col]; - else - *(buffer++) = 0; - in_col += stride_w; - } - } - else - { - for (int out_col = cols; out_col; out_col--) - { - *(buffer++) = 0; - in_col += stride_w; - } - } - - in_row += stride_h; - } - - _pack_rowmajor_image_subn(nr, nb, nr * kb, row_data, prhs_ptr); - prhs_ptr += nr; - } -} - -void _pack_rowmajor_image_rhs_batch(const int nr, const int nb, const int kb, const int k0, - const int n0, convMat_t *input, convMat_t *output, - convParams_t *params, float *prhs_ptr) -{ - const int w = input->w; - const int h = input->h; - const int c = input->c; - -#ifdef NCNN - const int seg_size = alignSize(output->w * output->h, 16 / sizeof(float)); -#else // NCNN - const int seg_size = output->w * output->h; -#endif // NCNN - -#ifdef NCNN - float *data = input->data + (alignSize(w * h, 16 / sizeof(float)) * c) * (n0 / seg_size); -#else // NCNN - float *data = input->data + (w * h * c) * (n0 / seg_size); -#endif // NCNN - - int seg0 = seg_size - n0 % seg_size; - if (seg0 > nb) - seg0 = nb; - int nseg = (nb - seg0 + seg_size - 1) / seg_size; - if (seg0) - nseg++; - const int segn = (nb - seg0) % seg_size; - convMat_t _input = {w, h, c, 1, data}; - - for (int i = 0; i < nseg; i++) - { - const int _nb = (i == 0 ? seg0 : (i == nseg - 1 ? segn : seg_size)); - const int _n0 = (i == 0 ? seg_size - seg0 : 0); - - _pack_rowmajor_image_rhs(nr, _nb, kb, k0, _n0, &_input, output, params, prhs_ptr); - -#ifdef NCNN - _input.data += alignSize(w * h, 16 / sizeof(float)) * c; -#else // NCNN - _input.data += w * h * c; -#endif // NCNN - } -} - -void _unpack_rowmajor_image_res(const int mb, const int nb, const int m0, const int n0, - convMat_t *input, convMat_t *output, convParams_t *params, - float *pres_ptr) -{ - const int outw = output->w; - const int outh = output->h; - const int w = input->w; - const int kernel_w = params->kernel_w; - const int kernel_h = params->kernel_h; - const int stride_w = params->stride_w; - const int stride_h = params->stride_h; - const int pad_w = params->pad_w; - const int pad_h = params->pad_h; - - const int out_row0 = n0 / w * stride_h; - const int out_col0 = n0 % w * stride_w; - int seg0 = w - n0 % w; - if (seg0 > nb) - seg0 = nb; - int rows = (nb - seg0 + w - 1) / w; - if (seg0) - rows++; - const int segn = (nb - seg0) % w; - - for (int i = m0; i < mb + m0; i++) - { - const int oc = i / (kernel_w * kernel_h); - const int out_row1 = ((i / kernel_w) % kernel_h) * params->dilation_h + out_row0; - const int out_col1 = i % kernel_w * params->dilation_w; - -#ifdef NCNN - float *output_data = output->data + oc * alignSize(outw * outh, 16 / sizeof(float)); -#else // NCNN - float *output_data = output->data + oc * outw * outh; -#endif // NCNN - int out_row = out_row1 - pad_h; - - for (int in_rows = rows; in_rows; in_rows--) - { - int cols = (in_rows != 1 || segn == 0) ? w : segn; - int out_col = out_col1 - pad_w; - if (in_rows == rows) - { - cols = seg0; - out_col += out_col0; - } - if ((unsigned int)out_row < (unsigned int)outh) - { - for (int in_col = cols; in_col; in_col--) - { - if ((unsigned int)out_col < (unsigned int)outw) - output_data[out_row * outw + out_col] += *pres_ptr++; - else - pres_ptr++; - out_col += stride_w; - } - } - else - { - pres_ptr += cols; - } - out_row += stride_h; - } - } -} - -// TODO:v8 & other case. -static inline void _pack_colmajor_image_rhs_sub(const int nr, const int k, const float *buffer, - float *prhs_ptr) -{ - int nk = k >> 2; - int rk = k & 0x03; - - const int _stride = k << 2; - - switch (nr) - { - case 12: - if (nk > 0) - { -#if __aarch64__ - asm volatile("0:\n" - "mov x0, %[buffer]\n" - - "ld1 {v4.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v5.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - - "zip1 v28.4s, v4.4s, v6.4s\n" - "zip2 v30.4s, v4.4s, v6.4s\n" - "zip1 v29.4s, v5.4s, v7.4s\n" - "zip2 v31.4s, v5.4s, v7.4s\n" - "zip1 v4.4s, v28.4s, v29.4s\n" - "zip2 v5.4s, v28.4s, v29.4s\n" - "zip1 v6.4s, v30.4s, v31.4s\n" - "zip2 v7.4s, v30.4s, v31.4s\n" - - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v9.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v11.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - - "zip1 v28.4s, v8.4s, v10.4s\n" - "zip2 v30.4s, v8.4s, v10.4s\n" - "zip1 v29.4s, v9.4s, v11.4s\n" - "zip2 v31.4s, v9.4s, v11.4s\n" - "zip1 v8.4s, v28.4s, v29.4s\n" - "zip2 v9.4s, v28.4s, v29.4s\n" - "zip1 v10.4s, v30.4s, v31.4s\n" - "zip2 v11.4s, v30.4s, v31.4s\n" - - "ld1 {v12.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v13.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v14.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v15.4s}, [x0]\n" - - "zip1 v28.4s, v12.4s, v14.4s\n" - "zip2 v30.4s, v12.4s, v14.4s\n" - "zip1 v29.4s, v13.4s, v15.4s\n" - "zip2 v31.4s, v13.4s, v15.4s\n" - "zip1 v12.4s, v28.4s, v29.4s\n" - "zip2 v13.4s, v28.4s, v29.4s\n" - "zip1 v14.4s, v30.4s, v31.4s\n" - "zip2 v15.4s, v30.4s, v31.4s\n" - - "st1 {v4.4s}, [%[prhs_ptr]], #16\n" - "st1 {v8.4s}, [%[prhs_ptr]], #16\n" - "st1 {v12.4s}, [%[prhs_ptr]], #16\n" - "st1 {v5.4s}, [%[prhs_ptr]], #16\n" - "st1 {v9.4s}, [%[prhs_ptr]], #16\n" - "st1 {v13.4s}, [%[prhs_ptr]], #16\n" - "st1 {v6.4s}, [%[prhs_ptr]], #16\n" - "st1 {v10.4s}, [%[prhs_ptr]], #16\n" - "st1 {v14.4s}, [%[prhs_ptr]], #16\n" - "st1 {v7.4s}, [%[prhs_ptr]], #16\n" - "st1 {v11.4s}, [%[prhs_ptr]], #16\n" - "st1 {v15.4s}, [%[prhs_ptr]], #16\n" - - "subs %[nk], %[nk], #1\n" - "add %[buffer], %[buffer], #16\n" - "bne 0b\n" - : [buffer] "+r"(buffer), [prhs_ptr] "+r"(prhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "x0", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", - "v12", "v13", "v14", "v15", "v28", "v29", "v30", "v31"); -#else // __aarch64__ - asm volatile("0:\n" - "mov r0, %[buffer]\n" - - "vld1.f32 {d8-d9}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d10-d11}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - "add r0, r0, %[_stride]\n" - - "vzip.32 q4, q6\n" - "vzip.32 q5, q7\n" - "vzip.32 q4, q5\n" - "vzip.32 q6, q7\n" - - "vld1.f32 {d16-d17}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d18-d19}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d20-d21}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d22-d23}, [r0]\n" - "add r0, r0, %[_stride]\n" - - "vzip.32 q8, q10\n" - "vzip.32 q9, q11\n" - "vzip.32 q8, q9\n" - "vzip.32 q10, q11\n" - - "vld1.f32 {d24-d25}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d26-d27}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d28-d29}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d30-d31}, [r0]\n" - - "vzip.32 q12, q14\n" - "vzip.32 q13, q15\n" - "vzip.32 q12, q13\n" - "vzip.32 q14, q15\n" - - "vst1.f32 {d8-d9}, [%[prhs_ptr]]!\n" - "vst1.f32 {d16-d17}, [%[prhs_ptr]]!\n" - "vst1.f32 {d24-d25}, [%[prhs_ptr]]!\n" - "vst1.f32 {d10-d11}, [%[prhs_ptr]]!\n" - "vst1.f32 {d18-d19}, [%[prhs_ptr]]!\n" - "vst1.f32 {d26-d27}, [%[prhs_ptr]]!\n" - "vst1.f32 {d12-d13}, [%[prhs_ptr]]!\n" - "vst1.f32 {d20-d21}, [%[prhs_ptr]]!\n" - "vst1.f32 {d28-d29}, [%[prhs_ptr]]!\n" - "vst1.f32 {d14-d15}, [%[prhs_ptr]]!\n" - "vst1.f32 {d22-d23}, [%[prhs_ptr]]!\n" - "vst1.f32 {d30-d31}, [%[prhs_ptr]]!\n" - - "subs %[nk], %[nk], #1\n" - "add %[buffer], %[buffer], #16\n" - "bne 0b\n" - : [buffer] "+r"(buffer), [prhs_ptr] "+r"(prhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "r0", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", - "q12", "q13", "q14", "q15"); -#endif // __aarch64__ - } - - for (int j = 0; j < rk; j++) - { - prhs_ptr[0] = buffer[0]; - prhs_ptr[1] = buffer[k]; - prhs_ptr[2] = buffer[k << 1]; - prhs_ptr[3] = buffer[3 * k]; - prhs_ptr[4] = buffer[k << 2]; - prhs_ptr[5] = buffer[5 * k]; - prhs_ptr[6] = buffer[6 * k]; - prhs_ptr[7] = buffer[7 * k]; - prhs_ptr[8] = buffer[k << 3]; - prhs_ptr[9] = buffer[9 * k]; - prhs_ptr[10] = buffer[10 * k]; - prhs_ptr[11] = buffer[11 * k]; - prhs_ptr += nr; - buffer++; - } - break; - - case 8: - if (nk > 0) - { -#if __aarch64__ - asm volatile("0:\n" - "mov x0, %[buffer]\n" - - "ld1 {v4.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v5.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - - "zip1 v28.4s, v4.4s, v6.4s\n" - "zip2 v30.4s, v4.4s, v6.4s\n" - "zip1 v29.4s, v5.4s, v7.4s\n" - "zip2 v31.4s, v5.4s, v7.4s\n" - "zip1 v4.4s, v28.4s, v29.4s\n" - "zip2 v5.4s, v28.4s, v29.4s\n" - "zip1 v6.4s, v30.4s, v31.4s\n" - "zip2 v7.4s, v30.4s, v31.4s\n" - - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v9.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v11.4s}, [x0]\n" - - "zip1 v28.4s, v8.4s, v10.4s\n" - "zip2 v30.4s, v8.4s, v10.4s\n" - "zip1 v29.4s, v9.4s, v11.4s\n" - "zip2 v31.4s, v9.4s, v11.4s\n" - "zip1 v8.4s, v28.4s, v29.4s\n" - "zip2 v9.4s, v28.4s, v29.4s\n" - "zip1 v10.4s, v30.4s, v31.4s\n" - "zip2 v11.4s, v30.4s, v31.4s\n" - - "st1 {v4.4s}, [%[prhs_ptr]], #16\n" - "st1 {v8.4s}, [%[prhs_ptr]], #16\n" - "st1 {v5.4s}, [%[prhs_ptr]], #16\n" - "st1 {v9.4s}, [%[prhs_ptr]], #16\n" - "st1 {v6.4s}, [%[prhs_ptr]], #16\n" - "st1 {v10.4s}, [%[prhs_ptr]], #16\n" - "st1 {v7.4s}, [%[prhs_ptr]], #16\n" - "st1 {v11.4s}, [%[prhs_ptr]], #16\n" - - "subs %[nk], %[nk], #1\n" - "add %[buffer], %[buffer], #16\n" - "bne 0b\n" - : [buffer] "+r"(buffer), [prhs_ptr] "+r"(prhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "x0", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", - "v28", "v29", "v30", "v31"); -#else // __aarch64__ - asm volatile("0:\n" - "mov r0, %[buffer]\n" - - "vld1.f32 {d8-d9}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d10-d11}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - "add r0, r0, %[_stride]\n" - - "vzip.32 q4, q6\n" - "vzip.32 q5, q7\n" - "vzip.32 q4, q5\n" - "vzip.32 q6, q7\n" - - "vld1.f32 {d16-d17}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d18-d19}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d20-d21}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d22-d23}, [r0]\n" - - "vzip.32 q8, q10\n" - "vzip.32 q9, q11\n" - "vzip.32 q8, q9\n" - "vzip.32 q10, q11\n" - - "vst1.f32 {d8-d9}, [%[prhs_ptr]]!\n" - "vst1.f32 {d16-d17}, [%[prhs_ptr]]!\n" - "vst1.f32 {d10-d11}, [%[prhs_ptr]]!\n" - "vst1.f32 {d18-d19}, [%[prhs_ptr]]!\n" - "vst1.f32 {d12-d13}, [%[prhs_ptr]]!\n" - "vst1.f32 {d20-d21}, [%[prhs_ptr]]!\n" - "vst1.f32 {d14-d15}, [%[prhs_ptr]]!\n" - "vst1.f32 {d22-d23}, [%[prhs_ptr]]!\n" - - "subs %[nk], %[nk], #1\n" - "add %[buffer], %[buffer], #16\n" - "bne 0b\n" - : [buffer] "+r"(buffer), [prhs_ptr] "+r"(prhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "r0", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11"); -#endif // __aarch64__ - } - - for (int j = 0; j < rk; j++) - { - prhs_ptr[0] = buffer[0]; - prhs_ptr[1] = buffer[k]; - prhs_ptr[2] = buffer[k << 1]; - prhs_ptr[3] = buffer[3 * k]; - prhs_ptr[4] = buffer[k << 2]; - prhs_ptr[5] = buffer[5 * k]; - prhs_ptr[6] = buffer[6 * k]; - prhs_ptr[7] = buffer[7 * k]; - prhs_ptr += nr; - buffer++; - } - break; -#if !__aarch64__ - case 6: - if (nk > 0) - { - asm volatile("0:\n" - "mov r0, %[buffer]\n" - - "vld1.f32 {d8-d9}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d10-d11}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d16-d17}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d18-d19}, [r0]\n" - - "vzip.32 q4, q6\n" - "vzip.32 q5, q7\n" - "vzip.32 q4, q5\n" - "vzip.32 q6, q7\n" - "vzip.32 q8, q9\n" - - "vst1.f32 {d8-d9}, [%[prhs_ptr]]!\n" - "vst1.f32 {d16}, [%[prhs_ptr]]!\n" - "vst1.f32 {d10-d11}, [%[prhs_ptr]]!\n" - "vst1.f32 {d17}, [%[prhs_ptr]]!\n" - "vst1.f32 {d12-d13}, [%[prhs_ptr]]!\n" - "vst1.f32 {d18}, [%[prhs_ptr]]!\n" - "vst1.f32 {d14-d15}, [%[prhs_ptr]]!\n" - "vst1.f32 {d19}, [%[prhs_ptr]]!\n" - - "subs %[nk], %[nk], #1\n" - "add %[buffer], %[buffer], #16\n" - "bne 0b\n" - : [buffer] "+r"(buffer), [prhs_ptr] "+r"(prhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "r0", "q4", "q5", "q6", "q7", "q8", "q9"); - } - - for (int j = 0; j < rk; j++) - { - prhs_ptr[0] = buffer[0]; - prhs_ptr[1] = buffer[k]; - prhs_ptr[2] = buffer[k << 1]; - prhs_ptr[3] = buffer[3 * k]; - prhs_ptr[4] = buffer[k << 2]; - prhs_ptr[5] = buffer[5 * k]; - prhs_ptr += nr; - buffer++; - } - break; -#endif // !__aarch64__ - case 4: - if (nk > 0) - { -#if __aarch64__ - asm volatile("0:\n" - "mov x0, %[buffer]\n" - - "ld1 {v4.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v5.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v7.4s}, [x0]\n" - - "zip1 v28.4s, v4.4s, v6.4s\n" - "zip2 v30.4s, v4.4s, v6.4s\n" - "zip1 v29.4s, v5.4s, v7.4s\n" - "zip2 v31.4s, v5.4s, v7.4s\n" - "zip1 v4.4s, v28.4s, v29.4s\n" - "zip2 v5.4s, v28.4s, v29.4s\n" - "zip1 v6.4s, v30.4s, v31.4s\n" - "zip2 v7.4s, v30.4s, v31.4s\n" - - "st1 {v4.4s}, [%[prhs_ptr]], #16\n" - "st1 {v5.4s}, [%[prhs_ptr]], #16\n" - "st1 {v6.4s}, [%[prhs_ptr]], #16\n" - "st1 {v7.4s}, [%[prhs_ptr]], #16\n" - - "subs %[nk], %[nk], #1\n" - "add %[buffer], %[buffer], #16\n" - "bne 0b\n" - : [buffer] "+r"(buffer), [prhs_ptr] "+r"(prhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "x0", "v4", "v5", "v6", "v7", "v28", "v29", "v30", "v31"); -#else // __aarch64__ - asm volatile("0:\n" - "mov r0, %[buffer]\n" - - "vld1.f32 {d8-d9}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d10-d11}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - - "vzip.32 q4, q6\n" - "vzip.32 q5, q7\n" - "vzip.32 q4, q5\n" - "vzip.32 q6, q7\n" - - "vst1.f32 {d8-d9}, [%[prhs_ptr]]!\n" - "vst1.f32 {d10-d11}, [%[prhs_ptr]]!\n" - "vst1.f32 {d12-d13}, [%[prhs_ptr]]!\n" - "vst1.f32 {d14-d15}, [%[prhs_ptr]]!\n" - - "subs %[nk], %[nk], #1\n" - "add %[buffer], %[buffer], #16\n" - "bne 0b\n" - : [buffer] "+r"(buffer), [prhs_ptr] "+r"(prhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "r0", "q4", "q5", "q6", "q7"); -#endif // __aarch64__ - } - - for (int j = 0; j < rk; j++) - { - prhs_ptr[0] = buffer[0]; - prhs_ptr[1] = buffer[k]; - prhs_ptr[2] = buffer[k << 1]; - prhs_ptr[3] = buffer[3 * k]; - prhs_ptr += nr; - buffer++; - } - break; - default: - break; - } -} - -void _pack_colmajor_notrans_lhs(const int mr, const int mb, const int kb, const int stride, - const float *lhs_ptr, float *plhs_ptr) -{ - _pack_rowmajor_notrans_rhs(mr, mb, kb, stride, lhs_ptr, plhs_ptr); -} - -void _pack_colmajor_notrans_rhs(const int nr, const int nb, const int kb, const int stride, - const float *rhs_ptr, float *prhs_ptr) -{ - _pack_rowmajor_notrans_lhs(nr, nb, kb, stride, rhs_ptr, prhs_ptr); -} - -void _pack_colmajor_trans_lhs(const int mr, const int mb, const int kb, const int stride, - const float *lhs_ptr, float *plhs_ptr) -{ - _pack_rowmajor_notrans_lhs(mr, mb, kb, stride, lhs_ptr, plhs_ptr); -} - -void _pack_colmajor_trans_rhs(const int nr, const int nb, const int kb, const int stride, - const float *rhs_ptr, float *prhs_ptr) -{ - _pack_rowmajor_notrans_rhs(nr, nb, kb, stride, rhs_ptr, prhs_ptr); -} - -void _pack_colmajor_image_rhs(const int nr, const int nb, const int kb, const int k0, const int n0, - convMat_t *input, convMat_t *output, convParams_t *params, - float *prhs_ptr) -{ - const int w = input->w; - const int h = input->h; - const int c = input->c; - const int outw = output->w; - const int kernel_w = params->kernel_w; - const int kernel_h = params->kernel_h; - const int stride_w = params->stride_w; - const int stride_h = params->stride_h; - const int pad_w = params->pad_w; - const int pad_h = params->pad_h; - const float *input_data = input->data; - - int c0 = c - k0 % c; - if (c0 > kb) - c0 = kb; - int nc = (kb - c0 + c - 1) / c; - if (c0) - nc++; - const int cn = (kb - c0) % c; - - int seg0 = outw - n0 % outw; - if (seg0 > nb) - seg0 = nb; - int rows = (nb - seg0 + outw - 1) / outw; - if (seg0) - rows++; - const int segn = (nb - seg0) % outw; - - const int in_row0 = n0 / outw * stride_h; - const int in_col0 = n0 % outw * stride_w; - - for (int i = 0; i < nc; i++) - { - const int channels = (i == 0 && c0 != 0) ? c0 : ((i == nc - 1 && cn != 0) ? cn : c); - const int c1 = (i == 0) ? k0 % c : 0; - - float tmp_data[channels * nr]; - int nindex = 0; - float *buffer = tmp_data; - float *prhs_tmp = prhs_ptr; - - const int in_row1 = (k0 / c + i) / kernel_w % kernel_h * params->dilation_h + in_row0; - const int in_col1 = (k0 / c + i) % kernel_w * params->dilation_w; - - int in_row = in_row1 - pad_h; - - for (int out_rows = rows; out_rows; out_rows--) - { - int cols = (out_rows != 1 || segn == 0) ? outw : segn; - int in_col = in_col1 - pad_w; - if (out_rows == rows) - { - cols = seg0; - in_col += in_col0; - } - if ((unsigned int)in_row < (unsigned int)h) - { - for (int out_col = cols; out_col; out_col--) - { - if ((unsigned int)in_col < (unsigned int)w) - { - for (int j = c1; j < c1 + channels; j++) - { - *(buffer++) = input_data[(in_row * w + in_col) * c + j]; - } - } - else - { - for (int j = 0; j < channels; j++) - { - *(buffer++) = 0; - } - } - in_col += stride_w; - - nindex++; - if (nindex == nr) - { - nindex = 0; - buffer = tmp_data; - _pack_colmajor_image_rhs_sub(nr, channels, tmp_data, prhs_tmp); - prhs_tmp += kb * nr; - } - } - } - else - { - for (int out_col = cols; out_col; out_col--) - { - for (int j = 0; j < channels; j++) - { - *(buffer++) = 0; - } - in_col += stride_w; - - nindex++; - if (nindex == nr) - { - nindex = 0; - buffer = tmp_data; - _pack_colmajor_image_rhs_sub(nr, channels, tmp_data, prhs_tmp); - prhs_tmp += kb * nr; - } - } - } - - in_row += stride_h; - } - - if (nindex > 0) - { - float *data = tmp_data; - for (int i = 0; i < channels; i++) - { - for (int j = 0; j < nindex; j++) - { - prhs_tmp[j] = data[j * channels]; - } - for (int j = nindex; j < nr; j++) - { - prhs_tmp[j] = 0.f; - } - prhs_tmp += nr; - data++; - } - } - - prhs_ptr += channels * nr; - } -} - -void _pack_colmajor_image_rhs_batch(const int nr, const int nb, const int kb, const int k0, - const int n0, convMat_t *input, convMat_t *output, - convParams_t *params, float *prhs_ptr) -{ - const int w = input->w; - const int h = input->h; - const int c = input->c; - const int outw = output->w; - const int kernel_w = params->kernel_w; - const int kernel_h = params->kernel_h; - const int stride_w = params->stride_w; - const int stride_h = params->stride_h; - - int c0 = c - k0 % c; - if (c0 > kb) - c0 = kb; - int nc = (kb - c0 + c - 1) / c; - if (c0) - nc++; - const int cn = (kb - c0) % c; - - const int seg_size = output->w * output->h; - - const float *indata = input->data + (w * h * c) * (n0 / seg_size); - - int bseg0 = seg_size - n0 % seg_size; - if (bseg0 > nb) - bseg0 = nb; - int bnseg = (nb - bseg0 + seg_size - 1) / seg_size; - if (bseg0) - bnseg++; - const int bsegn = (nb - bseg0) % seg_size; - - for (int ll = 0; ll < nc; ll++) - { - const float *input_data = indata; - - const int channels = (ll == 0 && c0 != 0) ? c0 : ((ll == nc - 1 && cn != 0) ? cn : c); - const int c1 = (ll == 0) ? k0 % c : 0; - - int nindex = 0; - float *prhs_tmp = prhs_ptr; - float tmp_data[channels * nr]; - float *buffer = tmp_data; - - for (int i = 0; i < bnseg; i++) - { - const int _nb = - ((i == 0 && bseg0 != 0) ? bseg0 : ((i == bnseg - 1 && bsegn != 0) ? bsegn : seg_size)); - const int _n0 = (i == 0 ? n0 % seg_size : 0); - - int seg0 = outw - _n0 % outw; - if (seg0 > _nb) - seg0 = _nb; - int rows = (_nb - seg0 + outw - 1) / outw; - if (seg0) - rows++; - const int segn = (_nb - seg0) % outw; - - const int in_row0 = _n0 / outw * stride_h; - const int in_col0 = _n0 % outw * stride_w; - - const int in_row1 = (k0 / c + ll) / kernel_w % kernel_h + in_row0; - const int in_col1 = (k0 / c + ll) % kernel_w; - - int in_row = in_row1; - - for (int out_rows = rows; out_rows; out_rows--) - { - int cols = (out_rows != 1 || segn == 0) ? outw : segn; - int in_col = in_col1; - if (out_rows == rows) - { - cols = seg0; - in_col += in_col0; - } - if ((unsigned int)in_row < (unsigned int)h) - { - for (int out_col = cols; out_col; out_col--) - { - if ((unsigned int)in_col < (unsigned int)w) - { - for (int j = c1; j < c1 + channels; j++) - { - *(buffer++) = input_data[(in_row * w + in_col) * c + j]; - } - } - else - { - for (int j = 0; j < channels; j++) - { - *(buffer++) = 0; - } - } - in_col += stride_w; - - nindex++; - if (nindex == nr) - { - nindex = 0; - buffer = tmp_data; - _pack_colmajor_image_rhs_sub(nr, channels, tmp_data, prhs_tmp); - prhs_tmp += kb * nr; - } - } - } - else - { - for (int out_col = cols; out_col; out_col--) - { - for (int j = 0; j < channels; j++) - { - *(buffer++) = 0; - } - in_col += stride_w; - - nindex++; - if (nindex == nr) - { - nindex = 0; - buffer = tmp_data; - _pack_colmajor_image_rhs_sub(nr, channels, tmp_data, prhs_tmp); - prhs_tmp += kb * nr; - } - } - } - - in_row += stride_h; - } - - input_data += w * h * c; - } - - if (nindex > 0) - { - float *data = tmp_data; - for (int ii = 0; ii < channels; ii++) - { - for (int jj = 0; jj < nindex; jj++) - { - prhs_tmp[jj] = data[jj * channels]; - } - for (int jj = nindex; jj < nr; jj++) - { - prhs_tmp[jj] = 0.f; - } - prhs_tmp += nr; - data++; - } - } - - prhs_ptr += channels * nr; - } -} - -void _unpack_colmajor_image_res(const int mb, const int nb, const int m0, const int n0, - convMat_t *input, convMat_t *output, convParams_t *params, - float *pres_ptr) -{ - const int w = input->w; - const int outw = output->w; - const int outh = output->h; - const int outc = output->c; - const int kernel_w = params->kernel_w; - const int kernel_h = params->kernel_h; - const int stride_w = params->stride_w; - const int stride_h = params->stride_h; - const int pad_w = params->pad_w; - const int pad_h = params->pad_h; - float *output_data = output->data; - - int c0 = outc - m0 % outc; - if (c0 > mb) - c0 = mb; - int nc = (mb - c0 + outc - 1) / outc; - if (c0) - nc++; - const int cn = (mb - c0) % outc; - - int seg0 = w - n0 % w; - if (seg0 > nb) - seg0 = nb; - int rows = (nb - seg0 + w - 1) / w; - if (seg0) - rows++; - const int segn = (nb - seg0) % w; - - const int out_row0 = n0 / w * stride_h; - const int out_col0 = n0 % w * stride_w; - - for (int i = 0; i < nc; i++) - { - const int channels = (i == 0 && c0 != 0) ? c0 : ((i == nc - 1 && cn != 0) ? cn : outc); - const int c1 = (i == 0) ? m0 % outc : 0; - - float *buffer = pres_ptr; - - const int out_row1 = (m0 / outc + i) / kernel_w % kernel_h * params->dilation_h + out_row0; - const int out_col1 = (m0 / outc + i) % kernel_w * params->dilation_w; - - int out_row = out_row1 - pad_h; - - for (int in_rows = rows; in_rows; in_rows--) - { - int cols = (in_rows != 1 || segn == 0) ? w : segn; - int out_col = out_col1 - pad_w; - if (in_rows == rows) - { - cols = seg0; - out_col += out_col0; - } - if ((unsigned int)out_row < (unsigned int)outh) - { - for (int in_col = cols; in_col; in_col--) - { - if ((unsigned int)out_col < (unsigned int)outw) - { - for (int j = c1; j < c1 + channels; j++) - { - // Note:Data competition for multi-threads - //#pragma omp atomic //low performance - output_data[(out_row * outw + out_col) * outc + j] += *(buffer + j - c1); - } - } - buffer += mb; - out_col += stride_w; - } - } - else - { - buffer += cols * mb; - } - out_row += stride_h; - } - - pres_ptr += channels; - } -} - -void _sparse_pack_rowmajor_image(const int nb, const int k0, const int n0, convMat_t *input, - convMat_t *output, convParams_t *params, float *prhs_ptr) -{ - const int w = input->w; - const int h = input->h; - const int outw = output->w; - const int kernel_w = params->kernel_w; - const int kernel_h = params->kernel_h; - const int stride_w = params->stride_w; - const int stride_h = params->stride_h; - const int pad_w = params->pad_w; - const int pad_h = params->pad_h; - - const int in_row0 = n0 / outw * stride_h; - const int in_col0 = n0 % outw * stride_w; - int seg0 = outw - n0 % outw; - if (seg0 > nb) - seg0 = nb; - int rows = (nb - seg0 + outw - 1) / outw; - if (seg0) - rows++; - const int segn = (nb - seg0) % outw; - - const int ic = k0 / (kernel_w * kernel_h); - const int in_row1 = ((k0 / kernel_w) % kernel_h) * params->dilation_h + in_row0; - const int in_col1 = k0 % kernel_w * params->dilation_w; - -#ifdef NCNN - const float *input_data = input->data + ic * alignSize(w * h, 16 / sizeof(float)); -#else // NCNN - const float *input_data = input->data + ic * w * h; -#endif // NCNN - - int in_row = in_row1 - pad_h; - - for (int out_rows = rows; out_rows; out_rows--) - { - int cols = (out_rows != 1 || segn == 0) ? outw : segn; - int in_col = in_col1 - pad_w; - if (out_rows == rows) - { - cols = seg0; - in_col += in_col0; - } - if ((unsigned int)in_row < (unsigned int)h) - { - for (int out_col = cols; out_col; out_col--) - { - if ((unsigned int)in_col < (unsigned int)w) - *(prhs_ptr++) = input_data[in_row * w + in_col]; - else - *(prhs_ptr++) = 0; - in_col += stride_w; - } - } - else - { - for (int out_col = cols; out_col; out_col--) - { - *(prhs_ptr++) = 0; - in_col += stride_w; - } - } - - in_row += stride_h; - } -} - -} // namespace srcn -} // namespace nnfw diff --git a/compute/ncnn/src/srcn/sgemm_pack.h b/compute/ncnn/src/srcn/sgemm_pack.h deleted file mode 100644 index d64843ebb..000000000 --- a/compute/ncnn/src/srcn/sgemm_pack.h +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __NNFW_SRCN_SGEMM_PACK_H__ -#define __NNFW_SRCN_SGEMM_PACK_H__ - -#include "ncnn/srcn/conv_type.h" - -namespace nnfw -{ -namespace srcn -{ - -void _pack_rowmajor_notrans_lhs(const int mr, const int mb, const int kb, const int stride, - const float *lhs_ptr, float *plhs_ptr); -void _pack_rowmajor_notrans_rhs(const int nr, const int nb, const int kb, const int stride, - const float *rhs_ptr, float *prhs_ptr); -void _pack_rowmajor_trans_lhs(const int mr, const int mb, const int kb, const int stride, - const float *lhs_ptr, float *plhs_ptr); -void _pack_rowmajor_trans_rhs(const int nr, const int nb, const int kb, const int stride, - const float *rhs_ptr, float *prhs_ptr); -void _pack_rowmajor_image_rhs(const int nr, const int nb, const int kb, const int k0, const int n0, - convMat_t *input, convMat_t *output, convParams_t *params, - float *prhs_ptr); -void _pack_rowmajor_image_rhs_batch(const int nr, const int nb, const int kb, const int k0, - const int n0, convMat_t *input, convMat_t *output, - convParams_t *params, float *prhs_ptr); - -void _unpack_rowmajor_image_res(const int mb, const int nb, const int m0, const int n0, - convMat_t *input, convMat_t *output, convParams_t *params, - float *pres_ptr); - -void _pack_colmajor_notrans_lhs(const int mr, const int mb, const int kb, const int stride, - const float *lhs_ptr, float *plhs_ptr); -void _pack_colmajor_notrans_rhs(const int nr, const int nb, const int kb, const int stride, - const float *rhs_ptr, float *prhs_ptr); -void _pack_colmajor_trans_lhs(const int mr, const int mb, const int kb, const int stride, - const float *lhs_ptr, float *plhs_ptr); -void _pack_colmajor_trans_rhs(const int nr, const int nb, const int kb, const int stride, - const float *rhs_ptr, float *prhs_ptr); - -void _pack_colmajor_image_rhs(const int nr, const int nb, const int kb, const int k0, const int n0, - convMat_t *input, convMat_t *output, convParams_t *params, - float *prhs_ptr); - -void _pack_colmajor_image_rhs_batch(const int nr, const int nb, const int kb, const int k0, - const int n0, convMat_t *input, convMat_t *output, - convParams_t *params, float *prhs_ptr); - -void _unpack_colmajor_image_res(const int mb, const int nb, const int m0, const int n0, - convMat_t *input, convMat_t *output, convParams_t *params, - float *pres_ptr); - -void _sparse_pack_rowmajor_image(const int nb, const int k0, const int n0, convMat_t *input, - convMat_t *output, convParams_t *params, float *prhs_ptr); - -} // namespace srcn -} // namespace nnfw - -#endif // __NNFW_SRCN_SGEMM_PACK_H__ diff --git a/compute/ncnn/src/srcn/sgemm_singlethread.cc b/compute/ncnn/src/srcn/sgemm_singlethread.cc deleted file mode 100644 index 3de3e1214..000000000 --- a/compute/ncnn/src/srcn/sgemm_singlethread.cc +++ /dev/null @@ -1,689 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <stdexcept> - -#include "common.h" -#include "sgemm_kernel.h" -#include "sgemm_pack.h" -#include "sgemm_singlethread.h" - -namespace nnfw -{ -namespace srcn -{ - -void sgemm_singlethread::param_init() -{ - if (n_ >= m_) - { - shard_type_ = shardByRow; - } - else - { - shard_type_ = shardByCol; - } - -#if __aarch64__ - if (major_type_ == rowMajor) - { - if (shard_type_ == shardByRow) - { - mr_ = 8; - nr_ = 12; - } - else - { - mr_ = 12; - nr_ = 8; - } - } - else if (major_type_ == colMajor) - { - mr_ = 12; - nr_ = 8; - } -#else // __aarch64__ - if (major_type_ == rowMajor) - { - // it is a bug, but i do not know why as now. - if (ltrans_ == notrans && rtrans_ == trans) - { - mr_ = 4; - nr_ = 12; - } - else - { - mr_ = 6; - nr_ = 8; - } - } - else if (major_type_ == colMajor) - { - mr_ = 8; - nr_ = 6; - } -#endif // __aarch64__ - - int k_div = (nr_ * sizeof_RhsScalar); - int k_sub = (mr_ * nr_ * sizeof_ResScalar); - - int gen_col = GEN_COL / cache_div_; - int min_k = MAX_K / cache_div_; - - const int k_cache = MIN(divup((int)(L1_CACHE_SIZE - k_sub), (int)k_div), min_k); - bk_ = MIN(k_cache, k_); - - if (shard_type_ == shardByCol) - { - int m_sub = (bk_ * nr_ * sizeof_RhsScalar); - int m_div = (sizeof_LhsScalar * bk_ * 2 * cache_div_); - if (L3_CACHE_SIZE) - m_div = (sizeof_LhsScalar * bk_ * 2); - int m_cache = divup((L2_CACHE_SIZE - m_sub), m_div); - bm_ = MIN(m_cache, m_); - - bn_ = MIN(gen_col, n_); - if (L3_CACHE_SIZE) - { - int n_sub = (bk_ * bm_ * sizeof_RhsScalar); - int n_cache = divup((L3_CACHE_SIZE - n_sub), (sizeof_LhsScalar * bk_ * 2)); - bn_ = MIN(n_cache, bn_); - } - } - else - { - int n_sub = (bk_ * mr_ * sizeof_RhsScalar); - int n_div = (sizeof_LhsScalar * bk_ * 2 * cache_div_); - if (L3_CACHE_SIZE) - n_div = (sizeof_LhsScalar * bk_ * 2); - int n_cache = divup((L2_CACHE_SIZE - n_sub), n_div); - bn_ = MIN(n_cache, n_); - - bm_ = MIN(gen_col, m_); - if (L3_CACHE_SIZE) - { - int m_sub = (bk_ * bn_ * sizeof_RhsScalar); - int m_cache = divup((L3_CACHE_SIZE - m_sub), (sizeof_LhsScalar * bk_ * 2)); - bm_ = MIN(m_cache, bm_); - } - } - - nm_ = divup(m_, bm_); - nn_ = divup(n_, bn_); - nk_ = divup(k_, bk_); - - rm_ = m_ % bm_; - rn_ = n_ % bn_; - rk_ = k_ % bk_; -} - -sgemm_singlethread::sgemm_singlethread(sgemmType_t major_type, sgemmTrans_t ltrans, - sgemmTrans_t rtrans, const int m, const int n, const int k, - const float *lhs_data, const float *rhs_data, - float *res_data, int cache_div) - : lhs_data_(lhs_data), rhs_data_(rhs_data), res_data_(res_data), major_type_(major_type), - ltrans_(ltrans), rtrans_(rtrans), m_(m), n_(n), k_(k), cache_div_(cache_div) -{ - param_init(); -} - -sgemm_singlethread::~sgemm_singlethread() {} - -void sgemm_singlethread::run() -{ - if (major_type_ == rowMajor) - { - if (ltrans_ == notrans && rtrans_ == notrans) - { - compute_rowmajor_nn(); - } - else if (ltrans_ == notrans && rtrans_ == trans) - { - compute_rowmajor_nt(); - } - else if (ltrans_ == trans && rtrans_ == notrans) - { - compute_rowmajor_tn(); - } - else if (ltrans_ == trans && rtrans_ == trans) - { - compute_rowmajor_tt(); - } - else - { - throw std::runtime_error{"error trans type."}; - } - } - else if (major_type_ == colMajor) - { - if (ltrans_ == notrans && rtrans_ == notrans) - { - compute_colmajor_nn(); - } - else if (ltrans_ == notrans && rtrans_ == trans) - { - compute_colmajor_nt(); - } - else if (ltrans_ == trans && rtrans_ == notrans) - { - compute_colmajor_tn(); - } - else if (ltrans_ == trans && rtrans_ == trans) - { - compute_colmajor_tt(); - } - else - { - throw std::runtime_error{"error trans type."}; - } - } - else - { - throw std::runtime_error{"error major type."}; - } -} - -void sgemm_singlethread::compute_rowmajor_nn() -{ - int mstride = (bm_ + mr_ - 1) / mr_ * mr_; - int nstride = (bn_ + nr_ - 1) / nr_ * nr_; - - float plhs_ptr[mstride * bk_]; - float prhs_ptr[nstride * bk_]; - - if (shard_type_ == shardByCol) - { - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - for (int l = 0; l < nk_; l++) - { - const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; - - _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr); - - for (int i = 0; i < nm_; i++) - { - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr); - - _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, - &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk); - } - } - } - } - else if (shard_type_ == shardByRow) - { - for (int i = 0; i < nm_; i++) - { - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - for (int l = 0; l < nk_; l++) - { - const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; - - _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr); - - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - - _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr); - - _sgemm_rowmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, - &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk); - } - } - } - } - else - { - throw std::runtime_error{"error shard type."}; - } -} - -void sgemm_singlethread::compute_rowmajor_nt() -{ - int mstride = (bm_ + mr_ - 1) / mr_ * mr_; - int nstride = (bn_ + nr_ - 1) / nr_ * nr_; - - float plhs_ptr[mstride * bk_]; - float prhs_ptr[nstride * bk_]; - - if (shard_type_ == shardByCol) - { - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - for (int l = 0; l < nk_; l++) - { - const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; - - _pack_rowmajor_trans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr); - - for (int i = 0; i < nm_; i++) - { - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr); - - _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, - &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk); - } - } - } - } - else if (shard_type_ == shardByRow) - { - for (int i = 0; i < nm_; i++) - { - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - for (int l = 0; l < nk_; l++) - { - const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; - - _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr); - - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - - _pack_rowmajor_trans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr); - - _sgemm_rowmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, - &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk); - } - } - } - } - else - { - throw std::runtime_error{"error shard type."}; - } -} - -void sgemm_singlethread::compute_rowmajor_tn() -{ - int mstride = (bm_ + mr_ - 1) / mr_ * mr_; - int nstride = (bn_ + nr_ - 1) / nr_ * nr_; - - float plhs_ptr[mstride * bk_]; - float prhs_ptr[nstride * bk_]; - - if (shard_type_ == shardByCol) - { - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - for (int l = 0; l < nk_; l++) - { - const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; - - _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr); - - for (int i = 0; i < nm_; i++) - { - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - _pack_rowmajor_trans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr); - - _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, - &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk); - } - } - } - } - else if (shard_type_ == shardByRow) - { - for (int i = 0; i < nm_; i++) - { - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - for (int l = 0; l < nk_; l++) - { - const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; - - _pack_rowmajor_trans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr); - - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - - _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr); - - _sgemm_rowmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, - &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk); - } - } - } - } - else - { - throw std::runtime_error{"error shard type."}; - } -} - -void sgemm_singlethread::compute_rowmajor_tt() -{ - int mstride = (bm_ + mr_ - 1) / mr_ * mr_; - int nstride = (bn_ + nr_ - 1) / nr_ * nr_; - - float plhs_ptr[mstride * bk_]; - float prhs_ptr[nstride * bk_]; - - if (shard_type_ == shardByCol) - { - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - for (int l = 0; l < nk_; l++) - { - const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; - - _pack_rowmajor_trans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr); - - for (int i = 0; i < nm_; i++) - { - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - _pack_rowmajor_trans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr); - - _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, - &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk); - } - } - } - } - else if (shard_type_ == shardByRow) - { - for (int i = 0; i < nm_; i++) - { - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - for (int l = 0; l < nk_; l++) - { - const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; - - _pack_rowmajor_trans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr); - - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - - _pack_rowmajor_trans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr); - - _sgemm_rowmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, - &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk); - } - } - } - } - else - { - throw std::runtime_error{"error shard type."}; - } -} - -void sgemm_singlethread::compute_colmajor_nn() -{ - int mstride = (bm_ + mr_ - 1) / mr_ * mr_; - int nstride = (bn_ + nr_ - 1) / nr_ * nr_; - - float plhs_ptr[mstride * bk_]; - float prhs_ptr[nstride * bk_]; - - if (shard_type_ == shardByCol) - { - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - for (int l = 0; l < nk_; l++) - { - const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; - - _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr); - - for (int i = 0; i < nm_; i++) - { - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr); - - _sgemm_colmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, - &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk); - } - } - } - } - else if (shard_type_ == shardByRow) - { - for (int i = 0; i < nm_; i++) - { - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - for (int l = 0; l < nk_; l++) - { - const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; - - _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr); - - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - - _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr); - - _sgemm_colmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, - &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk); - } - } - } - } - else - { - throw std::runtime_error{"error shard type."}; - } -} - -void sgemm_singlethread::compute_colmajor_nt() -{ - int mstride = (bm_ + mr_ - 1) / mr_ * mr_; - int nstride = (bn_ + nr_ - 1) / nr_ * nr_; - - float plhs_ptr[mstride * bk_]; - float prhs_ptr[nstride * bk_]; - - if (shard_type_ == shardByCol) - { - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - for (int l = 0; l < nk_; l++) - { - const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; - - _pack_colmajor_trans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr); - - for (int i = 0; i < nm_; i++) - { - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr); - - _sgemm_colmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, - &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk); - } - } - } - } - else if (shard_type_ == shardByRow) - { - for (int i = 0; i < nm_; i++) - { - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - for (int l = 0; l < nk_; l++) - { - const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; - - _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr); - - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - - _pack_colmajor_trans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr); - - _sgemm_colmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, - &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk); - } - } - } - } - else - { - throw std::runtime_error{"error shard type."}; - } -} - -void sgemm_singlethread::compute_colmajor_tn() -{ - int mstride = (bm_ + mr_ - 1) / mr_ * mr_; - int nstride = (bn_ + nr_ - 1) / nr_ * nr_; - - float plhs_ptr[mstride * bk_]; - float prhs_ptr[nstride * bk_]; - - if (shard_type_ == shardByCol) - { - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - for (int l = 0; l < nk_; l++) - { - const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; - - _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr); - - for (int i = 0; i < nm_; i++) - { - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - _pack_colmajor_trans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr); - - _sgemm_colmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, - &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk); - } - } - } - } - else if (shard_type_ == shardByRow) - { - for (int i = 0; i < nm_; i++) - { - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - for (int l = 0; l < nk_; l++) - { - const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; - - _pack_colmajor_trans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr); - - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - - _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr); - - _sgemm_colmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, - &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk); - } - } - } - } - else - { - throw std::runtime_error{"error shard type."}; - } -} - -void sgemm_singlethread::compute_colmajor_tt() -{ - int mstride = (bm_ + mr_ - 1) / mr_ * mr_; - int nstride = (bn_ + nr_ - 1) / nr_ * nr_; - - float plhs_ptr[mstride * bk_]; - float prhs_ptr[nstride * bk_]; - - if (shard_type_ == shardByCol) - { - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - for (int l = 0; l < nk_; l++) - { - const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; - - _pack_colmajor_trans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr); - - for (int i = 0; i < nm_; i++) - { - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - _pack_colmajor_trans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr); - - _sgemm_colmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, - &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk); - } - } - } - } - else if (shard_type_ == shardByRow) - { - for (int i = 0; i < nm_; i++) - { - const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; - - for (int l = 0; l < nk_; l++) - { - const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; - - _pack_colmajor_trans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr); - - for (int j = 0; j < nn_; j++) - { - const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; - - _pack_colmajor_trans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr); - - _sgemm_colmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, - &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk); - } - } - } - } - else - { - throw std::runtime_error{"error shard type."}; - } -} - -} // namespace srcn -} // namespace nnfw diff --git a/compute/ncnn/src/srcn/sgemm_singlethread.h b/compute/ncnn/src/srcn/sgemm_singlethread.h deleted file mode 100644 index 47954e028..000000000 --- a/compute/ncnn/src/srcn/sgemm_singlethread.h +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __NNFW_SRCN_SGEMM_SINGLETHREAD_H__ -#define __NNFW_SRCN_SGEMM_SINGLETHREAD_H__ - -#include "common.h" - -namespace nnfw -{ -namespace srcn -{ - -typedef enum { rowMajor = 0, colMajor } sgemmType_t; - -typedef enum { trans = 0, notrans } sgemmTrans_t; - -class sgemm_singlethread -{ -public: - sgemm_singlethread(sgemmType_t major_type, sgemmTrans_t ltrans, sgemmTrans_t rtrans, const int m, - const int n, const int k, const float *lhs_data, const float *rhs_data, - float *res_data, int cache_div); - ~sgemm_singlethread(); - - void run(); - -private: - void param_init(); - - void compute_rowmajor_nn(); - void compute_rowmajor_nt(); - void compute_rowmajor_tn(); - void compute_rowmajor_tt(); - - void compute_colmajor_nn(); - void compute_colmajor_nt(); - void compute_colmajor_tn(); - void compute_colmajor_tt(); - - const float *lhs_data_; - const float *rhs_data_; - float *res_data_; - - sgemmType_t major_type_; - sgemmTrans_t ltrans_; - sgemmTrans_t rtrans_; - - int m_; - int n_; - int k_; - - int bm_; - int bn_; - int bk_; - - int rm_; - int rn_; - int rk_; - - int nm_; - int nn_; - int nk_; - - int mr_; - int nr_; - - shardType_t shard_type_; - int cache_div_; -}; - -} // namespace srcn -} // namespace nnfw - -#endif // __NNFW_SRCN_SGEMM_SINGLETHREAD_H__ diff --git a/compute/ncnn/src/srcn/sgemm_test.cc b/compute/ncnn/src/srcn/sgemm_test.cc deleted file mode 100644 index 1b10970bb..000000000 --- a/compute/ncnn/src/srcn/sgemm_test.cc +++ /dev/null @@ -1,1883 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <stdio.h> -#include <stdlib.h> -#include <sys/time.h> -#include <unistd.h> - -#include "ncnn/srcn/conv_type.h" -#include "srcn/srcn_conv.h" -//#include "srcn_sgemm.h" -#include "conv_sgemm_singlethread.h" -#include "conv_sgemm_multithreads.h" -//#include "conv_sgemm_batch.h" -#include "sgemm_singlethread.h" -#include "conv_winograd.h" -#include "winograd.h" - -//#include "conv_gpu.h" -//#include "convolutiondepthwise_3x3.h" - -namespace nnfw -{ -namespace srcn -{ - -static void direct_conv_rowmajor(convMat_t *input, convMat_t *output, convMat_t *filter, - convParams_t *params) -{ - const int w = input->w; - const int h = input->h; - const int inch = input->c; - const int outw = output->w; - const int outh = output->h; - const int outch = output->c; - const int kernel_w = params->kernel_w; - const int kernel_h = params->kernel_h; - const int stride_w = params->stride_w; - const int stride_h = params->stride_h; - const int pad_w = params->pad_w; - const int pad_h = params->pad_h; - const int dilation_w = params->dilation_w; - const int dilation_h = params->dilation_h; - const float *input_data = input->data; - const float *filter_data = filter->data; - float *output_data = output->data; - - for (int out_c = 0; out_c < outch; out_c++) - { - for (int out_row = 0; out_row < outh; out_row++) - { - for (int out_col = 0; out_col < outw; out_col++) - { - const int in_col0 = (out_col * stride_w) - pad_w; - const int in_row0 = (out_row * stride_h) - pad_h; - float sum = 0.f; - for (int in_c = 0; in_c < inch; in_c++) - { - for (int filter_y = 0; filter_y < kernel_h; filter_y++) - { - for (int filter_x = 0; filter_x < kernel_w; filter_x++) - { - const int in_col = in_col0 + filter_x * dilation_w; - const int in_row = in_row0 + filter_y * dilation_h; - - if (((unsigned int)in_col < (unsigned int)w) && - ((unsigned int)in_row < (unsigned int)h)) - { - float input_value = input_data[(in_c * h + in_row) * w + in_col]; - float filter_value = - filter_data[((out_c * inch + in_c) * kernel_h + filter_y) * kernel_w + - filter_x]; - sum += (input_value * filter_value); - } - } - } - } - output_data[(out_c * outh + out_row) * outw + out_col] = sum; - } - } - } -} - -static void direct_deconv_rowmajor(convMat_t *input, convMat_t *output, convMat_t *filter, - convParams_t *params) -{ - const int w = input->w; - const int h = input->h; - const int inch = input->c; - const int outw = output->w; - const int outh = output->h; - const int outch = output->c; - const int kernel_w = params->kernel_w; - const int kernel_h = params->kernel_h; - const int stride_w = params->stride_w; - const int stride_h = params->stride_h; - const int pad_w = params->pad_w; - const int pad_h = params->pad_h; - const int dilation_w = params->dilation_w; - const int dilation_h = params->dilation_h; - const float *input_data = input->data; - const float *filter_data = filter->data; - float *output_data = output->data; - - for (int i = 0; i < outw * outh * outch; i++) - { - output_data[i] = 0; - } - - for (int in_c = 0; in_c < inch; in_c++) - { - for (int in_row = 0; in_row < h; in_row++) - { - for (int in_col = 0; in_col < w; in_col++) - { - const int out_col0 = (in_col * stride_w) - pad_w; - const int out_row0 = (in_row * stride_h) - pad_h; - float in_value = input_data[(in_c * h + in_row) * w + in_col]; - for (int out_c = 0; out_c < outch; out_c++) - { - for (int filter_y = 0; filter_y < kernel_h; filter_y++) - { - for (int filter_x = 0; filter_x < kernel_w; filter_x++) - { - const int out_col = out_col0 + filter_x * dilation_w; - const int out_row = out_row0 + filter_y * dilation_h; - - if (((unsigned int)out_col < (unsigned int)outw) && - ((unsigned int)out_row < (unsigned int)outh)) - { - float filter_value = - filter_data[((in_c * outch + out_c) * kernel_h + filter_y) * kernel_w + - filter_x]; - output_data[(out_c * outh + out_row) * outw + out_col] += filter_value * in_value; - } - } - } - } - } - } - } -} - -static void direct_sgemm_rowmajor(int Atrans, int Btrans, int m, int n, int k, float *A, float *B, - float *C) -{ - float *aa, *bb; - - if (Atrans == trans) - { - aa = (float *)malloc(m * k * sizeof(float)); - if (!aa) - return; - - for (int i = 0; i < k; i++) - { - for (int j = 0; j < m; j++) - { - aa[j * k + i] = A[i * m + j]; - } - } - } - else - { - aa = A; - } - - if (Btrans == trans) - { - bb = (float *)malloc(n * k * sizeof(float)); - if (!bb) - return; - - for (int i = 0; i < n; i++) - { - for (int j = 0; j < k; j++) - { - bb[j * n + i] = B[i * k + j]; - } - } - } - else - { - bb = B; - } - - for (int i = 0; i < m; i++) - { - for (int j = 0; j < n; j++) - { - float res = 0.f; - for (int l = 0; l < k; l++) - { - res += aa[i * k + l] * bb[l * n + j]; - } - C[i * n + j] = res; - } - } -} - -/*static void direct_sgemm_kernel(const int k, const int lhs_stride, const int rhs_stride, const int -res_stride, - const float *lhs_ptr, const float *rhs_ptr, float *res_ptr) -{ - int lstride = lhs_stride << 2; - int rstride = rhs_stride << 2; - int estride = res_stride << 2; - int rstep = rstride << 2; - - int nk = (k >> 2) - 1; - - __asm __volatile ( - "movi v16.4s, #0x0\n" - "movi v17.4s, #0x0\n" - "movi v18.4s, #0x0\n" - "movi v19.4s, #0x0\n" - "movi v20.4s, #0x0\n" - "movi v21.4s, #0x0\n" - "movi v22.4s, #0x0\n" - "movi v23.4s, #0x0\n" - "movi v24.4s, #0x0\n" - "movi v25.4s, #0x0\n" - "movi v26.4s, #0x0\n" - "movi v27.4s, #0x0\n" - "movi v28.4s, #0x0\n" - "movi v29.4s, #0x0\n" - "movi v30.4s, #0x0\n" - "movi v31.4s, #0x0\n" - - "mov x0, %[lhs_ptr]\n" - "add %[lhs_ptr], %[lhs_ptr], #16\n" - "ld1 {v0.4s}, [x0]\n" - "add x0, x0, %[lstride]\n" - "ld1 {v1.4s}, [x0]\n" - "add x0, x0, %[lstride]\n" - "ld1 {v2.4s}, [x0]\n" - "add x0, x0, %[lstride]\n" - "ld1 {v3.4s}, [x0]\n" - "add x0, x0, %[lstride]\n" - - "mov x1, %[rhs_ptr]\n" - "add %[rhs_ptr], %[rhs_ptr], %[rstep]\n" - "ld1 {v8.4s, v9.4s}, [x1]\n" - "add x1, x1, %[rstride]\n" - "ld1 {v10.4s, v11.4s}, [x1]\n" - "add x1, x1, %[rstride]\n" - - "1:\n" - "fmla v16.4s, v8.4s, v0.s[0]\n" - "fmla v17.4s, v9.4s, v0.s[0]\n" - "fmla v16.4s, v10.4s, v0.s[1]\n" - "fmla v17.4s, v11.4s, v0.s[1]\n" - "fmla v18.4s, v8.4s, v1.s[0]\n" - "fmla v19.4s, v9.4s, v1.s[0]\n" - "fmla v18.4s, v10.4s, v1.s[1]\n" - "fmla v19.4s, v11.4s, v1.s[1]\n" - "ld1 {v12.4s, v13.4s}, [x1]\n" - "fmla v20.4s, v8.4s, v2.s[0]\n" - "add x1, x1, %[rstride]\n" - "fmla v21.4s, v9.4s, v2.s[0]\n" - "ld1 {v14.4s, v15.4s}, [x1]\n" - "fmla v20.4s, v10.4s, v2.s[1]\n" - "add x1, x1, %[rstride]\n" - "fmla v21.4s, v11.4s, v2.s[1]\n" - "fmla v22.4s, v8.4s, v3.s[0]\n" - "fmla v23.4s, v9.4s, v3.s[0]\n" - "fmla v22.4s, v10.4s, v3.s[1]\n" - "fmla v23.4s, v11.4s, v3.s[1]\n" - - "ld1 {v4.4s}, [x0]\n" - "fmla v16.4s, v12.4s, v0.s[2]\n" - "add x0, x0, %[lstride]\n" - "fmla v17.4s, v13.4s, v0.s[2]\n" - "ld1 {v5.4s}, [x0]\n" - "fmla v16.4s, v14.4s, v0.s[3]\n" - "add x0, x0, %[lstride]\n" - "fmla v17.4s, v15.4s, v0.s[3]\n" - "ld1 {v6.4s}, [x0]\n" - "fmla v18.4s, v12.4s, v1.s[2]\n" - "add x0, x0, %[lstride]\n" - "fmla v19.4s, v13.4s, v1.s[2]\n" - "ld1 {v7.4s}, [x0]\n" - "fmla v18.4s, v14.4s, v1.s[3]\n" - "add x0, x0, %[lstride]\n" - "fmla v19.4s, v15.4s, v1.s[3]\n" - "fmla v20.4s, v12.4s, v2.s[2]\n" - "fmla v21.4s, v13.4s, v2.s[2]\n" - "fmla v20.4s, v14.4s, v2.s[3]\n" - "fmla v21.4s, v15.4s, v2.s[3]\n" - "fmla v22.4s, v12.4s, v3.s[2]\n" - "fmla v23.4s, v13.4s, v3.s[2]\n" - "fmla v22.4s, v14.4s, v3.s[3]\n" - "fmla v23.4s, v15.4s, v3.s[3]\n" - - "mov x0, %[lhs_ptr]\n" - "add %[lhs_ptr], %[lhs_ptr], #16\n" - - "fmla v24.4s, v8.4s, v4.s[0]\n" - "fmla v25.4s, v9.4s, v4.s[0]\n" - "ld1 {v0.4s}, [x0]\n" - "fmla v24.4s, v10.4s, v4.s[1]\n" - "add x0, x0, %[lstride]\n" - "fmla v25.4s, v11.4s, v4.s[1]\n" - "ld1 {v1.4s}, [x0]\n" - "fmla v26.4s, v8.4s, v5.s[0]\n" - "add x0, x0, %[lstride]\n" - "fmla v27.4s, v9.4s, v5.s[0]\n" - "ld1 {v2.4s}, [x0]\n" - "fmla v26.4s, v10.4s, v5.s[1]\n" - "add x0, x0, %[lstride]\n" - "fmla v27.4s, v11.4s, v5.s[1]\n" - "ld1 {v3.4s}, [x0]\n" - "fmla v28.4s, v8.4s, v6.s[0]\n" - "add x0, x0, %[lstride]\n" - "fmla v29.4s, v9.4s, v6.s[0]\n" - "fmla v28.4s, v10.4s, v6.s[1]\n" - "fmla v29.4s, v11.4s, v6.s[1]\n" - "fmla v30.4s, v8.4s, v7.s[0]\n" - "fmla v31.4s, v9.4s, v7.s[0]\n" - "fmla v30.4s, v10.4s, v7.s[1]\n" - "fmla v31.4s, v11.4s, v7.s[1]\n" - - "mov x1, %[rhs_ptr]\n" - "add %[rhs_ptr], %[rhs_ptr], %[rstep]\n" - - "fmla v24.4s, v12.4s, v4.s[2]\n" - "fmla v25.4s, v13.4s, v4.s[2]\n" - "ld1 {v8.4s, v9.4s}, [x1]\n" - "fmla v24.4s, v14.4s, v4.s[3]\n" - "add x1, x1, %[rstride]\n" - "fmla v25.4s, v15.4s, v4.s[3]\n" - "ld1 {v10.4s, v11.4s}, [x1]\n" - "fmla v26.4s, v12.4s, v5.s[2]\n" - "add x1, x1, %[rstride]\n" - "fmla v27.4s, v13.4s, v5.s[2]\n" - "fmla v26.4s, v14.4s, v5.s[3]\n" - "fmla v27.4s, v15.4s, v5.s[3]\n" - "fmla v28.4s, v12.4s, v6.s[2]\n" - "fmla v29.4s, v13.4s, v6.s[2]\n" - "fmla v28.4s, v14.4s, v6.s[3]\n" - "fmla v29.4s, v15.4s, v6.s[3]\n" - "fmla v30.4s, v12.4s, v7.s[2]\n" - "fmla v31.4s, v13.4s, v7.s[2]\n" - "subs %w[nk], %w[nk], #1\n" - "fmla v30.4s, v14.4s, v7.s[3]\n" - "fmla v31.4s, v15.4s, v7.s[3]\n" - "bne 1b\n" - - "fmla v16.4s, v8.4s, v0.s[0]\n" - "fmla v17.4s, v9.4s, v0.s[0]\n" - "fmla v16.4s, v10.4s, v0.s[1]\n" - "fmla v17.4s, v11.4s, v0.s[1]\n" - "fmla v18.4s, v8.4s, v1.s[0]\n" - "fmla v19.4s, v9.4s, v1.s[0]\n" - "fmla v18.4s, v10.4s, v1.s[1]\n" - "fmla v19.4s, v11.4s, v1.s[1]\n" - "ld1 {v12.4s, v13.4s}, [x1]\n" - "fmla v20.4s, v8.4s, v2.s[0]\n" - "add x1, x1, %[rstride]\n" - "fmla v21.4s, v9.4s, v2.s[0]\n" - "ld1 {v14.4s, v15.4s}, [x1]\n" - "fmla v20.4s, v10.4s, v2.s[1]\n" - "add x1, x1, %[rstride]\n" - "fmla v21.4s, v11.4s, v2.s[1]\n" - "fmla v22.4s, v8.4s, v3.s[0]\n" - "fmla v23.4s, v9.4s, v3.s[0]\n" - "fmla v22.4s, v10.4s, v3.s[1]\n" - "fmla v23.4s, v11.4s, v3.s[1]\n" - - "ld1 {v4.4s}, [x0]\n" - "fmla v16.4s, v12.4s, v0.s[2]\n" - "add x0, x0, %[lstride]\n" - "fmla v17.4s, v13.4s, v0.s[2]\n" - "ld1 {v5.4s}, [x0]\n" - "fmla v16.4s, v14.4s, v0.s[3]\n" - "add x0, x0, %[lstride]\n" - "fmla v17.4s, v15.4s, v0.s[3]\n" - "ld1 {v6.4s}, [x0]\n" - "fmla v18.4s, v12.4s, v1.s[2]\n" - "add x0, x0, %[lstride]\n" - "fmla v19.4s, v13.4s, v1.s[2]\n" - "ld1 {v7.4s}, [x0]\n" - "fmla v18.4s, v14.4s, v1.s[3]\n" - "add x0, x0, %[lstride]\n" - "fmla v19.4s, v15.4s, v1.s[3]\n" - "fmla v20.4s, v12.4s, v2.s[2]\n" - "fmla v21.4s, v13.4s, v2.s[2]\n" - "fmla v20.4s, v14.4s, v2.s[3]\n" - "fmla v21.4s, v15.4s, v2.s[3]\n" - "fmla v22.4s, v12.4s, v3.s[2]\n" - "fmla v23.4s, v13.4s, v3.s[2]\n" - "fmla v22.4s, v14.4s, v3.s[3]\n" - "fmla v23.4s, v15.4s, v3.s[3]\n" - - "mov x0, %[res_ptr]\n" - "fmla v24.4s, v8.4s, v4.s[0]\n" - "fmla v25.4s, v9.4s, v4.s[0]\n" - "st1 {v16.4s, v17.4s}, [x0]\n" - "add x0, x0, %[estride]\n" - "fmla v24.4s, v10.4s, v4.s[1]\n" - "fmla v25.4s, v11.4s, v4.s[1]\n" - "st1 {v18.4s, v19.4s}, [x0]\n" - "add x0, x0, %[estride]\n" - "fmla v26.4s, v8.4s, v5.s[0]\n" - "fmla v27.4s, v9.4s, v5.s[0]\n" - "st1 {v20.4s, v21.4s}, [x0]\n" - "add x0, x0, %[estride]\n" - "fmla v26.4s, v10.4s, v5.s[1]\n" - "fmla v27.4s, v11.4s, v5.s[1]\n" - "st1 {v22.4s, v23.4s}, [x0]\n" - "add x0, x0, %[estride]\n" - "fmla v28.4s, v8.4s, v6.s[0]\n" - "fmla v29.4s, v9.4s, v6.s[0]\n" - "fmla v28.4s, v10.4s, v6.s[1]\n" - "fmla v29.4s, v11.4s, v6.s[1]\n" - "fmla v30.4s, v8.4s, v7.s[0]\n" - "fmla v31.4s, v9.4s, v7.s[0]\n" - "fmla v30.4s, v10.4s, v7.s[1]\n" - "fmla v31.4s, v11.4s, v7.s[1]\n" - - "fmla v24.4s, v12.4s, v4.s[2]\n" - "fmla v25.4s, v13.4s, v4.s[2]\n" - "fmla v24.4s, v14.4s, v4.s[3]\n" - "fmla v25.4s, v15.4s, v4.s[3]\n" - "fmla v26.4s, v12.4s, v5.s[2]\n" - "fmla v27.4s, v13.4s, v5.s[2]\n" - "st1 {v24.4s, v25.4s}, [x0]\n" - "add x0, x0, %[estride]\n" - "fmla v26.4s, v14.4s, v5.s[3]\n" - "fmla v27.4s, v15.4s, v5.s[3]\n" - "fmla v28.4s, v12.4s, v6.s[2]\n" - "fmla v29.4s, v13.4s, v6.s[2]\n" - "st1 {v26.4s, v27.4s}, [x0]\n" - "add x0, x0, %[estride]\n" - "fmla v28.4s, v14.4s, v6.s[3]\n" - "fmla v29.4s, v15.4s, v6.s[3]\n" - "fmla v30.4s, v12.4s, v7.s[2]\n" - "fmla v31.4s, v13.4s, v7.s[2]\n" - "st1 {v28.4s, v29.4s}, [x0]\n" - "add x0, x0, %[estride]\n" - "fmla v30.4s, v14.4s, v7.s[3]\n" - "fmla v31.4s, v15.4s, v7.s[3]\n" - "st1 {v30.4s, v31.4s}, [x0]\n" - :[lhs_ptr] "+r" (lhs_ptr), [rhs_ptr] "+r" (rhs_ptr), [res_ptr] "+r" (res_ptr), - [nk] "+r" (nk) - : [lstride] "r" (lstride), [rstride] "r" (rstride), [estride] "r" (estride), [rstep] "r" -(rstep) - : "x0", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", - "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", - "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); -}*/ - -static void direct_conv_colmajor(convMat_t *input, convMat_t *output, convMat_t *filter, - convParams_t *params) -{ - const int w = input->w; - const int h = input->h; - const int inch = input->c; - const int outw = output->w; - const int outh = output->h; - const int outch = output->c; - const int kernel_w = params->kernel_w; - const int kernel_h = params->kernel_h; - const int stride_w = params->stride_w; - const int stride_h = params->stride_h; - const int pad_w = params->pad_w; - const int pad_h = params->pad_h; - const int dilation_w = params->dilation_w; - const int dilation_h = params->dilation_h; - const float *input_data = input->data; - const float *filter_data = filter->data; - float *output_data = output->data; - - for (int out_row = 0; out_row < outh; out_row++) - { - for (int out_col = 0; out_col < outw; out_col++) - { - const int in_col0 = (out_col * stride_w) - pad_w; - const int in_row0 = (out_row * stride_h) - pad_h; - - for (int out_c = 0; out_c < outch; out_c++) - { - float sum = 0.f; - for (int filter_y = 0; filter_y < kernel_h; filter_y++) - { - for (int filter_x = 0; filter_x < kernel_w; filter_x++) - { - const int in_col = in_col0 + filter_x * dilation_w; - const int in_row = in_row0 + filter_y * dilation_h; - - if (((unsigned int)in_col < (unsigned int)w) && - ((unsigned int)in_row < (unsigned int)h)) - { - for (int in_c = 0; in_c < inch; in_c++) - { - float input_value = input_data[(in_row * w + in_col) * inch + in_c]; - float filter_value = - filter_data[((filter_y * kernel_w + filter_x) * inch + in_c) * outch + out_c]; - sum += (input_value * filter_value); - } - } - } - } - output_data[(out_row * outw + out_col) * outch + out_c] = sum; - } - } - } -} - -static void direct_sgemm_colmajor(int Atrans, int Btrans, int m, int n, int k, float *A, float *B, - float *C) -{ - float *aa, *bb; - - if (Atrans) - { - aa = (float *)malloc(m * k * sizeof(float)); - if (!aa) - return; - - for (int i = 0; i < k; i++) - { - for (int j = 0; j < m; j++) - { - aa[i * m + j] = A[j * k + i]; - } - } - } - else - { - aa = A; - } - - if (Btrans) - { - bb = (float *)malloc(n * k * sizeof(float)); - if (!bb) - return; - - for (int i = 0; i < n; i++) - { - for (int j = 0; j < k; j++) - { - bb[i * k + j] = B[j * n + i]; - } - } - } - else - { - bb = B; - } - - for (int i = 0; i < m; i++) - { - for (int j = 0; j < n; j++) - { - float res = 0.f; - for (int l = 0; l < k; l++) - { - res += bb[j * k + l] * aa[l * m + i]; - } - C[j * m + i] = res; - } - } -} - -#if 0 -static int test_sgemm(int m, int n, int k, int loops) -{ - struct timeval start, end; - float total_time = 0.f; - - const int mb = 180; - const int nb = 1440; - const int kb = 512; - - const int mr = 4; - const int nr = 12; - -#if 0 - const int pm = (m + mr - 1) / mr * mr; - const int pn = (n + nr - 1) / nr * nr; - const int pk = k; -#else - const int pm = (mb + mr - 1) / mr * mr; - const int pn = (nb + nr - 1) / nr * nr; - const int pk = kb; -#endif - const int nm = (m + mb - 1) / mb; - const int nn = (n + nb - 1) / nb; - const int nk = (k + kb - 1) / kb; - - const int rm = m % mb; - const int rn = n % nb; - const int rk = k % kb; - - float *A = (float *)malloc(m * k * sizeof(float)); - if(!A) return 0; - - for(int i = 0 ; i < m * k; i++) - { - A[i] = 0.001 + i * 0.000001; - } - - float *B = (float *)malloc(k * n * sizeof(float)); - if(!B) return 0; - - for(int i = 0 ; i < n * k; i++) - { - B[i] = 0.001 - i * 0.000001; - } - - float *C = (float *)malloc(m * n * sizeof(float)); - if(!C) return 0; - -#if 0 - float *PA = (float *)malloc(pm * pk * sizeof(float)); - if(!PA) return 0; - - float *PB = (float *)malloc(pk * pn * sizeof(float)); - if(!PB) return 0; -#else - float PA[pm * pk]; - float PB[pk * pn]; -#endif - - for(int nloop = 0; nloop < loops; nloop++) - - { - gettimeofday(&start, NULL); - - //pack_rowmajor_notrans_lhs(mr, m, k, k, A, PA); - //pack_rowmajor_notrans_rhs(nr, n, k, n, B, PB); -#if 1 - for (int j = 0; j < nn; j++) - { - const int _nb = (j != nn - 1 || rn == 0) ? nb : rn; - for (int l = 0; l < nk; l++) - { - const int _kb = (l != nk - 1 || rk == 0) ? kb : rk; - pack_rowmajor_notrans_rhs(nr, _nb, _kb, 1, n, &B[l * kb * n + j * nb], PB); - for(int i = 0; i < nm; i++) - { - const int _mb = (i != nm - 1 || rm == 0) ? mb : rm; - pack_rowmajor_notrans_lhs(mr, _mb, _kb, 1, k, &A[i * mb * k + l * kb], PA); - sgemm_rowmajor_macro_kernel_divnm(mr, nr, _mb, _nb, _kb, PA, PB, &C[i * mb * n + j * nb], l, n, _kb); - //sgemm_rowmajor_macro_kernel_divnm(mr, nr, _mb, _nb, _kb, &PA[i * mb * k + l * kb], &PB[l * kb * pn + j * nb], &C[i * mb * n + j * nb], l, n, pk); - } - } - } -#else - for (int j = 0; j < nm; j++) - { - const int _mb = (j != nm - 1 || rm == 0) ? mb : rm; - for (int l = 0; l < nk; l++) - { - const int _kb = (l != nk - 1 || rk == 0) ? kb : rk; - pack_rowmajor_notrans_lhs(mr, _mb, _kb, 1, k, &A[j * mb * k + l * kb], PA); - for(int i = 0; i < nn; i++) - { - const int _nb = (i != nn - 1 || rn == 0) ? nb : rn; - pack_rowmajor_notrans_rhs(nr, _nb, _kb, 1, n, &B[l * kb * n + i * nb], PB); - sgemm_rowmajor_macro_kernel_divmn(mr, nr, _mb, _nb, _kb, PA, PB, &C[j * mb * n + i * nb], l, n, _kb); - //sgemm_rowmajor_macro_kernel_divmn(mr, nr, _mb, _nb, _kb, &PA[i * mb * k + l * kb], &PB[l * kb * pn + j * nb], &C[i * mb * n + j * nb], l, n, pk); - } - } - } -#endif - gettimeofday(&end, NULL); - total_time += ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 + start.tv_usec))/1000; - } - - int div = m * n < 16 ? m * n : 16; - int num = m * n > 64 ? 64 : m * n; - - float *c_ptr = &C[0]; - for(int i = 0; i < num; i++) - { - printf("%f ", c_ptr[i]); - if((i + 1) % div == 0) printf("\n"); - } - - printf("\n"); - - c_ptr = &C[m * n - num]; - for(int i = 0; i < num; i++) - { - printf("%f ", c_ptr[i]); - if((i + 1) % div == 0) printf("\n"); - } - - printf("\n"); - - long long total_size = (long long)m *n * k * 2; - printf("AVER Time consuming: %.2fms, total size: %lld, (GFLOP: %.2f)\n", total_time / loops , total_size, (double)total_size/(total_time / loops)/1000000); - - free(A); - free(B); - free(C); - - //free(PA); - //free(PB); - -} -#endif - -static int test_sgemm(int m, int n, int k, int type, int loops) -{ - struct timeval start, end; - float total_time = 0.f; - - // printf("1.\n"); - - float *A = (float *)malloc(m * k * sizeof(float)); - if (!A) - return 0; - - for (int i = 0; i < m * k; i++) - { - A[i] = 0.001 + i * 0.001; // i * 0.000001; - } - - float *B = (float *)malloc(k * n * sizeof(float)); - if (!B) - return 0; - - for (int i = 0; i < n * k; i++) - { - B[i] = 0.001 - i * 0.001; // - i * 0.000001; - } - - float *C = (float *)malloc(m * n * sizeof(float)); - if (!C) - return 0; - - for (int nloop = 0; nloop < loops; nloop++) - - { - gettimeofday(&start, NULL); - - if (type == 0) - { - // direct_sgemm_rowmajor(notrans, notrans, m, n, k, A, B, C); - direct_sgemm_colmajor(notrans, notrans, m, n, k, A, B, C); - } - - else if (type == 1) - { - class sgemm_singlethread my_gemm(colMajor, notrans, notrans, m, n, k, A, B, C, 1); - my_gemm.run(); - } - - /*else if(type == 2) - { - for(int i = 0; i < m / 8; i++) - { - for(int j = 0; j < n / 8; j++) - { - direct_sgemm_kernel(k, k, n, n, A + i * 8 * k, B + j * 8, C + i * 8 * n + j * 8); - } - } - }*/ - - gettimeofday(&end, NULL); - total_time += - ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 + start.tv_usec)) / 1000; - } - - int div = m * n < 16 ? m * n : 16; - int num = m * n > 64 ? 64 : m * n; - - float *c_ptr = &C[0]; - for (int i = 0; i < num; i++) - { - printf("%f ", c_ptr[i]); - if ((i + 1) % div == 0) - printf("\n"); - } - - printf("\n"); - - c_ptr = &C[m * n - num]; - for (int i = 0; i < num; i++) - { - printf("%f ", c_ptr[i]); - if ((i + 1) % div == 0) - printf("\n"); - } - - printf("\n"); - - long long total_size = (long long)m * n * k * 2; - printf("AVER Time consuming: %.2fms, total size: %lld, (GFLOP: %.2f)\n", total_time / loops, - total_size, (double)total_size / (total_time / loops) / 1000000); - - free(A); - free(B); - free(C); - - return 0; -} - -void weight_tensorflow2caffe(float *out, float *in, int H, int W, int C, int N) -{ // HWCN ---> NCHW - for (int h = 0; h < H; ++h) - { - for (int w = 0; w < W; ++w) - { - for (int c = 0; c < C; ++c) - { - for (int n = 0; n < N; ++n) - { - int index_in = h * W * C * N + w * C * N + c * N + n; - int index_out = n * C * H * W + c * H * W + h * W + w; - // printf("%3d <--- %3d\n", index_out, index_in); - out[index_out] = in[index_in]; - } - } - } - } -} - -void trans_weight2winograd(const convMat_t &_kernel, float **winograd_weight) -{ - const double *G; - const int kernel_size = _kernel.h; - const int channels = _kernel.c; - const int num_output = _kernel.n; - - int tile_h_in_, tile_w_in_; - int M, N; - - /*Step 1: transfer weight to winograd domain*/ - if (kernel_size == 3) - { - M = winograd_para_3x3s1::M; - N = winograd_para_3x3s1::N; - G = winograd_para_3x3s1::getG(); - } - else - { - M = winograd_para_5x5s1::M; - N = winograd_para_5x5s1::N; - G = winograd_para_5x5s1::getG(); - } - - tile_h_in_ = tile_w_in_ = M; - - float *winograd_g = new float[M * M * N * N]; - if (NULL == winograd_g) - return; - kronecker_product(winograd_g, G, G, M, N, M, N); - - *winograd_weight = new float[tile_h_in_ * tile_w_in_ * channels * num_output]; - - if (NULL == *winograd_weight) - return; - - float *weight_data_tran = new float[_kernel.h * _kernel.w * _kernel.c * _kernel.n]; - if (NULL == weight_data_tran) - return; - weight_tensorflow2caffe(weight_data_tran, _kernel.data, kernel_size, kernel_size, channels, - num_output); - - class sgemm_singlethread sgemm(rowMajor, notrans, trans, tile_h_in_ * tile_w_in_, - channels * num_output, kernel_size * kernel_size, winograd_g, - weight_data_tran, *winograd_weight, 1); - - sgemm.run(); - - delete[] weight_data_tran; - - /*With winograd, original weight data is useless.*/ - delete[] winograd_g; -} - -static int test_conv(const int w, const int h, const int kernel_size, const int stride, - const int inch, const int outch, const int padding, const int conv_type, - const int thread_num, const int loops) -{ - struct timeval start, end; - float total_time = 0.f; - - struct timeval start1, end1; - float total_time1 = 0.f; - - const int dilation = 1; - - const int kernel_dilation = dilation * (kernel_size - 1) + 1; - - convMat_t input; - convMat_t output; - convMat_t filter; - convParams_t params; - - int pad_l, pad_r, pad_t, pad_b; - if (padding) - { - int pad_w = kernel_dilation + (w - 1) / stride * stride - w; - int pad_h = kernel_dilation + (h - 1) / stride * stride - h; - pad_l = pad_w / 2; - pad_r = pad_w - pad_l; - pad_t = pad_h / 2; - pad_b = pad_h - pad_t; - } - else - { - pad_l = pad_r = pad_t = pad_b = 0; - } - - input.w = w; - input.h = h; - input.c = inch; - input.n = 1; -#ifdef NCNN - input.data = - (float *)malloc(alignSize(input.w * input.h, 16 / sizeof(float)) * input.c * sizeof(float)); -#else - input.data = (float *)malloc(input.w * input.h * input.c * sizeof(float)); -#endif - - if (!input.data) - return 0; - - output.w = (w + pad_l + pad_r - kernel_dilation) / stride + 1; - output.h = (h + pad_t + pad_b - kernel_dilation) / stride + 1; - output.c = outch; - output.n = 1; -#ifdef NCNN - output.data = (float *)malloc(alignSize(output.w * output.h, 16 / sizeof(float)) * output.c * - sizeof(float)); -#else - output.data = (float *)malloc(output.w * output.h * output.c * sizeof(float)); -#endif - - if (!output.data) - return 0; - - for (int i = 0; i < output.w * output.h * output.c; i++) - { - output.data[i] = 0; - } - - filter.w = kernel_size; - filter.h = kernel_size; - filter.c = inch; - filter.n = outch; - filter.data = (float *)malloc(filter.w * filter.h * filter.c * filter.n * sizeof(float)); - if (!filter.data) - return 0; - - for (int i = 0; i < input.w * input.h * input.c; i++) - { - input.data[i] = 0.001 + i * 0.000001; - } - -#if 1 - for (int i = 0; i < filter.w * filter.h * filter.c * filter.n; i++) - { - filter.data[i] = 0.001 - i * 0.000001; - } -#else - for (int i = 0; i < filter.w * filter.h * filter.c * filter.n; i++) - { - if ((i + 1) % 15 == 0) - filter.data[i] = 0.001 - i * 0.000001; - else - filter.data[i] = 0; - } -#endif - params.kernel_w = kernel_size; - params.kernel_h = kernel_size; - params.stride_w = stride; - params.stride_h = stride; - params.padding = padding; - params.pad_w = pad_l; - params.pad_h = pad_t; - params.dilation_w = dilation; - params.dilation_h = dilation; - - const int m = output.c; - const int n = output.w * output.h; - const int k = params.kernel_h * params.kernel_w * input.c; - - // ocl_context_t context; - size_t local_min[2]; - /** - if(conv_type == 14 || conv_type == 15 || conv_type == 6) - { - if(init_gpu(&context) < 0) return -1; - //if(conv_type ==14 || conv_type == 5) sgemm_ocltune(&context, m, n, (k < 1024 ? k : - 1024), local_min); - //else if(conv_type == 6) - { - if(kernel_size == 3) directconv_3x3S1_tune(&context, &input, &filter, &output, - local_min); - else if(kernel_size == 1) directconv_1x1S1_tune(&context, &input, &filter, &output, - local_min); - } - //local_min[0] = 1; local_min[1] = 1; - } - **/ - if (conv_type == 0) - { - for (int nloop = 0; nloop < loops; nloop++) - { - gettimeofday(&start, NULL); - - direct_conv_rowmajor(&input, &output, &filter, ¶ms); - // direct_conv_colmajor(&input, &output, &filter, ¶ms); - - gettimeofday(&end, NULL); - total_time += - ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 + start.tv_usec)) / 1000; - } - } - else if (conv_type == 1) - { - for (int nloop = 0; nloop < loops; nloop++) - { - // printf("nloop = %d, thread_num = %d\n", nloop, thread_num); - // class srcn_sgemm my_gemm(input, filter, output, params, thread_num, col_major); - gettimeofday(&start, NULL); - - /*if(thread_num == 1) - { - class conv_sgemm_singlethread my_gemm(input, filter, output, params, col_major); - my_gemm.run(); - } - else - { - class conv_sgemm_multithreads my_gemm(input, filter, output, params, thread_num, - col_major); - my_gemm.run(); - }*/ - - srcn_convolution2D(input, filter, output, params, NULL, thread_num, row_major); - - // printf("sync\n"); - - gettimeofday(&end, NULL); - total_time += - ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 + start.tv_usec)) / 1000; - } - } - else if (conv_type == 2) - { - float *winograd_weight; - - // trans_weight2winograd(filter, &winograd_weight); - - winogradParams_t wparams = {params.kernel_w, - params.kernel_h, - params.stride_w, - params.stride_h, - params.dilation_w, - params.dilation_h, - 1, - w, - h, - input.c, - output.c, - thread_num, - col_major, - filter.data}; - winograd_weight = trans_weight2winograd(wparams); - - for (int nloop = 0; nloop < loops; nloop++) - { - gettimeofday(&start, NULL); - - // class conv_winograd my_sgemm(input, output, params, col_major, winograd_weight, thread_num, - // w * h, n); - // my_sgemm.run(); - - srcn_convolution2D(input, filter, output, params, winograd_weight, thread_num, row_major); - - gettimeofday(&end, NULL); - total_time += - ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 + start.tv_usec)) / 1000; - } - } - else if (conv_type == 3) - { - void *sparse_weight = trans_weight2sparse(filter); - - for (int nloop = 0; nloop < loops; nloop++) - { - gettimeofday(&start, NULL); - - srcn_sparse_convolution2D(input, output, params, sparse_weight, thread_num, row_major); - - gettimeofday(&end, NULL); - total_time += - ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 + start.tv_usec)) / 1000; - } - - sparse_release(outch, sparse_weight); - } /** -else if(conv_type == 4) -{ -#if 0 - cl_int err; - convlib::load_opencl("./libmali.so"); - const int mpad = (m + 4 - 1) / 4 * 4; - const int npad = (n + 4 - 1) / 4 * 4; - cl_mem lhs_gpu = convlib::clCreateBuffer(context.context, CL_MEM_READ_WRITE | -CL_MEM_ALLOC_HOST_PTR, mpad * k * sizeof(float), NULL, &err); - if(err != CL_SUCCESS) - { - printf("err = %d@%s:%d\n", err, __FUNCTION__, __LINE__); - return -1; - } - - cl_image_format rhs_format = {CL_RGBA, CL_FLOAT}; - cl_image_desc desc = - { - CL_MEM_OBJECT_IMAGE2D, - (size_t)npad / 4, - (size_t)k, - 0, 0, - 0, - 0, 0, 0, 0 - }; - cl_mem rhs_gpu = convlib::clCreateImage(context.context, CL_MEM_READ_ONLY | -CL_MEM_ALLOC_HOST_PTR, &rhs_format, &desc, NULL, &err); - if(err != CL_SUCCESS) - { - printf("err = %d@%s:%d\n", err, __FUNCTION__, __LINE__); - return -1; - } - - cl_mem rhs_gpu = convlib::clCreateBuffer(context.context, CL_MEM_READ_WRITE | -CL_MEM_ALLOC_HOST_PTR, npad * k * sizeof(float), NULL, &err); - if(err != CL_SUCCESS) - { - printf("err = %d@%s:%d\n", err, __FUNCTION__, __LINE__); - return -1;; - } - - cl_mem res_gpu = convlib::clCreateBuffer(context.context, CL_MEM_READ_WRITE | -CL_MEM_ALLOC_HOST_PTR, mpad * npad * sizeof(float), NULL, &err); - if(err != CL_SUCCESS) - { - printf("err = %d@%s:%d\n", err, __FUNCTION__, __LINE__); - return -1; - } -#endif - for(int nloop = 0; nloop < loops + 1; nloop++) - { - gettimeofday(&start, NULL); - - //cl_mem _res_gpu = conv2D_gpu_sgemm(&context, &input, &filter, &output, ¶ms, local_min, -lhs_gpu, rhs_gpu, res_gpu); - - //get_result_gpu(&context, output.data + gpu_data_off, _res_gpu, m, n); - srcn_convolution2D_gpu(input, filter, output, params, row_major); - - gettimeofday(&end, NULL); - - if(nloop > 0) total_time += ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 -+ start.tv_usec))/1000; - } -} -else if(conv_type == 5) -{ - - for(int nloop = 0; nloop < loops + 1; nloop++) - { - gettimeofday(&start, NULL); - - //cl_mem res_gpu = conv2D_gpu_sgemm(&context, &input, &filter, &output, ¶ms, local_min); - - //clFlush(context.cmdQueue); - gettimeofday(&start1, NULL); - #if 1 - srcn_convolution2D(input, filter, output, params, NULL, thread_num, row_major - - #endif - //usleep(80 * 1000); - gettimeofday(&end1, NULL); - total_time1 += ((end1.tv_sec * 1000000 + end1.tv_usec) - (start1.tv_sec * 1000000 + -start1.tv_usec))/1000; - - //get_result_gpu(&context, output.data + gpu_data_off, res_gpu, m, n); - - srcn_convolution2D_dpu(input, filter, output, params, row_major); - - gettimeofday(&end, NULL); - if(nloop > 0) total_time += ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 -+ start.tv_usec))/1000; - } -} -else if(conv_type == 6) -{ - for(int nloop = 0; nloop < loops; nloop++) - { - gettimeofday(&start, NULL); - - if(kernel_size == 3 && stride == 1 && padding == 0) - { - conv2D_gpu_directconv_3x3S1(&context, &input, &filter, &output, ¶ms, local_min); - } - else if(kernel_size == 1 && stride == 1 && padding == 0) - { - conv2D_gpu_directconv_1x1S1(&context, &input, &filter, &output, ¶ms, local_min); - } - - gettimeofday(&end, NULL); - total_time += ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 + -start.tv_usec))/1000; - } -}**/ - - int div = m * n < 16 ? m * n : 16; - int num = m * n > 64 ? 64 : m * n; - - if (conv_type < 4) - printf("[CPU RESULT]\n"); - else if (conv_type == 4) - printf("[GPU RESULT]\n"); - else if (conv_type == 5) - printf("[DPU RESULT]\n"); - float *c_ptr = output.data; - for (int i = 0; i < num; i++) - { - printf("%f ", c_ptr[i]); - if ((i + 1) % div == 0) - printf("\n"); - } - - printf("\n"); - - c_ptr = &output.data[m * n - num]; - for (int i = 0; i < num; i++) - { - printf("%f ", c_ptr[i]); - if ((i + 1) % div == 0) - printf("\n"); - } - - printf("\n"); - - long long total_size = (long long)m * n * k * 2; - printf( - "AVER Time consuming: %.2fms, CPU Time consuming: %.2fms, total size: %lld, (GFLOP: %.2f)\n", - total_time / loops, total_time1 / loops, total_size, - (double)total_size / (total_time / loops) / 1000000); - - free(input.data); - free(output.data); - free(filter.data); - - return 0; -} - -static int test_deconv(const int w, const int h, const int kernel_size, const int stride, - const int inch, const int outch, const int padding, const int conv_type, - const int thread_num, const int loops) -{ - struct timeval start, end; - float total_time = 0.f; - - const int dilation = 1; - - const int kernel_dilation = dilation * (kernel_size - 1) + 1; - - convMat_t input; - convMat_t output; - convMat_t filter; - convParams_t params; - - int pad_l, pad_r, pad_t, pad_b; - if (padding) - { - int pad_w = kernel_dilation - 1; - int pad_h = kernel_dilation - 1; - pad_l = pad_w / 2; - pad_r = pad_w - pad_l; - pad_t = pad_h / 2; - pad_b = pad_h - pad_t; - } - else - { - pad_l = pad_r = pad_t = pad_b = 0; - } - - input.w = w; - input.h = h; - input.c = inch; - input.data = (float *)malloc(input.w * input.h * input.c * sizeof(float)); - if (!input.data) - return 0; - - // output.w = (w + pad_l + pad_r - kernel_dilation) / stride + 1; - // output.h = (h + pad_t + pad_b - kernel_dilation) / stride + 1; - output.w = stride * (w - 1) + kernel_dilation - (pad_l + pad_r); - output.h = stride * (h - 1) + kernel_dilation - (pad_t + pad_b); - output.c = outch; - output.data = (float *)malloc(output.w * output.h * output.c * sizeof(float)); - if (!output.data) - return 0; - - filter.w = kernel_size; - filter.h = kernel_size; - filter.c = outch; - filter.n = inch; - filter.data = (float *)malloc(filter.w * filter.h * filter.c * filter.n * sizeof(float)); - if (!filter.data) - return 0; - - for (int i = 0; i < input.w * input.h * input.c; i++) - { - input.data[i] = 0.001 + i * 0.000001; - } - - for (int i = 0; i < filter.w * filter.h * filter.c * filter.n; i++) - { - filter.data[i] = 0.001 - i * 0.000001; - } - - params.kernel_w = kernel_size; - params.kernel_h = kernel_size; - params.stride_w = stride; - params.stride_h = stride; - params.padding = padding; - params.pad_w = pad_l; - params.pad_h = pad_t; - params.dilation_w = dilation; - params.dilation_h = dilation; - - const int m = params.kernel_h * params.kernel_w * output.c; - const int n = input.w * input.h; - const int k = input.c; - - if (conv_type == 0) - { - for (int nloop = 0; nloop < loops; nloop++) - - { - gettimeofday(&start, NULL); - - direct_deconv_rowmajor(&input, &output, &filter, ¶ms); - - gettimeofday(&end, NULL); - total_time += - ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 + start.tv_usec)) / 1000; - } - } - else if (conv_type == 1) - { - for (int nloop = 0; nloop < loops; nloop++) - - { - gettimeofday(&start, NULL); - - for (int i = 0; i < output.w * output.h * output.c; i++) - { - output.data[i] = 0; - } - - srcn_deconvolution2D(input, filter, output, params, thread_num, row_major); - - gettimeofday(&end, NULL); - total_time += - ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 + start.tv_usec)) / 1000; - } - } - - const int output_size = output.w * output.h * output.c; - - int div = output_size < 16 ? output_size : 16; - int num = output_size > 64 ? 64 : output_size; - - float *c_ptr = output.data; - for (int i = 0; i < num; i++) - { - printf("%f ", c_ptr[i]); - if ((i + 1) % div == 0) - printf("\n"); - } - - printf("\n"); - - c_ptr = &output.data[output_size - num]; - for (int i = 0; i < num; i++) - { - printf("%f ", c_ptr[i]); - if ((i + 1) % div == 0) - printf("\n"); - } - - printf("\n"); - - long long total_size = (long long)m * n * k * 2; - printf("AVER Time consuming: %.2fms, total size: %lld, (GFLOP: %.2f)\n", total_time / loops, - total_size, (double)total_size / (total_time / loops) / 1000000); - - free(input.data); - free(output.data); - free(filter.data); - - return 0; -} - -static int test_batch_conv(const int batch, const int w, const int h, const int kernel_size, - const int stride, const int inch, const int outch, const int padding, - const int conv_type, const int thread_num, const int loops) -{ - struct timeval start, end; - float total_time = 0.f; - - const int dilation = 1; - - const int kernel_dilation = dilation * (kernel_size - 1) + 1; - - convMat_t input; - convMat_t output; - convMat_t filter; - convParams_t params; - - int pad_l, pad_r, pad_t, pad_b; - if (padding) - { - int pad_w = kernel_dilation + (w - 1) / stride * stride - w; - int pad_h = kernel_dilation + (h - 1) / stride * stride - h; - pad_l = pad_w / 2; - pad_r = pad_w - pad_l; - pad_t = pad_h / 2; - pad_b = pad_h - pad_t; - } - else - { - pad_l = pad_r = pad_t = pad_b = 0; - } - - input.w = w; - input.h = h; - input.c = inch; - input.n = batch; - input.data = (float *)malloc(input.n * input.w * input.h * input.c * sizeof(float)); - if (!input.data) - return 0; - - output.w = (w + pad_l + pad_r - kernel_dilation) / stride + 1; - output.h = (h + pad_t + pad_b - kernel_dilation) / stride + 1; - output.c = outch; - output.n = batch; - output.data = (float *)malloc(output.n * output.w * output.h * output.c * sizeof(float)); - if (!output.data) - return 0; - - filter.w = kernel_size; - filter.h = kernel_size; - filter.c = inch; - filter.n = outch; - filter.data = (float *)malloc(filter.w * filter.h * filter.c * filter.n * sizeof(float)); - if (!filter.data) - return 0; - - for (int i = 0; i < input.w * input.h * input.c * input.n; i++) - { - input.data[i] = 0.001 + i * 0.000001; - } - - for (int i = 0; i < filter.w * filter.h * filter.c * filter.n; i++) - { - filter.data[i] = 0.001 - i * 0.000001; - } - - params.kernel_w = kernel_size; - params.kernel_h = kernel_size; - params.stride_w = stride; - params.stride_h = stride; - params.padding = padding; - params.pad_w = pad_l; - params.pad_h = pad_t; - params.dilation_w = dilation; - params.dilation_h = dilation; - - const int m = output.c; - const int n = output.w * output.h; - const int k = params.kernel_h * params.kernel_w * input.c; - - if (conv_type == 1) - { - for (int nloop = 0; nloop < loops; nloop++) - - { - // printf("nloop = %d, thread_num = %d\n", nloop, thread_num); - // class srcn_sgemm my_gemm(input, filter, output, params, thread_num, col_major); - - gettimeofday(&start, NULL); - - srcn_batch_convolution2D(input, filter, output, params, NULL, thread_num, col_major); - - gettimeofday(&end, NULL); - total_time += - ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 + start.tv_usec)) / 1000; - } - } - else if (conv_type == 2) - { - float *winograd_weight; - - // trans_weight2winograd(filter, &winograd_weight); - - winogradParams_t wparams = {params.kernel_w, - params.kernel_h, - params.stride_w, - params.stride_h, - params.dilation_w, - params.dilation_h, - input.n, - w, - h, - input.c, - output.c, - thread_num, - col_major, - filter.data}; - winograd_weight = trans_weight2winograd(wparams); - - for (int nloop = 0; nloop < loops; nloop++) - - { - gettimeofday(&start, NULL); - - srcn_batch_convolution2D(input, filter, output, params, winograd_weight, thread_num, - col_major); - - gettimeofday(&end, NULL); - total_time += - ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 + start.tv_usec)) / 1000; - } - } - - int div = m * n < 16 ? m * n : 16; - int num = m * n > 64 ? 64 : m * n; - - float *c_ptr = output.data; - for (int i = 0; i < num; i++) - { - printf("%f ", c_ptr[i]); - if ((i + 1) % div == 0) - printf("\n"); - } - - printf("\n"); - - c_ptr = &output.data[m * n * batch - num]; - for (int i = 0; i < num; i++) - { - printf("%f ", c_ptr[i]); - if ((i + 1) % div == 0) - printf("\n"); - } - - printf("\n"); - - long long total_size = (long long)batch * m * n * k * 2; - printf("AVER Time consuming: %.2fms, total size: %lld, (GFLOP: %.2f)\n", total_time / loops, - total_size, (double)total_size / (total_time / loops) / 1000000); - - free(input.data); - free(output.data); - free(filter.data); - - return 0; -} - -static int test_depthwise_conv(const int w, const int h, const int kernel_size, const int stride, - const int inch, const int outch, const int padding, - const int conv_type, const int thread_num, const int loops) -{ - if (outch != inch) - return -1; - struct timeval start, end; - float total_time = 0.f; - - const int dilation = 1; - - const int kernel_dilation = dilation * (kernel_size - 1) + 1; - - convMat_t input; - convMat_t output; - convMat_t filter; - convMat_t bias; - convParams_t params; - - int pad_l, pad_r, pad_t, pad_b; - if (padding) - { - int pad_w = kernel_dilation + (w - 1) / stride * stride - w; - int pad_h = kernel_dilation + (h - 1) / stride * stride - h; - pad_l = pad_w / 2; - pad_r = pad_w - pad_l; - pad_t = pad_h / 2; - pad_b = pad_h - pad_t; - } - else - { - pad_l = pad_r = pad_t = pad_b = 0; - } - - input.w = w; - input.h = h; - input.c = inch; - input.n = 1; -#ifdef NCNN - input.data = - (float *)malloc(alignSize(input.w * input.h, 16 / sizeof(float)) * input.c * sizeof(float)); -#else - input.data = (float *)malloc(input.w * input.h * input.c * sizeof(float)); -#endif - if (!input.data) - return 0; - - output.w = (w + pad_l + pad_r - kernel_dilation) / stride + 1; - output.h = (h + pad_t + pad_b - kernel_dilation) / stride + 1; - output.c = outch; - output.n = 1; - -#ifdef NCNN - output.data = (float *)malloc(alignSize(output.w * output.h, 16 / sizeof(float)) * output.c * - sizeof(float)); -#else - output.data = (float *)malloc(output.w * output.h * output.c * sizeof(float)); -#endif - const int gpu_data_off = output.w * output.h * output.c; - if (!output.data) - return 0; - - for (int i = 0; i < output.w * output.h * output.c; i++) - { - output.data[i] = 1.f; - } - - filter.w = kernel_size; - filter.h = kernel_size; - filter.c = 1; - filter.n = outch; - filter.data = (float *)malloc(filter.w * filter.h * filter.c * filter.n * sizeof(float)); - if (!filter.data) - return 0; - - for (int i = 0; i < input.w * input.h * input.c; i++) - { - input.data[i] = 0.001 + i * 0.000001; - } - - for (int i = 0; i < filter.w * filter.h * filter.c * filter.n; i++) - { - filter.data[i] = 0.001 - i * 0.000001; - } - - bias.w = outch; - bias.data = (float *)malloc(bias.w * sizeof(float)); - if (!bias.data) - return 0; - for (int i = 0; i < bias.w; i++) - { - bias.data[i] = 0.f; - } - - params.kernel_w = kernel_size; - params.kernel_h = kernel_size; - params.stride_w = stride; - params.stride_h = stride; - params.padding = padding; - params.pad_w = pad_l; - params.pad_h = pad_t; - params.dilation_w = dilation; - params.dilation_h = dilation; - - const int m = output.c; - const int n = output.w * output.h; - const int k = params.kernel_h * params.kernel_w * input.c; - - // ocl_context_t context; - size_t local_min[2] = {4, 4}; - /** - if(conv_type == 1) - { - if(init_gpu(&context) < 0) return -1; - depthwise_conv_3x3S1_tune(&context, &input, &filter, &output, local_min); - }**/ - - gettimeofday(&start, NULL); - if (conv_type == 0) - srcn_depthwise_conv(input, filter, output, bias, params, 4, - row_major); // convdw3x3s1_neon(input, output, filter, filter); - // else if(conv_type == 1) depthwise_conv_gpu3x3S1(&context, &input, &filter, &output, ¶ms, - // local_min); - else if (conv_type == 2) - { - for (int i = 0; i < input.c; i++) - { - convMat_t _input; - convMat_t _output; - convMat_t _filter; - convParams_t _params = params; - - _input.w = input.w; - _input.h = input.h; - _input.c = 1; - _input.n = 1; -#ifdef NCNN - _input.data = input.data + i * alignSize(input.w * input.h, 16 / sizeof(float)); -#else - _input.data = input.data + i * input.w * input.h; -#endif - - _output.w = output.w; - _output.h = output.h; - _output.c = 1; - _output.n = 1; -#ifdef NCNN - _output.data = output.data + i * alignSize(output.w * output.h, 16 / sizeof(float)); -#else - _output.data = output.data + i * output.w * output.h; -#endif - _filter.w = filter.w; - _filter.h = filter.h; - _filter.c = 1; // filter.c; - _filter.n = 1; // filter.n; - _filter.data = filter.data + i * 9; - - srcn_convolution2D(_input, _filter, _output, _params, NULL, 1, row_major); - // direct_conv_rowmajor(&_input, &_output, &_filter, &_params); - } - } - - gettimeofday(&end, NULL); - total_time += - ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 + start.tv_usec)) / 1000; - - int div = m * n < 16 ? m * n : 16; - int num = m * n > 64 ? 64 : m * n; - - if (conv_type == 0) - printf("[CPU RESULT]\n"); - else if (conv_type == 1) - printf("[GPU RESULT]\n"); - float *c_ptr = output.data; - for (int i = 0; i < num; i++) - { - printf("%f ", c_ptr[i]); - if ((i + 1) % div == 0) - printf("\n"); - } - - printf("\n"); - - c_ptr = &output.data[m * n - num]; - for (int i = 0; i < num; i++) - { - printf("%f ", c_ptr[i]); - if ((i + 1) % div == 0) - printf("\n"); - } - - printf("\n"); - - long long total_size = (long long)m * n * k * 2; - printf("AVER Time consuming: %.2fms, total size: %lld, (GFLOP: %.2f)\n", total_time / loops, - total_size, (double)total_size / (total_time / loops) / 1000000); - - free(input.data); - free(output.data); - free(filter.data); - free(bias.data); - - return 0; -} - -//#define TEST_SGEMM -#define TEST_CONV -//#define TEST_DECONV -//#define TEST_BATCH_CONV -//#define TEST_DEPTHWISE_CONV - -int main(int argc, char **argv) -{ -#ifdef TEST_SGEMM - if (argc < 6) - return 0; - - const int m = atoi(argv[1]); - const int n = atoi(argv[2]); - const int k = atoi(argv[3]); - const int type = atoi(argv[4]); - const int loops = atoi(argv[5]); - - test_sgemm(m, n, k, type, loops); -#elif (defined TEST_CONV) - if (argc < 10) - return 0; - const int w = atoi(argv[1]); - const int h = atoi(argv[2]); - const int kernel_size = atoi(argv[3]); - const int stride = atoi(argv[4]); - const int outch = atoi(argv[5]); - const int inch = atoi(argv[6]); - const int padding = atoi(argv[7]); - const int conv_type = atoi(argv[8]); - const int thread_num = atoi(argv[9]); - int loops = 1; - if (argc > 10) - loops = atoi(argv[10]); - test_conv(w, h, kernel_size, stride, inch, outch, padding, conv_type, thread_num, loops); -#elif (defined TEST_DECONV) - if (argc < 10) - return 0; - const int w = atoi(argv[1]); - const int h = atoi(argv[2]); - const int kernel_size = atoi(argv[3]); - const int stride = atoi(argv[4]); - const int outch = atoi(argv[5]); - const int inch = atoi(argv[6]); - const int padding = atoi(argv[7]); - const int conv_type = atoi(argv[8]); - const int thread_num = atoi(argv[9]); - int loops = 1; - if (argc > 10) - loops = atoi(argv[10]); - test_deconv(w, h, kernel_size, stride, inch, outch, padding, conv_type, thread_num, loops); -#elif (defined TEST_BATCH_CONV) - if (argc < 11) - return 0; - const int batch = atoi(argv[1]); - const int w = atoi(argv[2]); - const int h = atoi(argv[3]); - const int kernel_size = atoi(argv[4]); - const int stride = atoi(argv[5]); - const int outch = atoi(argv[6]); - const int inch = atoi(argv[7]); - const int padding = atoi(argv[8]); - const int conv_type = atoi(argv[9]); - const int thread_num = atoi(argv[10]); - int loops = 1; - if (argc > 11) - loops = atoi(argv[11]); - test_batch_conv(batch, w, h, kernel_size, stride, inch, outch, padding, conv_type, thread_num, - loops); -#elif (defined TEST_DEPTHWISE_CONV) - if (argc < 10) - return 0; - const int w = atoi(argv[1]); - const int h = atoi(argv[2]); - const int kernel_size = atoi(argv[3]); - const int stride = atoi(argv[4]); - const int outch = atoi(argv[5]); - const int inch = atoi(argv[6]); - const int padding = atoi(argv[7]); - const int conv_type = atoi(argv[8]); - const int thread_num = atoi(argv[9]); - int loops = 1; - if (argc > 10) - loops = atoi(argv[10]); - test_depthwise_conv(w, h, kernel_size, stride, inch, outch, padding, conv_type, thread_num, - loops); -#endif - - return 0; -} - -} // namespace srcn -} // namespace nnfw diff --git a/compute/ncnn/src/srcn/srcn_conv.cc b/compute/ncnn/src/srcn/srcn_conv.cc deleted file mode 100644 index bb8e4f13e..000000000 --- a/compute/ncnn/src/srcn/srcn_conv.cc +++ /dev/null @@ -1,614 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifdef _OPENMP -#include <omp.h> -#endif - -#include "ncnn/srcn/conv_type.h" -#include "common.h" -#include "sgemm_singlethread.h" -#include "conv_sgemm_singlethread.h" -#include "conv_sgemm_multithreads.h" -#include "conv_winograd.h" -#include "direct_conv_colmajor.h" -#include "winograd.h" - -#include "deconv_sgemm_multithreads.h" -#include "conv_sparse.h" -#include "conv_winograd_batch.h" - -namespace nnfw -{ -namespace srcn -{ - -static inline void weight_transfer(float *out, float *in, int H, int W, int C, int N) -{ - // HWCN ---> NCHW - for (int h = 0; h < H; ++h) - { - for (int w = 0; w < W; ++w) - { - for (int c = 0; c < C; ++c) - { - for (int n = 0; n < N; ++n) - { - int index_in = h * W * C * N + w * C * N + c * N + n; - int index_out = n * C * H * W + c * H * W + h * W + w; - out[index_out] = in[index_in]; - } - } - } - } -} - -int check_winograd(winogradParams_t ¶ms) -{ - int winograd_flag = - ((params.kernel_w == params.kernel_h) && (params.stride_w == params.stride_h) && - (params.kernel_w == 3 || params.kernel_w == 5) && (params.stride_w == 1) && - (params.dilation_w == 1) && (params.dilation_h == 1)); - - int winograd_channel_cond = 64 * 64; - int winograd_image_cond = 10 * 10; - -#ifdef TIZEN - if (params.num_threads > 1) - { - winograd_channel_cond = 128 * 128; - winograd_image_cond = 20 * 20; - } -#endif // TIZEN - - winograd_flag &= (params.inch * params.outch >= winograd_channel_cond); - - if (params.w > 0 && params.h > 0 && params.batch == 1) - { - winograd_flag &= (params.w * params.h >= winograd_image_cond); - } - - return winograd_flag; -} - -float *trans_weight2winograd(winogradParams_t ¶ms, unsigned int *size = NULL) -{ - int M, N; - const double *G; - - float *winograd_weight; - - int winograd_channel_cond = 64 * 64; - int winograd_image_cond = 10 * 10; - -#ifdef TIZEN - if (params.num_threads > 1) - { - winograd_channel_cond = 128 * 128; - // int winograd_image_cond = 20 * 20; - } -#endif // TIZEN - - int winograd_flag = - ((params.kernel_w == params.kernel_h) && (params.stride_w == params.stride_h) && - (params.kernel_w == 3 || params.kernel_w == 5) && (params.stride_w == 1) && - (params.dilation_w == 1) && (params.dilation_h == 1)); - if (!winograd_flag) - return NULL; - - winograd_flag = (params.inch * params.outch >= winograd_channel_cond); - - if (!winograd_flag) - return NULL; - - if (params.w > 0 && params.h > 0 && params.batch == 1) - { - winograd_flag &= (params.w * params.h >= winograd_image_cond); - if (!winograd_flag) - return NULL; - } - - const int kernel_size = params.kernel_w; - const int inch = params.inch; - const int outch = params.outch; - float *weight_data = params.weight_data; - - /*Step 1: transfer weight to winograd domain*/ - if (kernel_size == 3) - { - if (params.w == 4 && params.batch > 1) - { - M = winograd_para_3x3s1_2::M; - N = winograd_para_3x3s1_2::N; - G = winograd_para_3x3s1_2::getG(); - } - else - { - M = winograd_para_3x3s1::M; - N = winograd_para_3x3s1::N; - G = winograd_para_3x3s1::getG(); - } - } - else - { - M = winograd_para_5x5s1::M; - N = winograd_para_5x5s1::N; - G = winograd_para_5x5s1::getG(); - } - - int tile_h_in_, tile_w_in_; - tile_h_in_ = tile_w_in_ = M; - - if (size) - *size = tile_h_in_ * tile_w_in_ * inch * outch; - - winograd_weight = new float[tile_h_in_ * tile_w_in_ * inch * outch]; - if (!winograd_weight) - return NULL; - - float *winograd_g = new float[M * M * N * N]; - if (!winograd_g) - { - delete[] winograd_weight; - return NULL; - } - - kronecker_product(winograd_g, G, G, M, N, M, N); - - if (params.conv_type == col_major) - { - weight_data = new float[kernel_size * kernel_size * inch * outch]; - if (!weight_data) - { - delete[] winograd_weight; - delete[] winograd_g; - return NULL; - } - weight_transfer(weight_data, params.weight_data, kernel_size, kernel_size, inch, outch); - } - - class sgemm_singlethread sgemm(rowMajor, notrans, trans, tile_h_in_ * tile_w_in_, inch * outch, - kernel_size * kernel_size, winograd_g, weight_data, - winograd_weight, 1); - - sgemm.run(); - - if (params.conv_type == col_major) - delete[] weight_data; - - delete[] winograd_g; - - return winograd_weight; -} - -void winograd_release(float *winograd_weight) -{ - if (winograd_weight) - delete[] winograd_weight; -} - -void srcn_convolution2D(const convMat_t &in_mat, const convMat_t &weights_mat, convMat_t &out_mat, - const convParams_t &in_param, const float *winograd_weight, int num_threads, - convType_t conv_type) -{ - const int outw = out_mat.w; - const int outh = out_mat.h; - const int inch = in_mat.c; - const int outch = out_mat.c; - - int winograd_flag = - ((in_param.kernel_w == in_param.kernel_h) && (in_param.stride_w == in_param.stride_h) && - (in_param.kernel_w == 3 || in_param.kernel_w == 5) && (in_param.stride_w == 1) && - (winograd_weight) && (in_param.dilation_w == 1) && (in_param.dilation_h == 1)); - - int direct_flag = ((conv_type == col_major) && (in_param.stride_w == in_param.stride_h) && - (in_param.dilation_w == 1) && (in_param.dilation_h == 1)); - - int winograd_image_cond = 10 * 10; - int winograd_channel_cond = 64 * 64; - int direct_image_cond = 4 * 4; - int direct_channel_cond = 16 * 16; - -#ifdef TIZEN - if (num_threads > 1) - { - winograd_image_cond = 20 * 20; - winograd_channel_cond = 128 * 128; - } -#endif - - winograd_flag &= - ((outw * outh >= winograd_image_cond) && (inch * outch >= winograd_channel_cond)); - direct_flag &= ((outw * outh <= direct_image_cond) || (inch * outch <= direct_channel_cond)); - - if (num_threads == 1) - { - if (winograd_flag) - { - class conv_winograd conv(in_mat, out_mat, in_param, conv_type, winograd_weight, num_threads, - in_mat.w * in_mat.h, outw * outh, outch); - conv.run(); - } - else if (direct_flag) - { - direct_conv_colmajor(in_mat, out_mat, weights_mat, in_param, num_threads); - } - else - { - class conv_sgemm_singlethread conv(in_mat, weights_mat, out_mat, in_param, conv_type); - conv.run(); - } - } - else if (num_threads > 1) - { - if (winograd_flag) - { - const int npart = num_threads > 4 ? 4 : num_threads; - - omp_set_num_threads(npart); - - if (conv_type == col_major) - { - if (outch < 512) - { - const int _H = (outh + npart - 1) / npart; - - if (_H < in_param.pad_h) - { - class conv_winograd conv(in_mat, out_mat, in_param, conv_type, winograd_weight, 1, - in_mat.w * in_mat.h, outw * outh, outch); - conv.run(); - return; - } - - // const int ih = (_H - 1) * in_param.stride_w + in_param.kernel_w; - // const int oh = _H; - const int nh = (outh + _H - 1) / _H; - int rh = outh % _H; - if (rh == 0) - rh = _H; - -#pragma omp parallel for - for (int i = 0; i < nh; i++) - { - int pad_h_part = 0; - convMat_t in_part; - convMat_t out_part; - const int oh = (i != nh - 1 || rh == 0) ? _H : rh; - const int ih = (oh - 1) * in_param.stride_w + in_param.kernel_w; - - in_part.w = in_mat.w; - in_part.c = inch; - out_part.w = outw; - out_part.c = outch; - in_part.h = ih; - out_part.h = oh; - - int bottom_offset = i * _H - in_param.pad_h; - if (bottom_offset < 0) - { - bottom_offset = 0; - pad_h_part = in_param.pad_h; - } - in_part.data = in_mat.data + bottom_offset * in_mat.w * inch * in_param.stride_w; - if (ih + bottom_offset > in_mat.h) - { - in_part.h = in_mat.h - bottom_offset; - } - - out_part.data = out_mat.data + i * _H * outw * outch; - - convParams_t params = { - in_param.kernel_w, in_param.kernel_h, in_param.stride_w, in_param.stride_h, 1, 1, - in_param.padding, in_param.pad_w, pad_h_part}; - - class conv_winograd conv(in_part, out_part, params, conv_type, winograd_weight, - num_threads, in_mat.w * in_mat.h, outw * outh, outch); - conv.run(); - } - } - else - { - const int _OUTC = (outch + npart - 1) / npart; - - const int nc = (outch + _OUTC - 1) / _OUTC; - int rc = out_mat.c % _OUTC; - if (rc == 0) - rc = _OUTC; - -#pragma omp parallel for - for (int i = 0; i < nc; i++) - { - const float *weight_part; - convMat_t out_part; - - const int oc = (i != nc - 1 || rc == 0) ? _OUTC : rc; - - out_part.w = outw; - out_part.h = outh; - out_part.c = oc; - out_part.data = out_mat.data + i * _OUTC; - weight_part = winograd_weight + i * _OUTC * inch; - class conv_winograd conv(in_mat, out_part, in_param, conv_type, weight_part, - num_threads, in_mat.w * in_mat.h, outw * outh, outch); - conv.run(); - } - } - } - else if (conv_type == row_major) - { -#ifdef TIZEN - if (outch < 512) -#else // TIZEN - if (outh >= 20) -#endif // TIZEN - { - const int _H = (outh + npart - 1) / npart; - - if (_H < in_param.pad_h) - { - class conv_winograd conv(in_mat, out_mat, in_param, conv_type, winograd_weight, 1, - in_mat.w * in_mat.h, outw * outh, outch); - conv.run(); - return; - } - - // const int ih = (_H - 1) * in_param.stride_w + in_param.kernel_w; - // const int oh = _H; - const int nh = (outh + _H - 1) / _H; - int rh = outh % _H; - if (rh == 0) - rh = _H; - -#pragma omp parallel for - for (int i = 0; i < nh; i++) - { - int pad_h_part = 0; - convMat_t in_part; - convMat_t out_part; - const int oh = (i != nh - 1 || rh == 0) ? _H : rh; - const int ih = (oh - 1) * in_param.stride_w + in_param.kernel_w; - - in_part.w = in_mat.w; - in_part.c = inch; - out_part.w = outw; - out_part.c = outch; - in_part.h = ih; - out_part.h = oh; - - int bottom_offset = i * _H - in_param.pad_h; - if (bottom_offset < 0) - { - bottom_offset = 0; - pad_h_part = in_param.pad_h; - } - in_part.data = in_mat.data + bottom_offset * in_mat.w * in_param.stride_w; - if (ih + bottom_offset > in_mat.h) - { - in_part.h = in_mat.h - bottom_offset; - } - - out_part.data = out_mat.data + i * _H * outw; - - convParams_t params = { - in_param.kernel_w, in_param.kernel_h, in_param.stride_w, 1, 1, - in_param.stride_h, in_param.padding, in_param.pad_w, pad_h_part}; - - class conv_winograd conv(in_part, out_part, params, conv_type, winograd_weight, - num_threads, in_mat.w * in_mat.h, outw * outh, outch); - conv.run(); - } - } - else - { - const int _OUTC = (outch + npart - 1) / npart; - - const int nc = (outch + _OUTC - 1) / _OUTC; - int rc = out_mat.c % _OUTC; - if (rc == 0) - rc = _OUTC; - -#pragma omp parallel for - for (int i = 0; i < nc; i++) - { - const float *weight_part; - convMat_t out_part; - - const int oc = (i != nc - 1 || rc == 0) ? _OUTC : rc; - - out_part.w = outw; - out_part.h = outh; - out_part.c = oc; - out_part.data = out_mat.data + i * _OUTC * outw * outh; - weight_part = winograd_weight + i * _OUTC * inch; - class conv_winograd conv(in_mat, out_part, in_param, conv_type, weight_part, - num_threads, in_mat.w * in_mat.h, outw * outh, outch); - conv.run(); - } - } - } - } - else if (direct_flag) - { - direct_conv_colmajor(in_mat, out_mat, weights_mat, in_param, num_threads); - } - else - { - class conv_sgemm_multithreads conv(in_mat, weights_mat, out_mat, in_param, num_threads, - conv_type); - conv.run(); - } - } -} - -void srcn_deconvolution2D(const convMat_t &in_mat, const convMat_t &weights_mat, convMat_t &out_mat, - const convParams_t &in_param, int num_threads, convType_t conv_type) -{ - class deconv_sgemm_multithreads deconv(in_mat, weights_mat, out_mat, in_param, num_threads, - conv_type); - deconv.run(); -} - -void *trans_weight2sparse(const convMat_t &weights_mat) -{ - const int kernel_w = weights_mat.w; - const int kernel_h = weights_mat.h; - const int inch = weights_mat.c; - const int outch = weights_mat.n; - - const int nch = (outch + BCH - 1) / BCH; - const int rch = outch % BCH; - - const float *data = weights_mat.data; - const int klength = inch * kernel_h * kernel_w; - - sparse_weight_t *sparse_weight = new sparse_weight_t[nch]; - if (!sparse_weight) - return NULL; - - for (int i = 0; i < nch; i++) - { - int _bch = (i != nch - 1 || rch == 0) ? BCH : rch; - sparse_weight_t *sparse_weight_n = &sparse_weight[i]; - sparse_weight_n->mxk = 0; - - for (int j = 0; j < _bch; j++) - { - for (int l = 0; l < klength; l++) - { - float val = *(data + (i * BCH + j) * klength + l); - if (val != 0) - { - sparse_weight_n->mxk++; - } - } - } - } - - for (int i = 0; i < nch; i++) - { - int _bch = (i != nch - 1 || rch == 0) ? BCH : rch; - sparse_weight_t *sparse_weight_n = &sparse_weight[i]; - sparse_weight_n->wdata = new weight_data_t[sparse_weight_n->mxk]; - int index = 0; - - for (int l = 0; l < klength; l++) - { - for (int j = 0; j < _bch; j++) - { - float val = *(data + (i * BCH + j) * klength + l); - if (val != 0) - { - sparse_weight_n->wdata[index].m = i * BCH + j; - sparse_weight_n->wdata[index].k = l; - sparse_weight_n->wdata[index++].data = val; - } - } - } - } - - return (void *)sparse_weight; -} - -void sparse_release(const int outch, void *ptr) -{ - sparse_weight_t *sparse_weight = (sparse_weight_t *)ptr; - const int nch = (outch + BCH - 1) / BCH; - - if (!sparse_weight) - return; - - for (int i = 0; i < nch; i++) - { - sparse_weight_t *sparse_weight_n = &sparse_weight[i]; - if (sparse_weight_n->wdata) - delete[] sparse_weight_n->wdata; - } - - if (sparse_weight) - delete[] sparse_weight; -} - -void srcn_sparse_convolution2D(const convMat_t &in_mat, convMat_t &out_mat, - const convParams_t &in_param, const void *sparse_weight, - int number_threas, convType_t conv_type) -{ - class conv_sparse conv(in_mat, out_mat, in_param, (const sparse_weight_t *)sparse_weight, - number_threas, conv_type); - - for (int i = 0; i < out_mat.c * out_mat.h * out_mat.w; i++) - { - *(out_mat.data + i) = 0; - } - - conv.run(); -} - -void srcn_batch_convolution2D(const convMat_t &in_mat, const convMat_t &weights_mat, - convMat_t &out_mat, const convParams_t &in_param, - const float *winograd_weight, int num_threads, convType_t conv_type) -{ - int winograd_flag = (winograd_weight != NULL); - - if (winograd_flag) - { - if (num_threads > 1) - { - omp_set_num_threads(num_threads); - const int batch = in_mat.n; - const int npart = (batch + num_threads - 1) / num_threads; - const int nn = (batch + npart - 1) / npart; - const int rn = batch % npart; - -#pragma omp parallel for - for (int i = 0; i < nn; i++) - { - const int pn = (i != nn - 1 || rn == 0) ? npart : rn; - convMat_t in_mat_part = {in_mat.w, in_mat.h, in_mat.c, pn, - in_mat.data + i * npart * in_mat.w * in_mat.h * in_mat.c}; - convMat_t out_mat_part = {out_mat.w, out_mat.h, out_mat.c, pn, - out_mat.data + i * npart * out_mat.w * out_mat.h * out_mat.c}; - - class conv_winograd_batch conv(in_mat_part, out_mat_part, in_param, conv_type, - winograd_weight, num_threads); - conv.run(); - } - } - else - { - class conv_winograd_batch conv(in_mat, out_mat, in_param, conv_type, winograd_weight, - num_threads); - conv.run(); - } - } - else - { - if (num_threads == 1) - { - class conv_sgemm_singlethread conv(in_mat, weights_mat, out_mat, in_param, conv_type); - conv.run(); - } - else - { - class conv_sgemm_multithreads conv(in_mat, weights_mat, out_mat, in_param, num_threads, - conv_type); - conv.run(); - } - } -} - -} // namespace srcn -} // namespace nnfw diff --git a/compute/ncnn/src/srcn/winograd.h b/compute/ncnn/src/srcn/winograd.h deleted file mode 100644 index 5ad8f1126..000000000 --- a/compute/ncnn/src/srcn/winograd.h +++ /dev/null @@ -1,148 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __NNFW_SRCN_WINOGRAD_H__ -#define __NNFW_SRCN_WINOGRAD_H__ - -namespace nnfw -{ -namespace srcn -{ - -struct winograd_para_3x3s1 -{ - static const int M = 3 + 4 - 1; - static const int N = 3; - - static const double *getG() - { - static const double G[M * N] = { - 1. / 4., 0, 0, -1. / 6., -1. / 6., -1. / 6., -1. / 6., 1. / 6., -1. / 6., - 1. / 24., 1. / 12., 1. / 6., 1. / 24., -1. / 12., 1. / 6., 0, 0, 1, - }; - return G; - } - - static const double *getA() - { - static const double A[M * (M - N + 1)] = { - 1, 0, 0, 0, 1, 1, 1, 1, 1, -1, 1, -1, 1, 2, 4, 8, 1, -2, 4, -8, 0, 0, 0, 1, - }; - return A; - } - - static const double *getB() - { - static const double B[M * M] = { - 4, 0, 0, 0, 0, 0, 0, -4, 4, -2, 2, 4, -5, -4, -4, -1, -1, 0, - 0, 1, -1, 2, -2, -5, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, - }; - return B; - }; -}; - -struct winograd_para_3x3s1_2 -{ - static const int M = 3 + 2 - 1; - static const int N = 3; - - static const double *getG() - { - static const double G[M * N] = { - 1, 0, 0, 1. / 2., 1. / 2., 1. / 2., 1. / 2., -1. / 2., 1. / 2., 0, 0, 1, - }; - return G; - } - - static const double *getA() - { - static const double A[M * (M - N + 1)] = { - 1, 0, 1, 1, 1, -1, 0, 1, - }; - return A; - } - - static const double *getB() - { - static const double B[M * M] = { - 1, 0, 0, 0, 0, 1, -1, -1, -1, 1, 1, 0, 0, 0, 0, 1, - }; - return B; - }; -}; - -struct winograd_para_5x5s1 -{ - static const int M = 5 + 4 - 1; - static const int N = 5; - - static const double *getG() - { - static const double G[M * N] = { - 1, 0, 0, 0, 0, -2. / 9., -2. / 9., -2. / 9., - -2. / 9., -2. / 9., -2. / 9., 2. / 9., -2. / 9., 2. / 9., -2. / 9., 1. / 90., - 1. / 45., 2. / 45., 4. / 45., 8. / 45., 1. / 90., -1. / 45., 2. / 45., -4. / 45., - 8. / 45., 4. / 45., 2. / 45., 1. / 45., 1. / 90., 1. / 180., 4. / 45., -2. / 45., - 1. / 45., -1. / 90., 1. / 180., 0, 0, 0, 0, 1, - }; - return G; - } - - static const double *getA() - { - static const double A[M * (M - N + 1)] = {1, 0, 0, 0, 1, 1, 1, 1, 1, -1, 1, -1, 1, 2, 4, 8, - 1, -2, 4, -8, 8, 4, 2, 1, 8, -4, 2, -1, 0, 0, 0, 1}; - return A; - } - - static const double *getB() - { - static const double B[M * M] = { - 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, - -1, 1. / 2, -1. / 2, 2, -2, -1, -21. / 4, 1, 1, 1. / 4, - 1. / 4, 4, 4, 0, 0, -17. / 4, 17. / 4, -5. / 2, 5. / 2, -5. / 2, - 5. / 2, 21. / 4, 21. / 4, -17. / 4, -17. / 4, -5. / 4, -5. / 4, -5, -5, 0, - 0, 1, -1, 2, -2, 1. / 2, -1. / 2, -21. / 4, -1, 1, - 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, - 0, 0, 0, 1, - }; - return B; - } -}; - -static void kronecker_product(float *out, const double *in1, const double *in2, int m, int n, int p, - int q) -{ - for (int i = 0; i < m; ++i) - { - for (int j = 0; j < n; ++j) - { - for (int k = 0; k < p; ++k) - { - for (int l = 0; l < q; ++l) - { - out[(p * i + k) * n * q + q * j + l] = in1[n * i + j] * in2[k * q + l]; - /* compute in double precision and then convert it back to Dtype for accuracy */ - } - } - } - } -} - -} // namespace srcn -} // namespace nnfw - -#endif // __NNFW_SRCN_WINOGRAD_H__ |