/* * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifdef _OPENMP #include #endif #include "srcn/conv_type.h" #include "common.h" #include "sgemm_kernel.h" #include "sgemm_pack.h" #include "conv_sgemm_multithreads.h" namespace nnfw { namespace srcn { void conv_sgemm_multithreads::param_init() { #if __aarch64__ if (conv_type_ == row_major) { mr_ = 8; nr_ = 12; } else if (conv_type_ == col_major) { #ifdef BATCH_DILATION_FIX if (out_mat_.n > 1) { mr_ = 24; nr_ = 4; } else #endif // BATCH_DILATION_FIX { if (m_ > n_) { mr_ = 24; nr_ = 4; } else { mr_ = 12; nr_ = 8; } } } #else // __aarch64__ if (conv_type_ == row_major) { mr_ = 6; nr_ = 8; } else if (conv_type_ == col_major) { mr_ = 8; nr_ = 6; } #endif // __aarch64__ int col = n_; if (m_ > n_) { shard_type_ = shardByRow; col = m_; } else { shard_type_ = shardByCol; } int th_base = divup(col, num_threads_); th_base = MIN(MAX(th_base, MIN_COL), MAX_COL); int k_div = (nr_ * sizeof_RhsScalar); int k_sub = (mr_ * nr_ * sizeof_ResScalar); const int k_cache = MIN(divup((int)(L1_CACHE_SIZE - k_sub), (int)k_div * 2), MAX_K); bk_ = MIN(k_cache, k_); if (shard_type_ == shardByCol) { int m_sub = (bk_ * nr_ * sizeof_RhsScalar); int m_div = (sizeof_LhsScalar * bk_ * 2 * num_threads_); if (L3_CACHE_SIZE) m_div = (sizeof_LhsScalar * bk_ * 2); int m_cache = divup((L2_CACHE_SIZE - m_sub), m_div); bm_ = MIN(m_cache, m_); bn_ = MIN(th_base, n_); if (L3_CACHE_SIZE) { int n_sub = (bk_ * bm_ * sizeof_RhsScalar); int n_div = (sizeof_LhsScalar * bk_ * 2 * num_threads_); int n_cache = divup((L3_CACHE_SIZE - n_sub), n_div); bn_ = MIN(n_cache, bn_); } } else { int n_sub = (bk_ * mr_ * sizeof_LhsScalar); int n_div = (sizeof_LhsScalar * bk_ * 2 * num_threads_); if (L3_CACHE_SIZE) n_div = (sizeof_LhsScalar * bk_ * 2); int n_cache = divup((L2_CACHE_SIZE - n_sub), n_div); bn_ = MIN(n_cache, n_); bm_ = MIN(th_base, m_); if (L3_CACHE_SIZE) { int m_sub = (bk_ * bn_ * sizeof_RhsScalar); int m_div = (sizeof_LhsScalar * bk_ * 2 * num_threads_); int m_cache = divup((L3_CACHE_SIZE - m_sub), m_div); bm_ = MIN(m_cache, bm_); } } nm_ = divup(m_, bm_); nn_ = divup(n_, bn_); nk_ = divup(k_, bk_); rm_ = m_ % bm_; rn_ = n_ % bn_; rk_ = k_ % bk_; } conv_sgemm_multithreads::conv_sgemm_multithreads(const convMat_t &in_mat, const convMat_t &weights_mat, convMat_t &out_mat, const convParams_t &in_param, int num_threads, convType_t conv_type) : in_mat_(in_mat), weights_mat_(weights_mat), out_mat_(out_mat), in_param_(in_param), num_threads_(num_threads), conv_type_(conv_type) { m_ = out_mat_.c; #ifdef NCNN #ifdef WITH_DPU np_ = out_mat_.n * alignSize(out_mat_.h * out_mat_.w, 16 / sizeof(float)); n_ = (np_ + 1) / 2; #else // WITH_DPU n_ = out_mat_.n * alignSize(out_mat_.h * out_mat_.w, 16 / sizeof(float)); #endif // WITH_DPU #else // NCNN #ifdef WITH_DPU np_ = out_mat_.n * out_mat_.w * out_mat_.h; n_ = (np_ + 1) / 2; #else // WITH_DPU n_ = out_mat_.n * out_mat_.w * out_mat_.h; #endif // WITH_DPU #endif // NCNN k_ = in_param_.kernel_h * in_param_.kernel_w * in_mat.c; param_init(); int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; if (shard_type_ == shardByCol) { plhs_buffer_ = new float[lhs_stride * 1 * nm_]; prhs_buffer_ = new float[rhs_stride * num_threads_]; } else { plhs_buffer_ = new float[lhs_stride * num_threads_]; prhs_buffer_ = new float[rhs_stride * 1 * nn_]; } if (plhs_buffer_ == NULL || prhs_buffer_ == NULL) { error_ = 1; } if (in_param_.kernel_w != 1 || in_param_.kernel_h != 1 || in_param_.stride_w != 1 || in_param_.stride_h != 1 || in_param_.padding != 0) { need_im2col_ = 1; } else { need_im2col_ = 0; } omp_set_num_threads(num_threads_); error_ = 0; } conv_sgemm_multithreads::~conv_sgemm_multithreads() { if (plhs_buffer_) delete[] plhs_buffer_; if (prhs_buffer_) delete[] prhs_buffer_; } void conv_sgemm_multithreads::run() { if (error_) return; if (shard_type_ == shardByCol && conv_type_ == col_major) { compute_colmajor_colshard(); } else if (shard_type_ == shardByRow && conv_type_ == col_major) { compute_colmajor_rowshard(); } else if (shard_type_ == shardByCol && conv_type_ == row_major) { compute_rowmajor_colshard(); } else if (shard_type_ == shardByRow && conv_type_ == row_major) { compute_rowmajor_rowshard(); } } void conv_sgemm_multithreads::compute_rowmajor_colshard() { int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; for (int l = 0; l < nk_; l++) { const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; #pragma omp parallel for for (int i = 0; i < nm_; i++) { const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &weights_mat_.data[i * bm_ * k_ + l * bk_], &plhs_buffer_[i * lhs_stride]); } #pragma omp parallel for for (int j = 0; j < nn_; j++) { int thread_num = omp_get_thread_num(); // float *plhs_ptr = &plhs_buffer_[lhs_stride * thread_num]; float *prhs_ptr = &prhs_buffer_[rhs_stride * thread_num]; const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; if (need_im2col_) { if (out_mat_.n == 1) { _pack_rowmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_, const_cast(&in_mat_), &out_mat_, const_cast(&in_param_), prhs_ptr); } else { _pack_rowmajor_image_rhs_batch(nr_, bn, bk, l * bk_, j * bn_, const_cast(&in_mat_), &out_mat_, const_cast(&in_param_), prhs_ptr); } } else { #ifdef WITH_DPU _pack_rowmajor_notrans_rhs(nr_, bn, bk, np_, &in_mat_.data[n_ + l * bk_ * np_ + j * bn_], prhs_ptr); #else _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &in_mat_.data[l * bk_ * n_ + j * bn_], prhs_ptr); #endif } for (int i = 0; i < nm_; i++) { const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; #ifdef WITH_DPU _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, &plhs_buffer_[i * lhs_stride], prhs_ptr, &out_mat_.data[n_ + i * bm_ * np_ + j * bn_], l, np_, bk); #else // WITH_DPU _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, &plhs_buffer_[i * lhs_stride], prhs_ptr, &out_mat_.data[i * bm_ * n_ + j * bn_], l, n_, bk); #endif // WITH_DPU } } } } void conv_sgemm_multithreads::compute_rowmajor_rowshard() { int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; for (int l = 0; l < nk_; l++) { const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; #pragma omp parallel for for (int j = 0; j < nn_; j++) { const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; if (need_im2col_) { if (out_mat_.n == 1) { _pack_rowmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_, const_cast(&in_mat_), &out_mat_, const_cast(&in_param_), &prhs_buffer_[j * rhs_stride]); } else { _pack_rowmajor_image_rhs_batch( nr_, bn, bk, l * bk_, j * bn_, const_cast(&in_mat_), &out_mat_, const_cast(&in_param_), &prhs_buffer_[j * rhs_stride]); } } else { _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &in_mat_.data[l * bk_ * n_ + j * bn_], &prhs_buffer_[j * rhs_stride]); } } #pragma omp parallel for for (int i = 0; i < nm_; i++) { int thread_num = omp_get_thread_num(); float *plhs_ptr = &plhs_buffer_[lhs_stride * thread_num]; const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &weights_mat_.data[i * bm_ * k_ + l * bk_], plhs_ptr); for (int j = 0; j < nn_; j++) { const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; _sgemm_rowmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, &prhs_buffer_[j * rhs_stride], &out_mat_.data[i * bm_ * n_ + j * bn_], l, n_, bk); } } } } void conv_sgemm_multithreads::compute_colmajor_colshard() { int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; for (int l = 0; l < nk_; l++) { const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; #pragma omp parallel for for (int i = 0; i < nm_; i++) { const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &weights_mat_.data[l * bk_ * m_ + i * bm_], &plhs_buffer_[i * lhs_stride]); } #pragma omp parallel for for (int j = 0; j < nn_; j++) { int thread_num = omp_get_thread_num(); float *prhs_ptr = &prhs_buffer_[rhs_stride * thread_num]; const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; if (need_im2col_) { if (out_mat_.n == 1) { _pack_colmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_, const_cast(&in_mat_), &out_mat_, const_cast(&in_param_), prhs_ptr); } else { _pack_colmajor_image_rhs_batch(nr_, bn, bk, l * bk_, j * bn_, const_cast(&in_mat_), &out_mat_, const_cast(&in_param_), prhs_ptr); } } else { _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &in_mat_.data[j * bn_ * k_ + l * bk_], prhs_ptr); } for (int i = 0; i < nm_; i++) { const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; _sgemm_colmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, &plhs_buffer_[i * lhs_stride], prhs_ptr, &out_mat_.data[j * bn_ * m_ + i * bm_], l, m_, bk); } } } } void conv_sgemm_multithreads::compute_colmajor_rowshard() { int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; for (int l = 0; l < nk_; l++) { const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; #pragma omp parallel for for (int j = 0; j < nn_; j++) { const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; if (need_im2col_) { if (out_mat_.n == 1) { _pack_colmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_, const_cast(&in_mat_), &out_mat_, const_cast(&in_param_), &prhs_buffer_[j * rhs_stride]); } else { _pack_colmajor_image_rhs_batch( nr_, bn, bk, l * bk_, j * bn_, const_cast(&in_mat_), &out_mat_, const_cast(&in_param_), &prhs_buffer_[j * rhs_stride]); } } else { _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &in_mat_.data[j * bn_ * k_ + l * bk_], &prhs_buffer_[j * rhs_stride]); } } #pragma omp parallel for for (int i = 0; i < nm_; i++) { int thread_num = omp_get_thread_num(); float *plhs_ptr = &plhs_buffer_[lhs_stride * thread_num]; const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &weights_mat_.data[l * bk_ * m_ + i * bm_], plhs_ptr); for (int j = 0; j < nn_; j++) { const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; _sgemm_colmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, &prhs_buffer_[j * rhs_stride], &out_mat_.data[j * bn_ * m_ + i * bm_], l, m_, bk); } } } } } // namespace srcn } // namespace nnfw