diff options
Diffstat (limited to 'runtimes/libs/srcn/src/conv_sgemm_multithreads.cc')
-rw-r--r-- | runtimes/libs/srcn/src/conv_sgemm_multithreads.cc | 483 |
1 files changed, 483 insertions, 0 deletions
diff --git a/runtimes/libs/srcn/src/conv_sgemm_multithreads.cc b/runtimes/libs/srcn/src/conv_sgemm_multithreads.cc new file mode 100644 index 000000000..91a4533bd --- /dev/null +++ b/runtimes/libs/srcn/src/conv_sgemm_multithreads.cc @@ -0,0 +1,483 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef _OPENMP +#include <omp.h> +#endif + +#include "srcn/conv_type.h" +#include "common.h" +#include "sgemm_kernel.h" +#include "sgemm_pack.h" +#include "conv_sgemm_multithreads.h" + +namespace nnfw +{ +namespace srcn +{ + +void conv_sgemm_multithreads::param_init() +{ +#if __aarch64__ + if (conv_type_ == row_major) + { + mr_ = 8; + nr_ = 12; + } + else if (conv_type_ == col_major) + { +#ifdef BATCH_DILATION_FIX + if (out_mat_.n > 1) + { + + mr_ = 24; + nr_ = 4; + } + else +#endif // BATCH_DILATION_FIX + { + if (m_ > n_) + { + mr_ = 24; + nr_ = 4; + } + else + { + mr_ = 12; + nr_ = 8; + } + } + } +#else // __aarch64__ + if (conv_type_ == row_major) + { + mr_ = 6; + nr_ = 8; + } + else if (conv_type_ == col_major) + { + mr_ = 8; + nr_ = 6; + } +#endif // __aarch64__ + int col = n_; + + if (m_ > n_) + { + shard_type_ = shardByRow; + col = m_; + } + else + { + shard_type_ = shardByCol; + } + + int th_base = divup(col, num_threads_); + + th_base = MIN(MAX(th_base, MIN_COL), MAX_COL); + + int k_div = (nr_ * sizeof_RhsScalar); + int k_sub = (mr_ * nr_ * sizeof_ResScalar); + + const int k_cache = MIN(divup((int)(L1_CACHE_SIZE - k_sub), (int)k_div * 2), MAX_K); + bk_ = MIN(k_cache, k_); + + if (shard_type_ == shardByCol) + { + int m_sub = (bk_ * nr_ * sizeof_RhsScalar); + int m_div = (sizeof_LhsScalar * bk_ * 2 * num_threads_); + if (L3_CACHE_SIZE) + m_div = (sizeof_LhsScalar * bk_ * 2); + int m_cache = divup((L2_CACHE_SIZE - m_sub), m_div); + bm_ = MIN(m_cache, m_); + + bn_ = MIN(th_base, n_); + if (L3_CACHE_SIZE) + { + int n_sub = (bk_ * bm_ * sizeof_RhsScalar); + int n_div = (sizeof_LhsScalar * bk_ * 2 * num_threads_); + int n_cache = divup((L3_CACHE_SIZE - n_sub), n_div); + bn_ = MIN(n_cache, bn_); + } + } + else + { + int n_sub = (bk_ * mr_ * sizeof_LhsScalar); + int n_div = (sizeof_LhsScalar * bk_ * 2 * num_threads_); + if (L3_CACHE_SIZE) + n_div = (sizeof_LhsScalar * bk_ * 2); + int n_cache = divup((L2_CACHE_SIZE - n_sub), n_div); + bn_ = MIN(n_cache, n_); + + bm_ = MIN(th_base, m_); + if (L3_CACHE_SIZE) + { + int m_sub = (bk_ * bn_ * sizeof_RhsScalar); + int m_div = (sizeof_LhsScalar * bk_ * 2 * num_threads_); + int m_cache = divup((L3_CACHE_SIZE - m_sub), m_div); + bm_ = MIN(m_cache, bm_); + } + } + + nm_ = divup(m_, bm_); + nn_ = divup(n_, bn_); + nk_ = divup(k_, bk_); + + rm_ = m_ % bm_; + rn_ = n_ % bn_; + rk_ = k_ % bk_; +} + +conv_sgemm_multithreads::conv_sgemm_multithreads(const convMat_t &in_mat, + const convMat_t &weights_mat, convMat_t &out_mat, + const convParams_t &in_param, int num_threads, + convType_t conv_type) + + : in_mat_(in_mat), weights_mat_(weights_mat), out_mat_(out_mat), in_param_(in_param), + num_threads_(num_threads), conv_type_(conv_type) +{ + m_ = out_mat_.c; +#ifdef NCNN +#ifdef WITH_DPU + np_ = out_mat_.n * alignSize(out_mat_.h * out_mat_.w, 16 / sizeof(float)); + n_ = (np_ + 1) / 2; +#else // WITH_DPU + n_ = out_mat_.n * alignSize(out_mat_.h * out_mat_.w, 16 / sizeof(float)); +#endif // WITH_DPU +#else // NCNN +#ifdef WITH_DPU + np_ = out_mat_.n * out_mat_.w * out_mat_.h; + n_ = (np_ + 1) / 2; +#else // WITH_DPU + n_ = out_mat_.n * out_mat_.w * out_mat_.h; +#endif // WITH_DPU +#endif // NCNN + k_ = in_param_.kernel_h * in_param_.kernel_w * in_mat.c; + + param_init(); + + int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; + int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; + + if (shard_type_ == shardByCol) + { + plhs_buffer_ = new float[lhs_stride * 1 * nm_]; + prhs_buffer_ = new float[rhs_stride * num_threads_]; + } + else + { + plhs_buffer_ = new float[lhs_stride * num_threads_]; + prhs_buffer_ = new float[rhs_stride * 1 * nn_]; + } + + if (plhs_buffer_ == NULL || prhs_buffer_ == NULL) + { + error_ = 1; + } + + if (in_param_.kernel_w != 1 || in_param_.kernel_h != 1 || in_param_.stride_w != 1 || + in_param_.stride_h != 1 || in_param_.padding != 0) + { + need_im2col_ = 1; + } + else + { + need_im2col_ = 0; + } + + omp_set_num_threads(num_threads_); + + error_ = 0; +} + +conv_sgemm_multithreads::~conv_sgemm_multithreads() +{ + if (plhs_buffer_) + delete[] plhs_buffer_; + if (prhs_buffer_) + delete[] prhs_buffer_; +} + +void conv_sgemm_multithreads::run() +{ + if (error_) + return; + + if (shard_type_ == shardByCol && conv_type_ == col_major) + { + compute_colmajor_colshard(); + } + else if (shard_type_ == shardByRow && conv_type_ == col_major) + { + compute_colmajor_rowshard(); + } + else if (shard_type_ == shardByCol && conv_type_ == row_major) + { + compute_rowmajor_colshard(); + } + else if (shard_type_ == shardByRow && conv_type_ == row_major) + { + compute_rowmajor_rowshard(); + } +} + +void conv_sgemm_multithreads::compute_rowmajor_colshard() +{ + int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; + int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + +#pragma omp parallel for + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &weights_mat_.data[i * bm_ * k_ + l * bk_], + &plhs_buffer_[i * lhs_stride]); + } + +#pragma omp parallel for + for (int j = 0; j < nn_; j++) + { + int thread_num = omp_get_thread_num(); + // float *plhs_ptr = &plhs_buffer_[lhs_stride * thread_num]; + float *prhs_ptr = &prhs_buffer_[rhs_stride * thread_num]; + + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + if (need_im2col_) + { + if (out_mat_.n == 1) + { + _pack_rowmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_, const_cast<convMat_t *>(&in_mat_), + &out_mat_, const_cast<convParams_t *>(&in_param_), prhs_ptr); + } + else + { + _pack_rowmajor_image_rhs_batch(nr_, bn, bk, l * bk_, j * bn_, + const_cast<convMat_t *>(&in_mat_), &out_mat_, + const_cast<convParams_t *>(&in_param_), prhs_ptr); + } + } + else + { +#ifdef WITH_DPU + _pack_rowmajor_notrans_rhs(nr_, bn, bk, np_, &in_mat_.data[n_ + l * bk_ * np_ + j * bn_], + prhs_ptr); +#else + _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &in_mat_.data[l * bk_ * n_ + j * bn_], + prhs_ptr); +#endif + } + + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + +#ifdef WITH_DPU + _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, &plhs_buffer_[i * lhs_stride], + prhs_ptr, &out_mat_.data[n_ + i * bm_ * np_ + j * bn_], + l, np_, bk); +#else // WITH_DPU + _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, &plhs_buffer_[i * lhs_stride], + prhs_ptr, &out_mat_.data[i * bm_ * n_ + j * bn_], l, n_, + bk); +#endif // WITH_DPU + } + } + } +} + +void conv_sgemm_multithreads::compute_rowmajor_rowshard() +{ + int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; + int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + +#pragma omp parallel for + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + if (need_im2col_) + { + if (out_mat_.n == 1) + { + _pack_rowmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_, const_cast<convMat_t *>(&in_mat_), + &out_mat_, const_cast<convParams_t *>(&in_param_), + &prhs_buffer_[j * rhs_stride]); + } + else + { + _pack_rowmajor_image_rhs_batch( + nr_, bn, bk, l * bk_, j * bn_, const_cast<convMat_t *>(&in_mat_), &out_mat_, + const_cast<convParams_t *>(&in_param_), &prhs_buffer_[j * rhs_stride]); + } + } + else + { + _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &in_mat_.data[l * bk_ * n_ + j * bn_], + &prhs_buffer_[j * rhs_stride]); + } + } + +#pragma omp parallel for + for (int i = 0; i < nm_; i++) + { + int thread_num = omp_get_thread_num(); + float *plhs_ptr = &plhs_buffer_[lhs_stride * thread_num]; + + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &weights_mat_.data[i * bm_ * k_ + l * bk_], + plhs_ptr); + + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + _sgemm_rowmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, + &prhs_buffer_[j * rhs_stride], + &out_mat_.data[i * bm_ * n_ + j * bn_], l, n_, bk); + } + } + } +} + +void conv_sgemm_multithreads::compute_colmajor_colshard() +{ + int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; + int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + +#pragma omp parallel for + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &weights_mat_.data[l * bk_ * m_ + i * bm_], + &plhs_buffer_[i * lhs_stride]); + } + +#pragma omp parallel for + for (int j = 0; j < nn_; j++) + { + int thread_num = omp_get_thread_num(); + float *prhs_ptr = &prhs_buffer_[rhs_stride * thread_num]; + + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + if (need_im2col_) + { + if (out_mat_.n == 1) + { + _pack_colmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_, const_cast<convMat_t *>(&in_mat_), + &out_mat_, const_cast<convParams_t *>(&in_param_), prhs_ptr); + } + else + { + _pack_colmajor_image_rhs_batch(nr_, bn, bk, l * bk_, j * bn_, + const_cast<convMat_t *>(&in_mat_), &out_mat_, + const_cast<convParams_t *>(&in_param_), prhs_ptr); + } + } + else + { + _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &in_mat_.data[j * bn_ * k_ + l * bk_], + prhs_ptr); + } + + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _sgemm_colmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, &plhs_buffer_[i * lhs_stride], + prhs_ptr, &out_mat_.data[j * bn_ * m_ + i * bm_], l, m_, + bk); + } + } + } +} + +void conv_sgemm_multithreads::compute_colmajor_rowshard() +{ + int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; + int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + +#pragma omp parallel for + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + if (need_im2col_) + { + if (out_mat_.n == 1) + { + _pack_colmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_, const_cast<convMat_t *>(&in_mat_), + &out_mat_, const_cast<convParams_t *>(&in_param_), + &prhs_buffer_[j * rhs_stride]); + } + else + { + _pack_colmajor_image_rhs_batch( + nr_, bn, bk, l * bk_, j * bn_, const_cast<convMat_t *>(&in_mat_), &out_mat_, + const_cast<convParams_t *>(&in_param_), &prhs_buffer_[j * rhs_stride]); + } + } + else + { + _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &in_mat_.data[j * bn_ * k_ + l * bk_], + &prhs_buffer_[j * rhs_stride]); + } + } + +#pragma omp parallel for + for (int i = 0; i < nm_; i++) + { + int thread_num = omp_get_thread_num(); + float *plhs_ptr = &plhs_buffer_[lhs_stride * thread_num]; + + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &weights_mat_.data[l * bk_ * m_ + i * bm_], + plhs_ptr); + + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + _sgemm_colmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, + &prhs_buffer_[j * rhs_stride], + &out_mat_.data[j * bn_ * m_ + i * bm_], l, m_, bk); + } + } + } +} + +} // namespace srcn +} // namespace nnfw |