summaryrefslogtreecommitdiff
path: root/runtimes/libs/srcn/src/conv_sgemm_multithreads.cc
diff options
context:
space:
mode:
Diffstat (limited to 'runtimes/libs/srcn/src/conv_sgemm_multithreads.cc')
-rw-r--r--runtimes/libs/srcn/src/conv_sgemm_multithreads.cc483
1 files changed, 483 insertions, 0 deletions
diff --git a/runtimes/libs/srcn/src/conv_sgemm_multithreads.cc b/runtimes/libs/srcn/src/conv_sgemm_multithreads.cc
new file mode 100644
index 000000000..91a4533bd
--- /dev/null
+++ b/runtimes/libs/srcn/src/conv_sgemm_multithreads.cc
@@ -0,0 +1,483 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifdef _OPENMP
+#include <omp.h>
+#endif
+
+#include "srcn/conv_type.h"
+#include "common.h"
+#include "sgemm_kernel.h"
+#include "sgemm_pack.h"
+#include "conv_sgemm_multithreads.h"
+
+namespace nnfw
+{
+namespace srcn
+{
+
+void conv_sgemm_multithreads::param_init()
+{
+#if __aarch64__
+ if (conv_type_ == row_major)
+ {
+ mr_ = 8;
+ nr_ = 12;
+ }
+ else if (conv_type_ == col_major)
+ {
+#ifdef BATCH_DILATION_FIX
+ if (out_mat_.n > 1)
+ {
+
+ mr_ = 24;
+ nr_ = 4;
+ }
+ else
+#endif // BATCH_DILATION_FIX
+ {
+ if (m_ > n_)
+ {
+ mr_ = 24;
+ nr_ = 4;
+ }
+ else
+ {
+ mr_ = 12;
+ nr_ = 8;
+ }
+ }
+ }
+#else // __aarch64__
+ if (conv_type_ == row_major)
+ {
+ mr_ = 6;
+ nr_ = 8;
+ }
+ else if (conv_type_ == col_major)
+ {
+ mr_ = 8;
+ nr_ = 6;
+ }
+#endif // __aarch64__
+ int col = n_;
+
+ if (m_ > n_)
+ {
+ shard_type_ = shardByRow;
+ col = m_;
+ }
+ else
+ {
+ shard_type_ = shardByCol;
+ }
+
+ int th_base = divup(col, num_threads_);
+
+ th_base = MIN(MAX(th_base, MIN_COL), MAX_COL);
+
+ int k_div = (nr_ * sizeof_RhsScalar);
+ int k_sub = (mr_ * nr_ * sizeof_ResScalar);
+
+ const int k_cache = MIN(divup((int)(L1_CACHE_SIZE - k_sub), (int)k_div * 2), MAX_K);
+ bk_ = MIN(k_cache, k_);
+
+ if (shard_type_ == shardByCol)
+ {
+ int m_sub = (bk_ * nr_ * sizeof_RhsScalar);
+ int m_div = (sizeof_LhsScalar * bk_ * 2 * num_threads_);
+ if (L3_CACHE_SIZE)
+ m_div = (sizeof_LhsScalar * bk_ * 2);
+ int m_cache = divup((L2_CACHE_SIZE - m_sub), m_div);
+ bm_ = MIN(m_cache, m_);
+
+ bn_ = MIN(th_base, n_);
+ if (L3_CACHE_SIZE)
+ {
+ int n_sub = (bk_ * bm_ * sizeof_RhsScalar);
+ int n_div = (sizeof_LhsScalar * bk_ * 2 * num_threads_);
+ int n_cache = divup((L3_CACHE_SIZE - n_sub), n_div);
+ bn_ = MIN(n_cache, bn_);
+ }
+ }
+ else
+ {
+ int n_sub = (bk_ * mr_ * sizeof_LhsScalar);
+ int n_div = (sizeof_LhsScalar * bk_ * 2 * num_threads_);
+ if (L3_CACHE_SIZE)
+ n_div = (sizeof_LhsScalar * bk_ * 2);
+ int n_cache = divup((L2_CACHE_SIZE - n_sub), n_div);
+ bn_ = MIN(n_cache, n_);
+
+ bm_ = MIN(th_base, m_);
+ if (L3_CACHE_SIZE)
+ {
+ int m_sub = (bk_ * bn_ * sizeof_RhsScalar);
+ int m_div = (sizeof_LhsScalar * bk_ * 2 * num_threads_);
+ int m_cache = divup((L3_CACHE_SIZE - m_sub), m_div);
+ bm_ = MIN(m_cache, bm_);
+ }
+ }
+
+ nm_ = divup(m_, bm_);
+ nn_ = divup(n_, bn_);
+ nk_ = divup(k_, bk_);
+
+ rm_ = m_ % bm_;
+ rn_ = n_ % bn_;
+ rk_ = k_ % bk_;
+}
+
+conv_sgemm_multithreads::conv_sgemm_multithreads(const convMat_t &in_mat,
+ const convMat_t &weights_mat, convMat_t &out_mat,
+ const convParams_t &in_param, int num_threads,
+ convType_t conv_type)
+
+ : in_mat_(in_mat), weights_mat_(weights_mat), out_mat_(out_mat), in_param_(in_param),
+ num_threads_(num_threads), conv_type_(conv_type)
+{
+ m_ = out_mat_.c;
+#ifdef NCNN
+#ifdef WITH_DPU
+ np_ = out_mat_.n * alignSize(out_mat_.h * out_mat_.w, 16 / sizeof(float));
+ n_ = (np_ + 1) / 2;
+#else // WITH_DPU
+ n_ = out_mat_.n * alignSize(out_mat_.h * out_mat_.w, 16 / sizeof(float));
+#endif // WITH_DPU
+#else // NCNN
+#ifdef WITH_DPU
+ np_ = out_mat_.n * out_mat_.w * out_mat_.h;
+ n_ = (np_ + 1) / 2;
+#else // WITH_DPU
+ n_ = out_mat_.n * out_mat_.w * out_mat_.h;
+#endif // WITH_DPU
+#endif // NCNN
+ k_ = in_param_.kernel_h * in_param_.kernel_w * in_mat.c;
+
+ param_init();
+
+ int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_;
+ int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_;
+
+ if (shard_type_ == shardByCol)
+ {
+ plhs_buffer_ = new float[lhs_stride * 1 * nm_];
+ prhs_buffer_ = new float[rhs_stride * num_threads_];
+ }
+ else
+ {
+ plhs_buffer_ = new float[lhs_stride * num_threads_];
+ prhs_buffer_ = new float[rhs_stride * 1 * nn_];
+ }
+
+ if (plhs_buffer_ == NULL || prhs_buffer_ == NULL)
+ {
+ error_ = 1;
+ }
+
+ if (in_param_.kernel_w != 1 || in_param_.kernel_h != 1 || in_param_.stride_w != 1 ||
+ in_param_.stride_h != 1 || in_param_.padding != 0)
+ {
+ need_im2col_ = 1;
+ }
+ else
+ {
+ need_im2col_ = 0;
+ }
+
+ omp_set_num_threads(num_threads_);
+
+ error_ = 0;
+}
+
+conv_sgemm_multithreads::~conv_sgemm_multithreads()
+{
+ if (plhs_buffer_)
+ delete[] plhs_buffer_;
+ if (prhs_buffer_)
+ delete[] prhs_buffer_;
+}
+
+void conv_sgemm_multithreads::run()
+{
+ if (error_)
+ return;
+
+ if (shard_type_ == shardByCol && conv_type_ == col_major)
+ {
+ compute_colmajor_colshard();
+ }
+ else if (shard_type_ == shardByRow && conv_type_ == col_major)
+ {
+ compute_colmajor_rowshard();
+ }
+ else if (shard_type_ == shardByCol && conv_type_ == row_major)
+ {
+ compute_rowmajor_colshard();
+ }
+ else if (shard_type_ == shardByRow && conv_type_ == row_major)
+ {
+ compute_rowmajor_rowshard();
+ }
+}
+
+void conv_sgemm_multithreads::compute_rowmajor_colshard()
+{
+ int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_;
+ int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_;
+
+ for (int l = 0; l < nk_; l++)
+ {
+ const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_;
+
+#pragma omp parallel for
+ for (int i = 0; i < nm_; i++)
+ {
+ const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_;
+
+ _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &weights_mat_.data[i * bm_ * k_ + l * bk_],
+ &plhs_buffer_[i * lhs_stride]);
+ }
+
+#pragma omp parallel for
+ for (int j = 0; j < nn_; j++)
+ {
+ int thread_num = omp_get_thread_num();
+ // float *plhs_ptr = &plhs_buffer_[lhs_stride * thread_num];
+ float *prhs_ptr = &prhs_buffer_[rhs_stride * thread_num];
+
+ const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_;
+ if (need_im2col_)
+ {
+ if (out_mat_.n == 1)
+ {
+ _pack_rowmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_, const_cast<convMat_t *>(&in_mat_),
+ &out_mat_, const_cast<convParams_t *>(&in_param_), prhs_ptr);
+ }
+ else
+ {
+ _pack_rowmajor_image_rhs_batch(nr_, bn, bk, l * bk_, j * bn_,
+ const_cast<convMat_t *>(&in_mat_), &out_mat_,
+ const_cast<convParams_t *>(&in_param_), prhs_ptr);
+ }
+ }
+ else
+ {
+#ifdef WITH_DPU
+ _pack_rowmajor_notrans_rhs(nr_, bn, bk, np_, &in_mat_.data[n_ + l * bk_ * np_ + j * bn_],
+ prhs_ptr);
+#else
+ _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &in_mat_.data[l * bk_ * n_ + j * bn_],
+ prhs_ptr);
+#endif
+ }
+
+ for (int i = 0; i < nm_; i++)
+ {
+ const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_;
+
+#ifdef WITH_DPU
+ _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, &plhs_buffer_[i * lhs_stride],
+ prhs_ptr, &out_mat_.data[n_ + i * bm_ * np_ + j * bn_],
+ l, np_, bk);
+#else // WITH_DPU
+ _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, &plhs_buffer_[i * lhs_stride],
+ prhs_ptr, &out_mat_.data[i * bm_ * n_ + j * bn_], l, n_,
+ bk);
+#endif // WITH_DPU
+ }
+ }
+ }
+}
+
+void conv_sgemm_multithreads::compute_rowmajor_rowshard()
+{
+ int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_;
+ int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_;
+
+ for (int l = 0; l < nk_; l++)
+ {
+ const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_;
+
+#pragma omp parallel for
+ for (int j = 0; j < nn_; j++)
+ {
+ const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_;
+
+ if (need_im2col_)
+ {
+ if (out_mat_.n == 1)
+ {
+ _pack_rowmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_, const_cast<convMat_t *>(&in_mat_),
+ &out_mat_, const_cast<convParams_t *>(&in_param_),
+ &prhs_buffer_[j * rhs_stride]);
+ }
+ else
+ {
+ _pack_rowmajor_image_rhs_batch(
+ nr_, bn, bk, l * bk_, j * bn_, const_cast<convMat_t *>(&in_mat_), &out_mat_,
+ const_cast<convParams_t *>(&in_param_), &prhs_buffer_[j * rhs_stride]);
+ }
+ }
+ else
+ {
+ _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &in_mat_.data[l * bk_ * n_ + j * bn_],
+ &prhs_buffer_[j * rhs_stride]);
+ }
+ }
+
+#pragma omp parallel for
+ for (int i = 0; i < nm_; i++)
+ {
+ int thread_num = omp_get_thread_num();
+ float *plhs_ptr = &plhs_buffer_[lhs_stride * thread_num];
+
+ const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_;
+
+ _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &weights_mat_.data[i * bm_ * k_ + l * bk_],
+ plhs_ptr);
+
+ for (int j = 0; j < nn_; j++)
+ {
+ const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_;
+
+ _sgemm_rowmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr,
+ &prhs_buffer_[j * rhs_stride],
+ &out_mat_.data[i * bm_ * n_ + j * bn_], l, n_, bk);
+ }
+ }
+ }
+}
+
+void conv_sgemm_multithreads::compute_colmajor_colshard()
+{
+ int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_;
+ int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_;
+
+ for (int l = 0; l < nk_; l++)
+ {
+ const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_;
+
+#pragma omp parallel for
+ for (int i = 0; i < nm_; i++)
+ {
+ const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_;
+
+ _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &weights_mat_.data[l * bk_ * m_ + i * bm_],
+ &plhs_buffer_[i * lhs_stride]);
+ }
+
+#pragma omp parallel for
+ for (int j = 0; j < nn_; j++)
+ {
+ int thread_num = omp_get_thread_num();
+ float *prhs_ptr = &prhs_buffer_[rhs_stride * thread_num];
+
+ const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_;
+
+ if (need_im2col_)
+ {
+ if (out_mat_.n == 1)
+ {
+ _pack_colmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_, const_cast<convMat_t *>(&in_mat_),
+ &out_mat_, const_cast<convParams_t *>(&in_param_), prhs_ptr);
+ }
+ else
+ {
+ _pack_colmajor_image_rhs_batch(nr_, bn, bk, l * bk_, j * bn_,
+ const_cast<convMat_t *>(&in_mat_), &out_mat_,
+ const_cast<convParams_t *>(&in_param_), prhs_ptr);
+ }
+ }
+ else
+ {
+ _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &in_mat_.data[j * bn_ * k_ + l * bk_],
+ prhs_ptr);
+ }
+
+ for (int i = 0; i < nm_; i++)
+ {
+ const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_;
+
+ _sgemm_colmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, &plhs_buffer_[i * lhs_stride],
+ prhs_ptr, &out_mat_.data[j * bn_ * m_ + i * bm_], l, m_,
+ bk);
+ }
+ }
+ }
+}
+
+void conv_sgemm_multithreads::compute_colmajor_rowshard()
+{
+ int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_;
+ int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_;
+
+ for (int l = 0; l < nk_; l++)
+ {
+ const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_;
+
+#pragma omp parallel for
+ for (int j = 0; j < nn_; j++)
+ {
+ const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_;
+
+ if (need_im2col_)
+ {
+ if (out_mat_.n == 1)
+ {
+ _pack_colmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_, const_cast<convMat_t *>(&in_mat_),
+ &out_mat_, const_cast<convParams_t *>(&in_param_),
+ &prhs_buffer_[j * rhs_stride]);
+ }
+ else
+ {
+ _pack_colmajor_image_rhs_batch(
+ nr_, bn, bk, l * bk_, j * bn_, const_cast<convMat_t *>(&in_mat_), &out_mat_,
+ const_cast<convParams_t *>(&in_param_), &prhs_buffer_[j * rhs_stride]);
+ }
+ }
+ else
+ {
+ _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &in_mat_.data[j * bn_ * k_ + l * bk_],
+ &prhs_buffer_[j * rhs_stride]);
+ }
+ }
+
+#pragma omp parallel for
+ for (int i = 0; i < nm_; i++)
+ {
+ int thread_num = omp_get_thread_num();
+ float *plhs_ptr = &plhs_buffer_[lhs_stride * thread_num];
+
+ const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_;
+
+ _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &weights_mat_.data[l * bk_ * m_ + i * bm_],
+ plhs_ptr);
+
+ for (int j = 0; j < nn_; j++)
+ {
+ const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_;
+
+ _sgemm_colmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr,
+ &prhs_buffer_[j * rhs_stride],
+ &out_mat_.data[j * bn_ * m_ + i * bm_], l, m_, bk);
+ }
+ }
+ }
+}
+
+} // namespace srcn
+} // namespace nnfw