diff options
Diffstat (limited to 'compute/ncnn/src/srcn/sgemm_singlethread.cc')
-rw-r--r-- | compute/ncnn/src/srcn/sgemm_singlethread.cc | 689 |
1 files changed, 689 insertions, 0 deletions
diff --git a/compute/ncnn/src/srcn/sgemm_singlethread.cc b/compute/ncnn/src/srcn/sgemm_singlethread.cc new file mode 100644 index 000000000..3de3e1214 --- /dev/null +++ b/compute/ncnn/src/srcn/sgemm_singlethread.cc @@ -0,0 +1,689 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <stdexcept> + +#include "common.h" +#include "sgemm_kernel.h" +#include "sgemm_pack.h" +#include "sgemm_singlethread.h" + +namespace nnfw +{ +namespace srcn +{ + +void sgemm_singlethread::param_init() +{ + if (n_ >= m_) + { + shard_type_ = shardByRow; + } + else + { + shard_type_ = shardByCol; + } + +#if __aarch64__ + if (major_type_ == rowMajor) + { + if (shard_type_ == shardByRow) + { + mr_ = 8; + nr_ = 12; + } + else + { + mr_ = 12; + nr_ = 8; + } + } + else if (major_type_ == colMajor) + { + mr_ = 12; + nr_ = 8; + } +#else // __aarch64__ + if (major_type_ == rowMajor) + { + // it is a bug, but i do not know why as now. + if (ltrans_ == notrans && rtrans_ == trans) + { + mr_ = 4; + nr_ = 12; + } + else + { + mr_ = 6; + nr_ = 8; + } + } + else if (major_type_ == colMajor) + { + mr_ = 8; + nr_ = 6; + } +#endif // __aarch64__ + + int k_div = (nr_ * sizeof_RhsScalar); + int k_sub = (mr_ * nr_ * sizeof_ResScalar); + + int gen_col = GEN_COL / cache_div_; + int min_k = MAX_K / cache_div_; + + const int k_cache = MIN(divup((int)(L1_CACHE_SIZE - k_sub), (int)k_div), min_k); + bk_ = MIN(k_cache, k_); + + if (shard_type_ == shardByCol) + { + int m_sub = (bk_ * nr_ * sizeof_RhsScalar); + int m_div = (sizeof_LhsScalar * bk_ * 2 * cache_div_); + if (L3_CACHE_SIZE) + m_div = (sizeof_LhsScalar * bk_ * 2); + int m_cache = divup((L2_CACHE_SIZE - m_sub), m_div); + bm_ = MIN(m_cache, m_); + + bn_ = MIN(gen_col, n_); + if (L3_CACHE_SIZE) + { + int n_sub = (bk_ * bm_ * sizeof_RhsScalar); + int n_cache = divup((L3_CACHE_SIZE - n_sub), (sizeof_LhsScalar * bk_ * 2)); + bn_ = MIN(n_cache, bn_); + } + } + else + { + int n_sub = (bk_ * mr_ * sizeof_RhsScalar); + int n_div = (sizeof_LhsScalar * bk_ * 2 * cache_div_); + if (L3_CACHE_SIZE) + n_div = (sizeof_LhsScalar * bk_ * 2); + int n_cache = divup((L2_CACHE_SIZE - n_sub), n_div); + bn_ = MIN(n_cache, n_); + + bm_ = MIN(gen_col, m_); + if (L3_CACHE_SIZE) + { + int m_sub = (bk_ * bn_ * sizeof_RhsScalar); + int m_cache = divup((L3_CACHE_SIZE - m_sub), (sizeof_LhsScalar * bk_ * 2)); + bm_ = MIN(m_cache, bm_); + } + } + + nm_ = divup(m_, bm_); + nn_ = divup(n_, bn_); + nk_ = divup(k_, bk_); + + rm_ = m_ % bm_; + rn_ = n_ % bn_; + rk_ = k_ % bk_; +} + +sgemm_singlethread::sgemm_singlethread(sgemmType_t major_type, sgemmTrans_t ltrans, + sgemmTrans_t rtrans, const int m, const int n, const int k, + const float *lhs_data, const float *rhs_data, + float *res_data, int cache_div) + : lhs_data_(lhs_data), rhs_data_(rhs_data), res_data_(res_data), major_type_(major_type), + ltrans_(ltrans), rtrans_(rtrans), m_(m), n_(n), k_(k), cache_div_(cache_div) +{ + param_init(); +} + +sgemm_singlethread::~sgemm_singlethread() {} + +void sgemm_singlethread::run() +{ + if (major_type_ == rowMajor) + { + if (ltrans_ == notrans && rtrans_ == notrans) + { + compute_rowmajor_nn(); + } + else if (ltrans_ == notrans && rtrans_ == trans) + { + compute_rowmajor_nt(); + } + else if (ltrans_ == trans && rtrans_ == notrans) + { + compute_rowmajor_tn(); + } + else if (ltrans_ == trans && rtrans_ == trans) + { + compute_rowmajor_tt(); + } + else + { + throw std::runtime_error{"error trans type."}; + } + } + else if (major_type_ == colMajor) + { + if (ltrans_ == notrans && rtrans_ == notrans) + { + compute_colmajor_nn(); + } + else if (ltrans_ == notrans && rtrans_ == trans) + { + compute_colmajor_nt(); + } + else if (ltrans_ == trans && rtrans_ == notrans) + { + compute_colmajor_tn(); + } + else if (ltrans_ == trans && rtrans_ == trans) + { + compute_colmajor_tt(); + } + else + { + throw std::runtime_error{"error trans type."}; + } + } + else + { + throw std::runtime_error{"error major type."}; + } +} + +void sgemm_singlethread::compute_rowmajor_nn() +{ + int mstride = (bm_ + mr_ - 1) / mr_ * mr_; + int nstride = (bn_ + nr_ - 1) / nr_ * nr_; + + float plhs_ptr[mstride * bk_]; + float prhs_ptr[nstride * bk_]; + + if (shard_type_ == shardByCol) + { + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr); + + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr); + + _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk); + } + } + } + } + else if (shard_type_ == shardByRow) + { + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr); + + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr); + + _sgemm_rowmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk); + } + } + } + } + else + { + throw std::runtime_error{"error shard type."}; + } +} + +void sgemm_singlethread::compute_rowmajor_nt() +{ + int mstride = (bm_ + mr_ - 1) / mr_ * mr_; + int nstride = (bn_ + nr_ - 1) / nr_ * nr_; + + float plhs_ptr[mstride * bk_]; + float prhs_ptr[nstride * bk_]; + + if (shard_type_ == shardByCol) + { + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_rowmajor_trans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr); + + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr); + + _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk); + } + } + } + } + else if (shard_type_ == shardByRow) + { + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr); + + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + _pack_rowmajor_trans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr); + + _sgemm_rowmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk); + } + } + } + } + else + { + throw std::runtime_error{"error shard type."}; + } +} + +void sgemm_singlethread::compute_rowmajor_tn() +{ + int mstride = (bm_ + mr_ - 1) / mr_ * mr_; + int nstride = (bn_ + nr_ - 1) / nr_ * nr_; + + float plhs_ptr[mstride * bk_]; + float prhs_ptr[nstride * bk_]; + + if (shard_type_ == shardByCol) + { + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr); + + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_rowmajor_trans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr); + + _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk); + } + } + } + } + else if (shard_type_ == shardByRow) + { + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_rowmajor_trans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr); + + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr); + + _sgemm_rowmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk); + } + } + } + } + else + { + throw std::runtime_error{"error shard type."}; + } +} + +void sgemm_singlethread::compute_rowmajor_tt() +{ + int mstride = (bm_ + mr_ - 1) / mr_ * mr_; + int nstride = (bn_ + nr_ - 1) / nr_ * nr_; + + float plhs_ptr[mstride * bk_]; + float prhs_ptr[nstride * bk_]; + + if (shard_type_ == shardByCol) + { + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_rowmajor_trans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr); + + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_rowmajor_trans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr); + + _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk); + } + } + } + } + else if (shard_type_ == shardByRow) + { + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_rowmajor_trans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr); + + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + _pack_rowmajor_trans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr); + + _sgemm_rowmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk); + } + } + } + } + else + { + throw std::runtime_error{"error shard type."}; + } +} + +void sgemm_singlethread::compute_colmajor_nn() +{ + int mstride = (bm_ + mr_ - 1) / mr_ * mr_; + int nstride = (bn_ + nr_ - 1) / nr_ * nr_; + + float plhs_ptr[mstride * bk_]; + float prhs_ptr[nstride * bk_]; + + if (shard_type_ == shardByCol) + { + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr); + + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr); + + _sgemm_colmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk); + } + } + } + } + else if (shard_type_ == shardByRow) + { + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr); + + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr); + + _sgemm_colmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk); + } + } + } + } + else + { + throw std::runtime_error{"error shard type."}; + } +} + +void sgemm_singlethread::compute_colmajor_nt() +{ + int mstride = (bm_ + mr_ - 1) / mr_ * mr_; + int nstride = (bn_ + nr_ - 1) / nr_ * nr_; + + float plhs_ptr[mstride * bk_]; + float prhs_ptr[nstride * bk_]; + + if (shard_type_ == shardByCol) + { + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_colmajor_trans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr); + + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr); + + _sgemm_colmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk); + } + } + } + } + else if (shard_type_ == shardByRow) + { + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr); + + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + _pack_colmajor_trans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr); + + _sgemm_colmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk); + } + } + } + } + else + { + throw std::runtime_error{"error shard type."}; + } +} + +void sgemm_singlethread::compute_colmajor_tn() +{ + int mstride = (bm_ + mr_ - 1) / mr_ * mr_; + int nstride = (bn_ + nr_ - 1) / nr_ * nr_; + + float plhs_ptr[mstride * bk_]; + float prhs_ptr[nstride * bk_]; + + if (shard_type_ == shardByCol) + { + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr); + + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_colmajor_trans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr); + + _sgemm_colmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk); + } + } + } + } + else if (shard_type_ == shardByRow) + { + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_colmajor_trans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr); + + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr); + + _sgemm_colmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk); + } + } + } + } + else + { + throw std::runtime_error{"error shard type."}; + } +} + +void sgemm_singlethread::compute_colmajor_tt() +{ + int mstride = (bm_ + mr_ - 1) / mr_ * mr_; + int nstride = (bn_ + nr_ - 1) / nr_ * nr_; + + float plhs_ptr[mstride * bk_]; + float prhs_ptr[nstride * bk_]; + + if (shard_type_ == shardByCol) + { + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_colmajor_trans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr); + + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_colmajor_trans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr); + + _sgemm_colmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk); + } + } + } + } + else if (shard_type_ == shardByRow) + { + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_colmajor_trans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr); + + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + _pack_colmajor_trans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr); + + _sgemm_colmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk); + } + } + } + } + else + { + throw std::runtime_error{"error shard type."}; + } +} + +} // namespace srcn +} // namespace nnfw |