summaryrefslogtreecommitdiff
path: root/compute/ncnn/src/srcn/conv_sgemm_singlethread.cc
diff options
context:
space:
mode:
Diffstat (limited to 'compute/ncnn/src/srcn/conv_sgemm_singlethread.cc')
-rw-r--r--compute/ncnn/src/srcn/conv_sgemm_singlethread.cc366
1 files changed, 366 insertions, 0 deletions
diff --git a/compute/ncnn/src/srcn/conv_sgemm_singlethread.cc b/compute/ncnn/src/srcn/conv_sgemm_singlethread.cc
new file mode 100644
index 000000000..4cbbf217f
--- /dev/null
+++ b/compute/ncnn/src/srcn/conv_sgemm_singlethread.cc
@@ -0,0 +1,366 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <stdexcept>
+
+#include "common.h"
+#include "sgemm_kernel.h"
+#include "sgemm_pack.h"
+#include "conv_sgemm_singlethread.h"
+
+namespace nnfw
+{
+namespace srcn
+{
+
+void conv_sgemm_singlethread::param_init()
+{
+ if (n_ > 3 * m_)
+ {
+ shard_type_ = shardByRow;
+ }
+ else
+ {
+ shard_type_ = shardByCol;
+ }
+
+#if __aarch64__
+ if (conv_type_ == row_major)
+ {
+ if (shard_type_ == shardByRow)
+ {
+ mr_ = 8;
+ nr_ = 12;
+ }
+ else
+ {
+ mr_ = 12;
+ nr_ = 8;
+ }
+ }
+ else if (conv_type_ == col_major)
+ {
+#ifndef BATCH_DILATION_FIX
+ mr_ = 12;
+ nr_ = 8;
+#else // BATCH_DILATION_FIX
+ // TODO: batch(dilation) + inw * inh
+ if (out_mat_.n > 1)
+ {
+ mr_ = 24;
+ nr_ = 4;
+ }
+ else
+ {
+ mr_ = 12;
+ nr_ = 8;
+ }
+#endif // BATCH_DILATION_FIX
+ }
+#else // __aarch64__
+ if (conv_type_ == row_major)
+ {
+ mr_ = 6;
+ nr_ = 8;
+ }
+ else if (conv_type_ == col_major)
+ {
+ mr_ = 8;
+ nr_ = 6;
+ }
+#endif // __aarch64__
+
+ int k_div = (nr_ * sizeof_RhsScalar);
+ int k_sub = (mr_ * nr_ * sizeof_ResScalar);
+
+ const int k_cache = MIN(divup((int)(L1_CACHE_SIZE - k_sub), (int)k_div), MAX_K);
+ bk_ = MIN(k_cache, k_);
+
+ if (shard_type_ == shardByCol)
+ {
+ int m_sub = (bk_ * nr_ * sizeof_RhsScalar);
+ int m_cache = divup((L2_CACHE_SIZE - m_sub), (sizeof_LhsScalar * bk_ * 2));
+ bm_ = MIN(m_cache, m_);
+
+ bn_ = MIN(GEN_COL, n_);
+ if (L3_CACHE_SIZE)
+ {
+ int n_sub = (bk_ * bm_ * sizeof_RhsScalar);
+ int n_cache = divup((L3_CACHE_SIZE - n_sub), (sizeof_LhsScalar * bk_ * 2));
+ bn_ = MIN(n_cache, bn_);
+ }
+ }
+ else
+ {
+ int n_sub = (bk_ * mr_ * sizeof_RhsScalar);
+ int n_cache = divup((L2_CACHE_SIZE - n_sub), (sizeof_LhsScalar * bk_ * 2));
+ bn_ = MIN(n_cache, n_);
+
+ bm_ = MIN(GEN_COL, m_);
+ if (L3_CACHE_SIZE)
+ {
+ int m_sub = (bk_ * bn_ * sizeof_RhsScalar);
+ int m_cache = divup((L3_CACHE_SIZE - m_sub), (sizeof_LhsScalar * bk_ * 2));
+ bm_ = MIN(m_cache, bm_);
+ }
+ }
+
+ nm_ = divup(m_, bm_);
+ nn_ = divup(n_, bn_);
+ nk_ = divup(k_, bk_);
+
+ rm_ = m_ % bm_;
+ rn_ = n_ % bn_;
+ rk_ = k_ % bk_;
+}
+
+conv_sgemm_singlethread::conv_sgemm_singlethread(const convMat_t &in_mat,
+ const convMat_t &weights_mat, convMat_t &out_mat,
+ const convParams_t &in_param, convType_t conv_type)
+ : in_mat_(in_mat), weights_mat_(weights_mat), out_mat_(out_mat), in_param_(in_param),
+ conv_type_(conv_type)
+{
+ m_ = out_mat_.c;
+#ifdef NCNN
+ n_ = out_mat_.n * alignSize(out_mat_.h * out_mat_.w, 16 / sizeof(float));
+#else
+ n_ = out_mat_.n * out_mat_.w * out_mat_.h;
+#endif
+ k_ = in_param_.kernel_h * in_param_.kernel_w * in_mat.c;
+
+ param_init();
+
+ if (in_param_.kernel_w != 1 || in_param_.kernel_h != 1 || in_param_.stride_w != 1 ||
+ in_param_.stride_h != 1 || in_param_.padding != 0 || out_mat_.n > 1)
+ {
+ need_im2col_ = 1;
+ }
+ else
+ {
+ need_im2col_ = 0;
+ }
+}
+
+conv_sgemm_singlethread::~conv_sgemm_singlethread() {}
+
+void conv_sgemm_singlethread::run()
+{
+ int mstride = (bm_ + mr_ - 1) / mr_ * mr_;
+ int nstride = (bn_ + nr_ - 1) / nr_ * nr_;
+
+ float *plhs_ptr = new float[mstride * bk_];
+ float *prhs_ptr = new float[nstride * bk_];
+
+ if (conv_type_ == row_major)
+ {
+ if (shard_type_ == shardByCol)
+ {
+ for (int j = 0; j < nn_; j++)
+ {
+ const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_;
+
+ for (int l = 0; l < nk_; l++)
+ {
+ const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_;
+
+ if (need_im2col_)
+ {
+ if (out_mat_.n == 1)
+ {
+ _pack_rowmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_,
+ const_cast<convMat_t *>(&in_mat_), &out_mat_,
+ const_cast<convParams_t *>(&in_param_), prhs_ptr);
+ }
+ else
+ {
+ _pack_rowmajor_image_rhs_batch(nr_, bn, bk, l * bk_, j * bn_,
+ const_cast<convMat_t *>(&in_mat_), &out_mat_,
+ const_cast<convParams_t *>(&in_param_), prhs_ptr);
+ }
+ }
+ else
+ {
+ _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &in_mat_.data[l * bk_ * n_ + j * bn_],
+ prhs_ptr);
+ }
+
+ for (int i = 0; i < nm_; i++)
+ {
+ const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_;
+
+ _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &weights_mat_.data[i * bm_ * k_ + l * bk_],
+ plhs_ptr);
+
+ _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr,
+ &out_mat_.data[i * bm_ * n_ + j * bn_], l, n_, bk);
+ }
+ }
+ }
+ }
+ else if (shard_type_ == shardByRow)
+ {
+ for (int i = 0; i < nm_; i++)
+ {
+ const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_;
+
+ for (int l = 0; l < nk_; l++)
+ {
+ const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_;
+
+ _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &weights_mat_.data[i * bm_ * k_ + l * bk_],
+ plhs_ptr);
+
+ for (int j = 0; j < nn_; j++)
+ {
+ const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_;
+
+ if (need_im2col_)
+ {
+ if (out_mat_.n == 1)
+ {
+ _pack_rowmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_,
+ const_cast<convMat_t *>(&in_mat_), &out_mat_,
+ const_cast<convParams_t *>(&in_param_), prhs_ptr);
+ }
+ else
+ {
+ _pack_rowmajor_image_rhs_batch(nr_, bn, bk, l * bk_, j * bn_,
+ const_cast<convMat_t *>(&in_mat_), &out_mat_,
+ const_cast<convParams_t *>(&in_param_), prhs_ptr);
+ }
+ }
+ else
+ {
+ _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &in_mat_.data[l * bk_ * n_ + j * bn_],
+ prhs_ptr);
+ }
+
+ _sgemm_rowmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr,
+ &out_mat_.data[i * bm_ * n_ + j * bn_], l, n_, bk);
+ }
+ }
+ }
+ }
+ else
+ {
+ throw std::runtime_error{"Error shrad type!"};
+ }
+ }
+ else if (conv_type_ == col_major)
+ {
+ if (shard_type_ == shardByCol)
+ {
+ for (int j = 0; j < nn_; j++)
+ {
+ const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_;
+ for (int l = 0; l < nk_; l++)
+ {
+ const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_;
+
+ if (need_im2col_)
+ {
+ if (out_mat_.n == 1)
+ {
+ _pack_colmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_,
+ const_cast<convMat_t *>(&in_mat_), &out_mat_,
+ const_cast<convParams_t *>(&in_param_), prhs_ptr);
+ }
+ else
+ {
+ _pack_colmajor_image_rhs_batch(nr_, bn, bk, l * bk_, j * bn_,
+ const_cast<convMat_t *>(&in_mat_), &out_mat_,
+ const_cast<convParams_t *>(&in_param_), prhs_ptr);
+ }
+ }
+ else
+ {
+ _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &in_mat_.data[j * bn_ * k_ + l * bk_],
+ prhs_ptr);
+ }
+
+ for (int i = 0; i < nm_; i++)
+ {
+ const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_;
+
+ _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &weights_mat_.data[l * bk_ * m_ + i * bm_],
+ plhs_ptr);
+
+ _sgemm_colmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr,
+ &out_mat_.data[j * bn_ * m_ + i * bm_], l, m_, bk);
+ }
+ }
+ }
+ }
+ else if (shard_type_ == shardByRow)
+ {
+ for (int i = 0; i < nm_; i++)
+ {
+ const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_;
+
+ for (int l = 0; l < nk_; l++)
+ {
+ const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_;
+
+ _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &weights_mat_.data[l * bk_ * m_ + i * bm_],
+ plhs_ptr);
+
+ for (int j = 0; j < nn_; j++)
+ {
+ const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_;
+
+ if (need_im2col_)
+ {
+ if (out_mat_.n == 1)
+ {
+ _pack_colmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_,
+ const_cast<convMat_t *>(&in_mat_), &out_mat_,
+ const_cast<convParams_t *>(&in_param_), prhs_ptr);
+ }
+ else
+ {
+ _pack_colmajor_image_rhs_batch(nr_, bn, bk, l * bk_, j * bn_,
+ const_cast<convMat_t *>(&in_mat_), &out_mat_,
+ const_cast<convParams_t *>(&in_param_), prhs_ptr);
+ }
+ }
+ else
+ {
+ _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &in_mat_.data[j * bn_ * k_ + l * bk_],
+ prhs_ptr);
+ }
+
+ _sgemm_colmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr,
+ &out_mat_.data[j * bn_ * m_ + i * bm_], l, m_, bk);
+ }
+ }
+ }
+ }
+ else
+ {
+ throw std::runtime_error{"Error shrad type!"};
+ }
+ }
+ else
+ {
+ throw std::runtime_error{"Error conv type!"};
+ }
+
+ delete[] plhs_ptr;
+ delete[] prhs_ptr;
+}
+
+} // namespace srcn
+} // namespace nnfw