diff options
Diffstat (limited to 'compute/ncnn/src/srcn/conv_winograd_batch.cc')
-rw-r--r-- | compute/ncnn/src/srcn/conv_winograd_batch.cc | 304 |
1 files changed, 0 insertions, 304 deletions
diff --git a/compute/ncnn/src/srcn/conv_winograd_batch.cc b/compute/ncnn/src/srcn/conv_winograd_batch.cc deleted file mode 100644 index cba45c648..000000000 --- a/compute/ncnn/src/srcn/conv_winograd_batch.cc +++ /dev/null @@ -1,304 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common.h" -#include "conv_winograd_batch.h" - -namespace std -{ -template <typename Dtype> static inline Dtype max(Dtype a, Dtype b) -{ - if (a > b) - return a; - else - return b; -} -} - -namespace nnfw -{ -namespace srcn -{ - -void conv_winograd_batch::param_init() -{ - if ((in_param_.kernel_w != in_param_.kernel_h) || (in_param_.stride_w != in_param_.stride_h) || - (in_param_.kernel_w != 3 && in_param_.kernel_w != 5) || (in_param_.stride_w != 1) || - (!winograd_weight_)) - { - error_ = 1; - return; - } - - int M, N; - const int w = in_mat_.w; - const int h = in_mat_.h; - const int outw = out_mat_.w; - const int outh = out_mat_.h; - const int pad_w = in_param_.pad_w; - const int pad_h = in_param_.pad_h; - - if (in_param_.kernel_w == 3) - { - if (w == 4) - { - M = winograd_para_3x3s1_2::M; - N = winograd_para_3x3s1_2::N; - } - else - { - M = winograd_para_3x3s1::M; - N = winograd_para_3x3s1::N; - } - } - else - { - M = winograd_para_5x5s1::M; - N = winograd_para_5x5s1::N; - } - - tile_h_in_ = tile_w_in_ = M; - tile_h_out_ = tile_h_in_ - N + 1; - tile_w_out_ = tile_w_in_ - N + 1; - ntiles_h_ = (std::max(h + pad_h - tile_h_in_ + 1, outh) + tile_h_out_ - 1) / tile_h_out_; - ntiles_w_ = (std::max(w + pad_w - tile_w_in_ + 1, outw) + tile_w_out_ - 1) / tile_w_out_; - - error_ = 0; -} - -conv_winograd_batch::conv_winograd_batch(const convMat_t &in_mat, convMat_t &out_mat, - const convParams_t &in_param, convType_t conv_type, - const float *winograd_weight, int num_threads) - : in_mat_(in_mat), out_mat_(out_mat), in_param_(in_param), conv_type_(conv_type), - winograd_weight_(winograd_weight), num_threads_(num_threads) -{ - param_init(); -} - -conv_winograd_batch::~conv_winograd_batch() {} - -void conv_winograd_batch::compute_sgemm(sgemmType_t major_type, sgemmTrans_t ltrans, - sgemmTrans_t rtrans, const int m, const int n, const int k, - const float *lhs_data, const float *rhs_data, - float *res_data) -{ - class sgemm_singlethread sgemm(major_type, ltrans, rtrans, m, n, k, lhs_data, rhs_data, res_data, - num_threads_); - - sgemm.run(); -} - -void conv_winograd_batch::winograd_input_im2col(float *col_buff) -{ - const int w = in_mat_.w; - const int h = in_mat_.h; - const float *data = in_mat_.data; - const int channels = in_mat_.c; - const int batch = in_mat_.n; - const int pad_w = in_param_.pad_w; - const int pad_h = in_param_.pad_h; - - // TODO: row_major - if (conv_type_ == col_major) - { - for (int n = 0; n < batch; n++) - { - for (int tile_h = 0; tile_h < ntiles_h_; ++tile_h) - { - for (int tile_w = 0; tile_w < ntiles_w_; ++tile_w) - { - for (int y = 0; y < tile_h_in_; ++y) - { - for (int x = 0; x < tile_w_in_; ++x) - { - for (int c = 0; c < channels; ++c) - { - int in_y = tile_h * tile_h_out_ + y - pad_h; - int in_x = tile_w * tile_w_out_ + x - pad_w; - - if (in_y < 0 || in_x < 0 || in_y >= h || in_x >= w) - { - col_buff[((((c * batch + n) * ntiles_h_ + tile_h) * ntiles_w_ + tile_w) * - tile_h_in_ + - y) * - tile_w_in_ + - x] = 0; - } - else - { - col_buff[((((c * batch + n) * ntiles_h_ + tile_h) * ntiles_w_ + tile_w) * - tile_h_in_ + - y) * - tile_w_in_ + - x] = data[((n * h + in_y) * w + in_x) * channels + c]; - } - } - } - } - } - } - } - } -} - -void conv_winograd_batch::winograd_output_col2im(const float *col_buff) -{ - int outh = out_mat_.h; - int outw = out_mat_.w; - float *data = out_mat_.data; - int channels = out_mat_.c; - int batch = out_mat_.n; - - // TODO: row_major - if (conv_type_ == col_major) - { - for (int n = 0; n < batch; n++) - { - for (int tile_h = 0; tile_h < ntiles_h_; ++tile_h) - { - for (int tile_w = 0; tile_w < ntiles_w_; ++tile_w) - { - for (int y = 0; y < tile_h_out_; ++y) - { - for (int x = 0; x < tile_w_out_; ++x) - { - for (int c = 0; c < channels; ++c) - { - int out_y = tile_h * tile_h_out_ + y; - int out_x = tile_w * tile_w_out_ + x; - if (out_y < outh && out_x < outw) - { - data[((n * outh + out_y) * outw + out_x) * channels + c] = - col_buff[((((c * batch + n) * ntiles_h_ + tile_h) * ntiles_w_ + tile_w) * - tile_h_out_ + - y) * - tile_w_out_ + - x]; - } - } - } - } - } - } - } - } -} - -void conv_winograd_batch::compute_winograd() -{ - const int w = in_mat_.w; - // const int h = in_mat_.h; - const int inch = in_mat_.c; - // const int outw = out_mat_.w; - // const int outh = out_mat_.h; - const int outch = out_mat_.c; - const int kernel_size = in_param_.kernel_w; - const int batch = in_mat_.n; - - int M, N; - const double *A; - const double *B; - - if (kernel_size == 3) - { - if (w == 4) - { - M = winograd_para_3x3s1_2::M; - N = winograd_para_3x3s1_2::N; - B = winograd_para_3x3s1_2::getB(); - A = winograd_para_3x3s1_2::getA(); - } - else - { - M = winograd_para_3x3s1::M; - N = winograd_para_3x3s1::N; - B = winograd_para_3x3s1::getB(); - A = winograd_para_3x3s1::getA(); - } - } - else - { - M = winograd_para_5x5s1::M; - N = winograd_para_5x5s1::N; - B = winograd_para_5x5s1::getB(); - A = winograd_para_5x5s1::getA(); - } - - /*Step 2: transfer image to winograd domain*/ - float *col_buff = - new float[std::max(outch, inch) * batch * ntiles_h_ * ntiles_w_ * tile_h_in_ * tile_w_in_]; - - int temp1_n = batch * inch * ntiles_h_ * ntiles_w_; - float *temp1_ = - new float[batch * tile_h_in_ * tile_w_in_ * std::max(outch, inch) * ntiles_h_ * ntiles_w_]; - - float *winograd_b = new float[M * M * M * M]; - - if ((NULL == col_buff) || (NULL == temp1_) || (NULL == winograd_b)) - { - delete[] col_buff; - delete[] temp1_; - delete[] winograd_b; - return; - } - - winograd_input_im2col(col_buff); - - kronecker_product(winograd_b, B, B, M, M, M, M); - - compute_sgemm(rowMajor, trans, trans, tile_h_in_ * tile_w_in_, temp1_n, tile_h_in_ * tile_w_in_, - winograd_b, col_buff, temp1_); - delete[] winograd_b; - - /*Step 3: convolution in winograd domain*/ - for (int j = 0; j < tile_h_in_ * tile_w_in_; ++j) - { - compute_sgemm(rowMajor, notrans, notrans, outch, batch * ntiles_h_ * ntiles_w_, inch, - winograd_weight_ + j * outch * inch, - temp1_ + j * batch * inch * ntiles_h_ * ntiles_w_, - col_buff + j * batch * outch * ntiles_h_ * ntiles_w_); - } - - /*Step 4: transfer back to time domain*/ - float *winograd_a = new float[M * (M - N + 1) * M * (M - N + 1)]; - if (NULL == winograd_a) - { - delete[] col_buff; - delete[] temp1_; - return; - } - - kronecker_product(winograd_a, A, A, M, M - N + 1, M, M - N + 1); - compute_sgemm(rowMajor, trans, notrans, batch * outch * ntiles_h_ * ntiles_w_, - tile_h_out_ * tile_w_out_, tile_h_in_ * tile_w_in_, col_buff, winograd_a, temp1_); - delete[] winograd_a; - delete[] col_buff; - - winograd_output_col2im(temp1_); - - delete[] temp1_; -} - -void conv_winograd_batch::run() -{ - if (error_) - return; - - compute_winograd(); -} - -} // namespace srcn -} // namespace nnfw |