diff options
author | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2014-06-30 12:06:27 -0700 |
---|---|---|
committer | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2014-07-29 01:20:39 -0700 |
commit | 4d44fe7b5ecf94411db34484e7e86e07f6e735ab (patch) | |
tree | 2947272c0645e4bde918333f3cc81ad3e8756044 /src/caffe/util/im2col.cu | |
parent | edf438a4a6609f2e32eeeca4ce353ff9c9c4a905 (diff) | |
download | caffeonacl-4d44fe7b5ecf94411db34484e7e86e07f6e735ab.tar.gz caffeonacl-4d44fe7b5ecf94411db34484e7e86e07f6e735ab.tar.bz2 caffeonacl-4d44fe7b5ecf94411db34484e7e86e07f6e735ab.zip |
im2col + convolve non-square filters, padding, and stride
Diffstat (limited to 'src/caffe/util/im2col.cu')
-rw-r--r-- | src/caffe/util/im2col.cu | 91 |
1 files changed, 52 insertions, 39 deletions
diff --git a/src/caffe/util/im2col.cu b/src/caffe/util/im2col.cu index 79faa6cb..b565d2d3 100644 --- a/src/caffe/util/im2col.cu +++ b/src/caffe/util/im2col.cu @@ -12,23 +12,25 @@ namespace caffe { template <typename Dtype> __global__ void im2col_gpu_kernel(const int n, const Dtype* data_im, - const int height, const int width, const int ksize, const int pad, - const int stride, const int height_col, const int width_col, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int height_col, const int width_col, Dtype* data_col) { CUDA_KERNEL_LOOP(index, n) { int w_out = index % width_col; int h_index = index / width_col; int h_out = h_index % height_col; int channel_in = h_index / height_col; - int channel_out = channel_in * ksize * ksize; - int h_in = h_out * stride - pad; - int w_in = w_out * stride - pad; + int channel_out = channel_in * kernel_h * kernel_w; + int h_in = h_out * stride_h - pad_h; + int w_in = w_out * stride_w - pad_w; Dtype* data_col_ptr = data_col; data_col_ptr += (channel_out * height_col + h_out) * width_col + w_out; const Dtype* data_im_ptr = data_im; data_im_ptr += (channel_in * height + h_in) * width + w_in; - for (int i = 0; i < ksize; ++i) { - for (int j = 0; j < ksize; ++j) { + for (int i = 0; i < kernel_h; ++i) { + for (int j = 0; j < kernel_w; ++j) { int h = h_in + i; int w = w_in + j; *data_col_ptr = (h >= 0 && w >= 0 && h < height && w < width) ? @@ -41,17 +43,20 @@ __global__ void im2col_gpu_kernel(const int n, const Dtype* data_im, template <typename Dtype> void im2col_gpu(const Dtype* data_im, const int channels, - const int height, const int width, const int ksize, const int pad, - const int stride, Dtype* data_col) { + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + Dtype* data_col) { // We are going to launch channels * height_col * width_col kernels, each // kernel responsible for copying a single-channel grid. - int height_col = (height + 2 * pad - ksize) / stride + 1; - int width_col = (width + 2 * pad - ksize) / stride + 1; + int height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1; + int width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1; int num_kernels = channels * height_col * width_col; // NOLINT_NEXT_LINE(whitespace/operators) im2col_gpu_kernel<Dtype><<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>>( - num_kernels, data_im, height, width, ksize, pad, stride, height_col, + num_kernels, data_im, height, width, kernel_h, kernel_w, pad_h, + pad_w, stride_h, stride_w, height_col, width_col, data_col); CUDA_POST_KERNEL_CHECK; } @@ -59,40 +64,46 @@ void im2col_gpu(const Dtype* data_im, const int channels, // Explicit instantiation template void im2col_gpu<float>(const float* data_im, const int channels, - const int height, const int width, const int ksize, const int pad, - const int stride, float* data_col); + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + float* data_col); template void im2col_gpu<double>(const double* data_im, const int channels, - const int height, const int width, const int ksize, const int pad, - const int stride, double* data_col); + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + double* data_col); template <typename Dtype> __global__ void col2im_gpu_kernel(const int n, const Dtype* data_col, - const int height, const int width, const int channels, const int ksize, - const int pad, const int stride, const int height_col, const int width_col, + const int height, const int width, const int channels, + const int patch_h, const int patch_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int height_col, const int width_col, Dtype* data_im) { CUDA_KERNEL_LOOP(index, n) { Dtype val = 0; - int w = index % width + pad; - int h = (index / width) % height + pad; + int w = index % width + pad_w; + int h = (index / width) % height + pad_h; int c = index / (width * height); // compute the start and end of the output - int w_col_start = (w < ksize) ? 0 : (w - ksize) / stride + 1; - int w_col_end = min(w / stride + 1, width_col); - int h_col_start = (h < ksize) ? 0 : (h - ksize) / stride + 1; - int h_col_end = min(h / stride + 1, height_col); + int w_col_start = (w < patch_w) ? 0 : (w - patch_w) / stride_w + 1; + int w_col_end = min(w / stride_w + 1, width_col); + int h_col_start = (h < patch_h) ? 0 : (h - patch_h) / stride_h + 1; + int h_col_end = min(h / stride_h + 1, height_col); /* for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { // the col location: [c * width * height + h_out, w_out] - int c_col = c * ksize * ksize + (h - h_col * stride) * ksize + (w - w_col * stride); + int c_col = c * patch_h * patch_w + (h - h_col * stride_h) * ksize + + (w - w_col * stride_w); val += data_col[(c_col * height_col + h_col) * width_col + w_col]; } } */ // equivalent implementation - int offset = (c * ksize * ksize + h * ksize + w) * height_col * width_col; - int coeff_h_col = (1 - stride * ksize * height_col) * width_col; - int coeff_w_col = (1 - stride * height_col * width_col); + int offset = (c * patch_h * patch_w + h * patch_h + w) * height_col * width_col; + int coeff_h_col = (1 - stride_h * patch_w * height_col) * width_col; + int coeff_w_col = (1 - stride_w * height_col * width_col); for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { val += data_col[offset + h_col * coeff_h_col + w_col * coeff_w_col]; @@ -104,29 +115,31 @@ __global__ void col2im_gpu_kernel(const int n, const Dtype* data_col, template <typename Dtype> void col2im_gpu(const Dtype* data_col, const int channels, - const int height, const int width, const int ksize, const int pad, - const int stride, Dtype* data_im) { - int height_col = (height + 2 * pad - ksize) / stride + 1; - int width_col = (width + 2 * pad - ksize) / stride + 1; + const int height, const int width, const int patch_h, const int patch_w, + const int pad_h, const int pad_w, const int stride_h, + const int stride_w, Dtype* data_im) { + int height_col = (height + 2 * pad_h - patch_h) / stride_h + 1; + int width_col = (width + 2 * pad_w - patch_w) / stride_w + 1; int num_kernels = channels * height * width; // To avoid involving atomic operations, we will launch one kernel per // bottom dimension, and then in the kernel add up the top dimensions. // NOLINT_NEXT_LINE(whitespace/operators) col2im_gpu_kernel<Dtype><<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>>( - num_kernels, data_col, height, width, channels, ksize, pad, stride, + num_kernels, data_col, height, width, channels, patch_h, patch_w, + pad_h, pad_w, stride_h, stride_w, height_col, width_col, data_im); CUDA_POST_KERNEL_CHECK; } - // Explicit instantiation template void col2im_gpu<float>(const float* data_col, const int channels, - const int height, const int width, const int psize, const int pad, - const int stride, float* data_im); + const int height, const int width, const int patch_h, const int patch_w, + const int pad_h, const int pad_w, const int stride_h, + const int stride_w, float* data_im); template void col2im_gpu<double>(const double* data_col, const int channels, - const int height, const int width, const int psize, const int pad, - const int stride, double* data_im); - + const int height, const int width, const int patch_h, const int patch_w, + const int pad_h, const int pad_w, const int stride_h, + const int stride_w, double* data_im); } // namespace caffe |