summaryrefslogtreecommitdiff
path: root/compute/ncnn/src/srcn/conv_winograd_batch.cc
diff options
context:
space:
mode:
Diffstat (limited to 'compute/ncnn/src/srcn/conv_winograd_batch.cc')
-rw-r--r--compute/ncnn/src/srcn/conv_winograd_batch.cc304
1 files changed, 0 insertions, 304 deletions
diff --git a/compute/ncnn/src/srcn/conv_winograd_batch.cc b/compute/ncnn/src/srcn/conv_winograd_batch.cc
deleted file mode 100644
index cba45c648..000000000
--- a/compute/ncnn/src/srcn/conv_winograd_batch.cc
+++ /dev/null
@@ -1,304 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "common.h"
-#include "conv_winograd_batch.h"
-
-namespace std
-{
-template <typename Dtype> static inline Dtype max(Dtype a, Dtype b)
-{
- if (a > b)
- return a;
- else
- return b;
-}
-}
-
-namespace nnfw
-{
-namespace srcn
-{
-
-void conv_winograd_batch::param_init()
-{
- if ((in_param_.kernel_w != in_param_.kernel_h) || (in_param_.stride_w != in_param_.stride_h) ||
- (in_param_.kernel_w != 3 && in_param_.kernel_w != 5) || (in_param_.stride_w != 1) ||
- (!winograd_weight_))
- {
- error_ = 1;
- return;
- }
-
- int M, N;
- const int w = in_mat_.w;
- const int h = in_mat_.h;
- const int outw = out_mat_.w;
- const int outh = out_mat_.h;
- const int pad_w = in_param_.pad_w;
- const int pad_h = in_param_.pad_h;
-
- if (in_param_.kernel_w == 3)
- {
- if (w == 4)
- {
- M = winograd_para_3x3s1_2::M;
- N = winograd_para_3x3s1_2::N;
- }
- else
- {
- M = winograd_para_3x3s1::M;
- N = winograd_para_3x3s1::N;
- }
- }
- else
- {
- M = winograd_para_5x5s1::M;
- N = winograd_para_5x5s1::N;
- }
-
- tile_h_in_ = tile_w_in_ = M;
- tile_h_out_ = tile_h_in_ - N + 1;
- tile_w_out_ = tile_w_in_ - N + 1;
- ntiles_h_ = (std::max(h + pad_h - tile_h_in_ + 1, outh) + tile_h_out_ - 1) / tile_h_out_;
- ntiles_w_ = (std::max(w + pad_w - tile_w_in_ + 1, outw) + tile_w_out_ - 1) / tile_w_out_;
-
- error_ = 0;
-}
-
-conv_winograd_batch::conv_winograd_batch(const convMat_t &in_mat, convMat_t &out_mat,
- const convParams_t &in_param, convType_t conv_type,
- const float *winograd_weight, int num_threads)
- : in_mat_(in_mat), out_mat_(out_mat), in_param_(in_param), conv_type_(conv_type),
- winograd_weight_(winograd_weight), num_threads_(num_threads)
-{
- param_init();
-}
-
-conv_winograd_batch::~conv_winograd_batch() {}
-
-void conv_winograd_batch::compute_sgemm(sgemmType_t major_type, sgemmTrans_t ltrans,
- sgemmTrans_t rtrans, const int m, const int n, const int k,
- const float *lhs_data, const float *rhs_data,
- float *res_data)
-{
- class sgemm_singlethread sgemm(major_type, ltrans, rtrans, m, n, k, lhs_data, rhs_data, res_data,
- num_threads_);
-
- sgemm.run();
-}
-
-void conv_winograd_batch::winograd_input_im2col(float *col_buff)
-{
- const int w = in_mat_.w;
- const int h = in_mat_.h;
- const float *data = in_mat_.data;
- const int channels = in_mat_.c;
- const int batch = in_mat_.n;
- const int pad_w = in_param_.pad_w;
- const int pad_h = in_param_.pad_h;
-
- // TODO: row_major
- if (conv_type_ == col_major)
- {
- for (int n = 0; n < batch; n++)
- {
- for (int tile_h = 0; tile_h < ntiles_h_; ++tile_h)
- {
- for (int tile_w = 0; tile_w < ntiles_w_; ++tile_w)
- {
- for (int y = 0; y < tile_h_in_; ++y)
- {
- for (int x = 0; x < tile_w_in_; ++x)
- {
- for (int c = 0; c < channels; ++c)
- {
- int in_y = tile_h * tile_h_out_ + y - pad_h;
- int in_x = tile_w * tile_w_out_ + x - pad_w;
-
- if (in_y < 0 || in_x < 0 || in_y >= h || in_x >= w)
- {
- col_buff[((((c * batch + n) * ntiles_h_ + tile_h) * ntiles_w_ + tile_w) *
- tile_h_in_ +
- y) *
- tile_w_in_ +
- x] = 0;
- }
- else
- {
- col_buff[((((c * batch + n) * ntiles_h_ + tile_h) * ntiles_w_ + tile_w) *
- tile_h_in_ +
- y) *
- tile_w_in_ +
- x] = data[((n * h + in_y) * w + in_x) * channels + c];
- }
- }
- }
- }
- }
- }
- }
- }
-}
-
-void conv_winograd_batch::winograd_output_col2im(const float *col_buff)
-{
- int outh = out_mat_.h;
- int outw = out_mat_.w;
- float *data = out_mat_.data;
- int channels = out_mat_.c;
- int batch = out_mat_.n;
-
- // TODO: row_major
- if (conv_type_ == col_major)
- {
- for (int n = 0; n < batch; n++)
- {
- for (int tile_h = 0; tile_h < ntiles_h_; ++tile_h)
- {
- for (int tile_w = 0; tile_w < ntiles_w_; ++tile_w)
- {
- for (int y = 0; y < tile_h_out_; ++y)
- {
- for (int x = 0; x < tile_w_out_; ++x)
- {
- for (int c = 0; c < channels; ++c)
- {
- int out_y = tile_h * tile_h_out_ + y;
- int out_x = tile_w * tile_w_out_ + x;
- if (out_y < outh && out_x < outw)
- {
- data[((n * outh + out_y) * outw + out_x) * channels + c] =
- col_buff[((((c * batch + n) * ntiles_h_ + tile_h) * ntiles_w_ + tile_w) *
- tile_h_out_ +
- y) *
- tile_w_out_ +
- x];
- }
- }
- }
- }
- }
- }
- }
- }
-}
-
-void conv_winograd_batch::compute_winograd()
-{
- const int w = in_mat_.w;
- // const int h = in_mat_.h;
- const int inch = in_mat_.c;
- // const int outw = out_mat_.w;
- // const int outh = out_mat_.h;
- const int outch = out_mat_.c;
- const int kernel_size = in_param_.kernel_w;
- const int batch = in_mat_.n;
-
- int M, N;
- const double *A;
- const double *B;
-
- if (kernel_size == 3)
- {
- if (w == 4)
- {
- M = winograd_para_3x3s1_2::M;
- N = winograd_para_3x3s1_2::N;
- B = winograd_para_3x3s1_2::getB();
- A = winograd_para_3x3s1_2::getA();
- }
- else
- {
- M = winograd_para_3x3s1::M;
- N = winograd_para_3x3s1::N;
- B = winograd_para_3x3s1::getB();
- A = winograd_para_3x3s1::getA();
- }
- }
- else
- {
- M = winograd_para_5x5s1::M;
- N = winograd_para_5x5s1::N;
- B = winograd_para_5x5s1::getB();
- A = winograd_para_5x5s1::getA();
- }
-
- /*Step 2: transfer image to winograd domain*/
- float *col_buff =
- new float[std::max(outch, inch) * batch * ntiles_h_ * ntiles_w_ * tile_h_in_ * tile_w_in_];
-
- int temp1_n = batch * inch * ntiles_h_ * ntiles_w_;
- float *temp1_ =
- new float[batch * tile_h_in_ * tile_w_in_ * std::max(outch, inch) * ntiles_h_ * ntiles_w_];
-
- float *winograd_b = new float[M * M * M * M];
-
- if ((NULL == col_buff) || (NULL == temp1_) || (NULL == winograd_b))
- {
- delete[] col_buff;
- delete[] temp1_;
- delete[] winograd_b;
- return;
- }
-
- winograd_input_im2col(col_buff);
-
- kronecker_product(winograd_b, B, B, M, M, M, M);
-
- compute_sgemm(rowMajor, trans, trans, tile_h_in_ * tile_w_in_, temp1_n, tile_h_in_ * tile_w_in_,
- winograd_b, col_buff, temp1_);
- delete[] winograd_b;
-
- /*Step 3: convolution in winograd domain*/
- for (int j = 0; j < tile_h_in_ * tile_w_in_; ++j)
- {
- compute_sgemm(rowMajor, notrans, notrans, outch, batch * ntiles_h_ * ntiles_w_, inch,
- winograd_weight_ + j * outch * inch,
- temp1_ + j * batch * inch * ntiles_h_ * ntiles_w_,
- col_buff + j * batch * outch * ntiles_h_ * ntiles_w_);
- }
-
- /*Step 4: transfer back to time domain*/
- float *winograd_a = new float[M * (M - N + 1) * M * (M - N + 1)];
- if (NULL == winograd_a)
- {
- delete[] col_buff;
- delete[] temp1_;
- return;
- }
-
- kronecker_product(winograd_a, A, A, M, M - N + 1, M, M - N + 1);
- compute_sgemm(rowMajor, trans, notrans, batch * outch * ntiles_h_ * ntiles_w_,
- tile_h_out_ * tile_w_out_, tile_h_in_ * tile_w_in_, col_buff, winograd_a, temp1_);
- delete[] winograd_a;
- delete[] col_buff;
-
- winograd_output_col2im(temp1_);
-
- delete[] temp1_;
-}
-
-void conv_winograd_batch::run()
-{
- if (error_)
- return;
-
- compute_winograd();
-}
-
-} // namespace srcn
-} // namespace nnfw