summaryrefslogtreecommitdiff
path: root/compute/ncnn/src/srcn/sgemm_singlethread.cc
diff options
context:
space:
mode:
Diffstat (limited to 'compute/ncnn/src/srcn/sgemm_singlethread.cc')
-rw-r--r--compute/ncnn/src/srcn/sgemm_singlethread.cc689
1 files changed, 689 insertions, 0 deletions
diff --git a/compute/ncnn/src/srcn/sgemm_singlethread.cc b/compute/ncnn/src/srcn/sgemm_singlethread.cc
new file mode 100644
index 000000000..3de3e1214
--- /dev/null
+++ b/compute/ncnn/src/srcn/sgemm_singlethread.cc
@@ -0,0 +1,689 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <stdexcept>
+
+#include "common.h"
+#include "sgemm_kernel.h"
+#include "sgemm_pack.h"
+#include "sgemm_singlethread.h"
+
+namespace nnfw
+{
+namespace srcn
+{
+
+void sgemm_singlethread::param_init()
+{
+ if (n_ >= m_)
+ {
+ shard_type_ = shardByRow;
+ }
+ else
+ {
+ shard_type_ = shardByCol;
+ }
+
+#if __aarch64__
+ if (major_type_ == rowMajor)
+ {
+ if (shard_type_ == shardByRow)
+ {
+ mr_ = 8;
+ nr_ = 12;
+ }
+ else
+ {
+ mr_ = 12;
+ nr_ = 8;
+ }
+ }
+ else if (major_type_ == colMajor)
+ {
+ mr_ = 12;
+ nr_ = 8;
+ }
+#else // __aarch64__
+ if (major_type_ == rowMajor)
+ {
+ // it is a bug, but i do not know why as now.
+ if (ltrans_ == notrans && rtrans_ == trans)
+ {
+ mr_ = 4;
+ nr_ = 12;
+ }
+ else
+ {
+ mr_ = 6;
+ nr_ = 8;
+ }
+ }
+ else if (major_type_ == colMajor)
+ {
+ mr_ = 8;
+ nr_ = 6;
+ }
+#endif // __aarch64__
+
+ int k_div = (nr_ * sizeof_RhsScalar);
+ int k_sub = (mr_ * nr_ * sizeof_ResScalar);
+
+ int gen_col = GEN_COL / cache_div_;
+ int min_k = MAX_K / cache_div_;
+
+ const int k_cache = MIN(divup((int)(L1_CACHE_SIZE - k_sub), (int)k_div), min_k);
+ bk_ = MIN(k_cache, k_);
+
+ if (shard_type_ == shardByCol)
+ {
+ int m_sub = (bk_ * nr_ * sizeof_RhsScalar);
+ int m_div = (sizeof_LhsScalar * bk_ * 2 * cache_div_);
+ if (L3_CACHE_SIZE)
+ m_div = (sizeof_LhsScalar * bk_ * 2);
+ int m_cache = divup((L2_CACHE_SIZE - m_sub), m_div);
+ bm_ = MIN(m_cache, m_);
+
+ bn_ = MIN(gen_col, n_);
+ if (L3_CACHE_SIZE)
+ {
+ int n_sub = (bk_ * bm_ * sizeof_RhsScalar);
+ int n_cache = divup((L3_CACHE_SIZE - n_sub), (sizeof_LhsScalar * bk_ * 2));
+ bn_ = MIN(n_cache, bn_);
+ }
+ }
+ else
+ {
+ int n_sub = (bk_ * mr_ * sizeof_RhsScalar);
+ int n_div = (sizeof_LhsScalar * bk_ * 2 * cache_div_);
+ if (L3_CACHE_SIZE)
+ n_div = (sizeof_LhsScalar * bk_ * 2);
+ int n_cache = divup((L2_CACHE_SIZE - n_sub), n_div);
+ bn_ = MIN(n_cache, n_);
+
+ bm_ = MIN(gen_col, m_);
+ if (L3_CACHE_SIZE)
+ {
+ int m_sub = (bk_ * bn_ * sizeof_RhsScalar);
+ int m_cache = divup((L3_CACHE_SIZE - m_sub), (sizeof_LhsScalar * bk_ * 2));
+ bm_ = MIN(m_cache, bm_);
+ }
+ }
+
+ nm_ = divup(m_, bm_);
+ nn_ = divup(n_, bn_);
+ nk_ = divup(k_, bk_);
+
+ rm_ = m_ % bm_;
+ rn_ = n_ % bn_;
+ rk_ = k_ % bk_;
+}
+
+sgemm_singlethread::sgemm_singlethread(sgemmType_t major_type, sgemmTrans_t ltrans,
+ sgemmTrans_t rtrans, const int m, const int n, const int k,
+ const float *lhs_data, const float *rhs_data,
+ float *res_data, int cache_div)
+ : lhs_data_(lhs_data), rhs_data_(rhs_data), res_data_(res_data), major_type_(major_type),
+ ltrans_(ltrans), rtrans_(rtrans), m_(m), n_(n), k_(k), cache_div_(cache_div)
+{
+ param_init();
+}
+
+sgemm_singlethread::~sgemm_singlethread() {}
+
+void sgemm_singlethread::run()
+{
+ if (major_type_ == rowMajor)
+ {
+ if (ltrans_ == notrans && rtrans_ == notrans)
+ {
+ compute_rowmajor_nn();
+ }
+ else if (ltrans_ == notrans && rtrans_ == trans)
+ {
+ compute_rowmajor_nt();
+ }
+ else if (ltrans_ == trans && rtrans_ == notrans)
+ {
+ compute_rowmajor_tn();
+ }
+ else if (ltrans_ == trans && rtrans_ == trans)
+ {
+ compute_rowmajor_tt();
+ }
+ else
+ {
+ throw std::runtime_error{"error trans type."};
+ }
+ }
+ else if (major_type_ == colMajor)
+ {
+ if (ltrans_ == notrans && rtrans_ == notrans)
+ {
+ compute_colmajor_nn();
+ }
+ else if (ltrans_ == notrans && rtrans_ == trans)
+ {
+ compute_colmajor_nt();
+ }
+ else if (ltrans_ == trans && rtrans_ == notrans)
+ {
+ compute_colmajor_tn();
+ }
+ else if (ltrans_ == trans && rtrans_ == trans)
+ {
+ compute_colmajor_tt();
+ }
+ else
+ {
+ throw std::runtime_error{"error trans type."};
+ }
+ }
+ else
+ {
+ throw std::runtime_error{"error major type."};
+ }
+}
+
+void sgemm_singlethread::compute_rowmajor_nn()
+{
+ int mstride = (bm_ + mr_ - 1) / mr_ * mr_;
+ int nstride = (bn_ + nr_ - 1) / nr_ * nr_;
+
+ float plhs_ptr[mstride * bk_];
+ float prhs_ptr[nstride * bk_];
+
+ if (shard_type_ == shardByCol)
+ {
+ for (int j = 0; j < nn_; j++)
+ {
+ const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_;
+ for (int l = 0; l < nk_; l++)
+ {
+ const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_;
+
+ _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr);
+
+ for (int i = 0; i < nm_; i++)
+ {
+ const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_;
+
+ _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr);
+
+ _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr,
+ &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk);
+ }
+ }
+ }
+ }
+ else if (shard_type_ == shardByRow)
+ {
+ for (int i = 0; i < nm_; i++)
+ {
+ const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_;
+
+ for (int l = 0; l < nk_; l++)
+ {
+ const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_;
+
+ _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr);
+
+ for (int j = 0; j < nn_; j++)
+ {
+ const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_;
+
+ _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr);
+
+ _sgemm_rowmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr,
+ &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk);
+ }
+ }
+ }
+ }
+ else
+ {
+ throw std::runtime_error{"error shard type."};
+ }
+}
+
+void sgemm_singlethread::compute_rowmajor_nt()
+{
+ int mstride = (bm_ + mr_ - 1) / mr_ * mr_;
+ int nstride = (bn_ + nr_ - 1) / nr_ * nr_;
+
+ float plhs_ptr[mstride * bk_];
+ float prhs_ptr[nstride * bk_];
+
+ if (shard_type_ == shardByCol)
+ {
+ for (int j = 0; j < nn_; j++)
+ {
+ const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_;
+ for (int l = 0; l < nk_; l++)
+ {
+ const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_;
+
+ _pack_rowmajor_trans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr);
+
+ for (int i = 0; i < nm_; i++)
+ {
+ const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_;
+
+ _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr);
+
+ _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr,
+ &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk);
+ }
+ }
+ }
+ }
+ else if (shard_type_ == shardByRow)
+ {
+ for (int i = 0; i < nm_; i++)
+ {
+ const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_;
+
+ for (int l = 0; l < nk_; l++)
+ {
+ const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_;
+
+ _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr);
+
+ for (int j = 0; j < nn_; j++)
+ {
+ const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_;
+
+ _pack_rowmajor_trans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr);
+
+ _sgemm_rowmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr,
+ &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk);
+ }
+ }
+ }
+ }
+ else
+ {
+ throw std::runtime_error{"error shard type."};
+ }
+}
+
+void sgemm_singlethread::compute_rowmajor_tn()
+{
+ int mstride = (bm_ + mr_ - 1) / mr_ * mr_;
+ int nstride = (bn_ + nr_ - 1) / nr_ * nr_;
+
+ float plhs_ptr[mstride * bk_];
+ float prhs_ptr[nstride * bk_];
+
+ if (shard_type_ == shardByCol)
+ {
+ for (int j = 0; j < nn_; j++)
+ {
+ const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_;
+ for (int l = 0; l < nk_; l++)
+ {
+ const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_;
+
+ _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr);
+
+ for (int i = 0; i < nm_; i++)
+ {
+ const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_;
+
+ _pack_rowmajor_trans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr);
+
+ _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr,
+ &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk);
+ }
+ }
+ }
+ }
+ else if (shard_type_ == shardByRow)
+ {
+ for (int i = 0; i < nm_; i++)
+ {
+ const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_;
+
+ for (int l = 0; l < nk_; l++)
+ {
+ const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_;
+
+ _pack_rowmajor_trans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr);
+
+ for (int j = 0; j < nn_; j++)
+ {
+ const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_;
+
+ _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr);
+
+ _sgemm_rowmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr,
+ &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk);
+ }
+ }
+ }
+ }
+ else
+ {
+ throw std::runtime_error{"error shard type."};
+ }
+}
+
+void sgemm_singlethread::compute_rowmajor_tt()
+{
+ int mstride = (bm_ + mr_ - 1) / mr_ * mr_;
+ int nstride = (bn_ + nr_ - 1) / nr_ * nr_;
+
+ float plhs_ptr[mstride * bk_];
+ float prhs_ptr[nstride * bk_];
+
+ if (shard_type_ == shardByCol)
+ {
+ for (int j = 0; j < nn_; j++)
+ {
+ const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_;
+ for (int l = 0; l < nk_; l++)
+ {
+ const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_;
+
+ _pack_rowmajor_trans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr);
+
+ for (int i = 0; i < nm_; i++)
+ {
+ const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_;
+
+ _pack_rowmajor_trans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr);
+
+ _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr,
+ &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk);
+ }
+ }
+ }
+ }
+ else if (shard_type_ == shardByRow)
+ {
+ for (int i = 0; i < nm_; i++)
+ {
+ const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_;
+
+ for (int l = 0; l < nk_; l++)
+ {
+ const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_;
+
+ _pack_rowmajor_trans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr);
+
+ for (int j = 0; j < nn_; j++)
+ {
+ const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_;
+
+ _pack_rowmajor_trans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr);
+
+ _sgemm_rowmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr,
+ &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk);
+ }
+ }
+ }
+ }
+ else
+ {
+ throw std::runtime_error{"error shard type."};
+ }
+}
+
+void sgemm_singlethread::compute_colmajor_nn()
+{
+ int mstride = (bm_ + mr_ - 1) / mr_ * mr_;
+ int nstride = (bn_ + nr_ - 1) / nr_ * nr_;
+
+ float plhs_ptr[mstride * bk_];
+ float prhs_ptr[nstride * bk_];
+
+ if (shard_type_ == shardByCol)
+ {
+ for (int j = 0; j < nn_; j++)
+ {
+ const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_;
+ for (int l = 0; l < nk_; l++)
+ {
+ const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_;
+
+ _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr);
+
+ for (int i = 0; i < nm_; i++)
+ {
+ const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_;
+
+ _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr);
+
+ _sgemm_colmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr,
+ &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk);
+ }
+ }
+ }
+ }
+ else if (shard_type_ == shardByRow)
+ {
+ for (int i = 0; i < nm_; i++)
+ {
+ const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_;
+
+ for (int l = 0; l < nk_; l++)
+ {
+ const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_;
+
+ _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr);
+
+ for (int j = 0; j < nn_; j++)
+ {
+ const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_;
+
+ _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr);
+
+ _sgemm_colmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr,
+ &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk);
+ }
+ }
+ }
+ }
+ else
+ {
+ throw std::runtime_error{"error shard type."};
+ }
+}
+
+void sgemm_singlethread::compute_colmajor_nt()
+{
+ int mstride = (bm_ + mr_ - 1) / mr_ * mr_;
+ int nstride = (bn_ + nr_ - 1) / nr_ * nr_;
+
+ float plhs_ptr[mstride * bk_];
+ float prhs_ptr[nstride * bk_];
+
+ if (shard_type_ == shardByCol)
+ {
+ for (int j = 0; j < nn_; j++)
+ {
+ const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_;
+ for (int l = 0; l < nk_; l++)
+ {
+ const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_;
+
+ _pack_colmajor_trans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr);
+
+ for (int i = 0; i < nm_; i++)
+ {
+ const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_;
+
+ _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr);
+
+ _sgemm_colmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr,
+ &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk);
+ }
+ }
+ }
+ }
+ else if (shard_type_ == shardByRow)
+ {
+ for (int i = 0; i < nm_; i++)
+ {
+ const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_;
+
+ for (int l = 0; l < nk_; l++)
+ {
+ const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_;
+
+ _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr);
+
+ for (int j = 0; j < nn_; j++)
+ {
+ const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_;
+
+ _pack_colmajor_trans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr);
+
+ _sgemm_colmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr,
+ &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk);
+ }
+ }
+ }
+ }
+ else
+ {
+ throw std::runtime_error{"error shard type."};
+ }
+}
+
+void sgemm_singlethread::compute_colmajor_tn()
+{
+ int mstride = (bm_ + mr_ - 1) / mr_ * mr_;
+ int nstride = (bn_ + nr_ - 1) / nr_ * nr_;
+
+ float plhs_ptr[mstride * bk_];
+ float prhs_ptr[nstride * bk_];
+
+ if (shard_type_ == shardByCol)
+ {
+ for (int j = 0; j < nn_; j++)
+ {
+ const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_;
+ for (int l = 0; l < nk_; l++)
+ {
+ const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_;
+
+ _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr);
+
+ for (int i = 0; i < nm_; i++)
+ {
+ const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_;
+
+ _pack_colmajor_trans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr);
+
+ _sgemm_colmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr,
+ &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk);
+ }
+ }
+ }
+ }
+ else if (shard_type_ == shardByRow)
+ {
+ for (int i = 0; i < nm_; i++)
+ {
+ const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_;
+
+ for (int l = 0; l < nk_; l++)
+ {
+ const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_;
+
+ _pack_colmajor_trans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr);
+
+ for (int j = 0; j < nn_; j++)
+ {
+ const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_;
+
+ _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr);
+
+ _sgemm_colmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr,
+ &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk);
+ }
+ }
+ }
+ }
+ else
+ {
+ throw std::runtime_error{"error shard type."};
+ }
+}
+
+void sgemm_singlethread::compute_colmajor_tt()
+{
+ int mstride = (bm_ + mr_ - 1) / mr_ * mr_;
+ int nstride = (bn_ + nr_ - 1) / nr_ * nr_;
+
+ float plhs_ptr[mstride * bk_];
+ float prhs_ptr[nstride * bk_];
+
+ if (shard_type_ == shardByCol)
+ {
+ for (int j = 0; j < nn_; j++)
+ {
+ const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_;
+ for (int l = 0; l < nk_; l++)
+ {
+ const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_;
+
+ _pack_colmajor_trans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr);
+
+ for (int i = 0; i < nm_; i++)
+ {
+ const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_;
+
+ _pack_colmajor_trans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr);
+
+ _sgemm_colmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr,
+ &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk);
+ }
+ }
+ }
+ }
+ else if (shard_type_ == shardByRow)
+ {
+ for (int i = 0; i < nm_; i++)
+ {
+ const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_;
+
+ for (int l = 0; l < nk_; l++)
+ {
+ const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_;
+
+ _pack_colmajor_trans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr);
+
+ for (int j = 0; j < nn_; j++)
+ {
+ const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_;
+
+ _pack_colmajor_trans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr);
+
+ _sgemm_colmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr,
+ &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk);
+ }
+ }
+ }
+ }
+ else
+ {
+ throw std::runtime_error{"error shard type."};
+ }
+}
+
+} // namespace srcn
+} // namespace nnfw