summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEvan Shelhamer <shelhamer@imaginarynumber.net>2015-09-19 13:50:57 -0700
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2015-09-19 14:42:52 -0700
commit328df2450c534119f239ce1d606f8502922c6825 (patch)
tree82918faae59cbc3eb6a44e37e74fd07c24c1e0d9
parent68655b55925cfa5cd6543cdfb879bf6e68bd2a3c (diff)
downloadcaffeonacl-328df2450c534119f239ce1d606f8502922c6825.tar.gz
caffeonacl-328df2450c534119f239ce1d606f8502922c6825.tar.bz2
caffeonacl-328df2450c534119f239ce1d606f8502922c6825.zip
clarify im2col + col2im var names
- clarify indices by naming *_im for indices in image and *_col for indices in column - mark corresonding im2col + col2im quantities by renaming patch_* -> kernel_* - fix out-of-date names in equivalent col2im loop
-rw-r--r--include/caffe/util/im2col.hpp4
-rw-r--r--src/caffe/util/im2col.cpp72
-rw-r--r--src/caffe/util/im2col.cu69
3 files changed, 73 insertions, 72 deletions
diff --git a/include/caffe/util/im2col.hpp b/include/caffe/util/im2col.hpp
index 531fd29c..d3eb6ccd 100644
--- a/include/caffe/util/im2col.hpp
+++ b/include/caffe/util/im2col.hpp
@@ -23,7 +23,7 @@ void col2im_nd_cpu(const Dtype* data_col, const int num_spatial_axes,
template <typename Dtype>
void col2im_cpu(const Dtype* data_col, const int channels,
- const int height, const int width, const int patch_h, const int patch_w,
+ const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h,
const int stride_w, Dtype* data_im);
@@ -47,7 +47,7 @@ void col2im_nd_gpu(const Dtype* data_col, const int num_spatial_axes,
template <typename Dtype>
void col2im_gpu(const Dtype* data_col, const int channels,
- const int height, const int width, const int patch_h, const int patch_w,
+ const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h,
const int stride_w, Dtype* data_im);
diff --git a/src/caffe/util/im2col.cpp b/src/caffe/util/im2col.cpp
index b0a7be50..afeb5e5d 100644
--- a/src/caffe/util/im2col.cpp
+++ b/src/caffe/util/im2col.cpp
@@ -17,19 +17,19 @@ void im2col_cpu(const Dtype* data_im, const int channels,
int height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1;
int width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1;
int channels_col = channels * kernel_h * kernel_w;
- for (int c = 0; c < channels_col; ++c) {
- int w_offset = c % kernel_w;
- int h_offset = (c / kernel_w) % kernel_h;
- int c_im = c / kernel_h / kernel_w;
- for (int h = 0; h < height_col; ++h) {
- for (int w = 0; w < width_col; ++w) {
- int h_pad = h * stride_h - pad_h + h_offset;
- int w_pad = w * stride_w - pad_w + w_offset;
- if (h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width)
- data_col[(c * height_col + h) * width_col + w] =
- data_im[(c_im * height + h_pad) * width + w_pad];
+ for (int c_col = 0; c_col < channels_col; ++c_col) {
+ int w_offset = c_col % kernel_w;
+ int h_offset = (c_col / kernel_w) % kernel_h;
+ int c_im = c_col / kernel_h / kernel_w;
+ for (int h_col = 0; h_col < height_col; ++h_col) {
+ for (int w_col = 0; w_col < width_col; ++w_col) {
+ int h_im = h_col * stride_h - pad_h + h_offset;
+ int w_im = w_col * stride_w - pad_w + w_offset;
+ if (h_im >= 0 && h_im < height && w_im >= 0 && w_im < width)
+ data_col[(c_col * height_col + h_col) * width_col + w_col] =
+ data_im[(c_im * height + h_im) * width + w_im];
else
- data_col[(c * height_col + h) * width_col + w] = 0;
+ data_col[(c_col * height_col + h_im) * width_col + w_im] = 0;
}
}
}
@@ -64,9 +64,9 @@ inline void im2col_nd_core_cpu(const Dtype* data_input, const bool im2col,
const int channels_col = col_shape[0];
vector<int> d_offset(num_spatial_axes, 0);
vector<int> d_iter(num_spatial_axes, 0);
- for (int c = 0; c < channels_col; ++c) {
+ for (int c_col = 0; c_col < channels_col; ++c_col) {
// Loop over spatial axes in reverse order to compute a per-axis offset.
- int offset = c;
+ int offset = c_col;
for (int d_i = num_spatial_axes - 1; d_i >= 0; --d_i) {
if (d_i < num_spatial_axes - 1) {
offset /= kernel_shape[d_i + 1];
@@ -76,17 +76,17 @@ inline void im2col_nd_core_cpu(const Dtype* data_input, const bool im2col,
for (bool incremented = true; incremented; ) {
// Loop over spatial axes in forward order to compute the indices in the
// image and column, and whether the index lies in the padding.
- int index_col = c;
- int index_im = c / kernel_size;
+ int index_col = c_col;
+ int index_im = c_col / kernel_size;
bool is_padding = false;
for (int d_i = 0; d_i < num_spatial_axes; ++d_i) {
const int d = d_iter[d_i];
- const int d_pad = d * stride[d_i] - pad[d_i] + d_offset[d_i];
- is_padding |= d_pad < 0 || d_pad >= im_shape[d_i + 1];
+ const int d_im = d * stride[d_i] - pad[d_i] + d_offset[d_i];
+ is_padding |= d_im < 0 || d_im >= im_shape[d_i + 1];
index_col *= col_shape[d_i + 1];
index_col += d;
index_im *= im_shape[d_i + 1];
- index_im += d_pad;
+ index_im += d_im;
}
if (im2col) {
if (is_padding) {
@@ -139,25 +139,25 @@ template void im2col_nd_cpu<double>(const double* data_im,
template <typename Dtype>
void col2im_cpu(const Dtype* data_col, const int channels,
- const int height, const int width, const int patch_h, const int patch_w,
+ const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
Dtype* data_im) {
caffe_set(height * width * channels, Dtype(0), data_im);
- int height_col = (height + 2 * pad_h - patch_h) / stride_h + 1;
- int width_col = (width + 2 * pad_w - patch_w) / stride_w + 1;
- int channels_col = channels * patch_h * patch_w;
- for (int c = 0; c < channels_col; ++c) {
- int w_offset = c % patch_w;
- int h_offset = (c / patch_w) % patch_h;
- int c_im = c / patch_h / patch_w;
- for (int h = 0; h < height_col; ++h) {
- for (int w = 0; w < width_col; ++w) {
- int h_pad = h * stride_h - pad_h + h_offset;
- int w_pad = w * stride_w - pad_w + w_offset;
- if (h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width)
- data_im[(c_im * height + h_pad) * width + w_pad] +=
- data_col[(c * height_col + h) * width_col + w];
+ int height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1;
+ int width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1;
+ int channels_col = channels * kernel_h * kernel_w;
+ for (int c_col = 0; c_col < channels_col; ++c_col) {
+ int w_offset = c_col % kernel_w;
+ int h_offset = (c_col / kernel_w) % kernel_h;
+ int c_im = c_col / kernel_h / kernel_w;
+ for (int h_col = 0; h_col < height_col; ++h_col) {
+ for (int w_col = 0; w_col < width_col; ++w_col) {
+ int h_im = h_col * stride_h - pad_h + h_offset;
+ int w_im = w_col * stride_w - pad_w + w_offset;
+ if (h_im >= 0 && h_im < height && w_im >= 0 && w_im < width)
+ data_im[(c_im * height + h_im) * width + w_im] +=
+ data_col[(c_col * height_col + h_col) * width_col + w_col];
}
}
}
@@ -165,11 +165,11 @@ void col2im_cpu(const Dtype* data_col, const int channels,
// Explicit instantiation
template void col2im_cpu<float>(const float* data_col, const int channels,
- const int height, const int width, const int patch_h, const int patch_w,
+ const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h,
const int stride_w, float* data_im);
template void col2im_cpu<double>(const double* data_col, const int channels,
- const int height, const int width, const int patch_h, const int patch_w,
+ const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h,
const int stride_w, double* data_im);
diff --git a/src/caffe/util/im2col.cu b/src/caffe/util/im2col.cu
index 5a478ba6..897e3c92 100644
--- a/src/caffe/util/im2col.cu
+++ b/src/caffe/util/im2col.cu
@@ -16,22 +16,23 @@ __global__ void im2col_gpu_kernel(const int n, const Dtype* data_im,
const int height_col, const int width_col,
Dtype* data_col) {
CUDA_KERNEL_LOOP(index, n) {
- int w_out = index % width_col;
int h_index = index / width_col;
- int h_out = h_index % height_col;
- int channel_in = h_index / height_col;
- int channel_out = channel_in * kernel_h * kernel_w;
- int h_in = h_out * stride_h - pad_h;
- int w_in = w_out * stride_w - pad_w;
+ int h_col = h_index % height_col;
+ int w_col = index % width_col;
+ int c_im = h_index / height_col;
+ int c_col = c_im * kernel_h * kernel_w;
+ int h_offset = h_col * stride_h - pad_h;
+ int w_offset = w_col * stride_w - pad_w;
Dtype* data_col_ptr = data_col;
- data_col_ptr += (channel_out * height_col + h_out) * width_col + w_out;
+ data_col_ptr += (c_col * height_col + h_col) * width_col + w_col;
const Dtype* data_im_ptr = data_im;
- data_im_ptr += (channel_in * height + h_in) * width + w_in;
+ data_im_ptr += (c_im * height + h_offset) * width + w_offset;
for (int i = 0; i < kernel_h; ++i) {
for (int j = 0; j < kernel_w; ++j) {
- int h = h_in + i;
- int w = w_in + j;
- *data_col_ptr = (h >= 0 && w >= 0 && h < height && w < width) ?
+ int h_im = h_offset + i;
+ int w_im = w_offset + j;
+ *data_col_ptr =
+ (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) ?
data_im_ptr[i * width + j] : 0;
data_col_ptr += height_col * width_col;
}
@@ -222,35 +223,35 @@ template void im2col_nd_gpu<double>(const double* data_im,
template <typename Dtype>
__global__ void col2im_gpu_kernel(const int n, const Dtype* data_col,
const int height, const int width, const int channels,
- const int patch_h, const int patch_w,
+ const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int height_col, const int width_col,
Dtype* data_im) {
CUDA_KERNEL_LOOP(index, n) {
Dtype val = 0;
- int w = index % width + pad_w;
- int h = (index / width) % height + pad_h;
- int c = index / (width * height);
+ int w_im = index % width + pad_w;
+ int h_im = (index / width) % height + pad_h;
+ int c_im = index / (width * height);
// compute the start and end of the output
- int w_col_start = (w < patch_w) ? 0 : (w - patch_w) / stride_w + 1;
- int w_col_end = min(w / stride_w + 1, width_col);
- int h_col_start = (h < patch_h) ? 0 : (h - patch_h) / stride_h + 1;
- int h_col_end = min(h / stride_h + 1, height_col);
+ int w_col_start = (w_im < kernel_w) ? 0 : (w_im - kernel_w) / stride_w + 1;
+ int w_col_end = min(w_im / stride_w + 1, width_col);
+ int h_col_start = (h_im < kernel_h) ? 0 : (h_im - kernel_h) / stride_h + 1;
+ int h_col_end = min(h_im / stride_h + 1, height_col);
/*
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
// the col location: [c * width * height + h_out, w_out]
- int c_col = c * patch_h * patch_w + (h - h_col * stride_h) * ksize
- + (w - w_col * stride_w);
+ int c_col = c_im * kernel_h * kernel_w
+ + (h_im - h_col * stride_h) * kernel_w + (w_im - w_col * stride_w);
val += data_col[(c_col * height_col + h_col) * width_col + w_col];
}
}
*/
// equivalent implementation
- int offset =
- (c * patch_h * patch_w + h * patch_w + w) * height_col * width_col;
- int coeff_h_col = (1 - stride_h * patch_w * height_col) * width_col;
+ int offset = (c_im * kernel_h * kernel_w + h_im * kernel_w + w_im)
+ * height_col * width_col;
+ int coeff_h_col = (1 - stride_h * kernel_w * height_col) * width_col;
int coeff_w_col = (1 - stride_w * height_col * width_col);
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
@@ -263,18 +264,18 @@ __global__ void col2im_gpu_kernel(const int n, const Dtype* data_col,
template <typename Dtype>
void col2im_gpu(const Dtype* data_col, const int channels,
- const int height, const int width, const int patch_h, const int patch_w,
+ const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h,
const int stride_w, Dtype* data_im) {
- int height_col = (height + 2 * pad_h - patch_h) / stride_h + 1;
- int width_col = (width + 2 * pad_w - patch_w) / stride_w + 1;
+ int height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1;
+ int width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1;
int num_kernels = channels * height * width;
// To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions.
// NOLINT_NEXT_LINE(whitespace/operators)
col2im_gpu_kernel<Dtype><<<CAFFE_GET_BLOCKS(num_kernels),
CAFFE_CUDA_NUM_THREADS>>>(
- num_kernels, data_col, height, width, channels, patch_h, patch_w,
+ num_kernels, data_col, height, width, channels, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
height_col, width_col, data_im);
CUDA_POST_KERNEL_CHECK;
@@ -282,11 +283,11 @@ void col2im_gpu(const Dtype* data_col, const int channels,
// Explicit instantiation
template void col2im_gpu<float>(const float* data_col, const int channels,
- const int height, const int width, const int patch_h, const int patch_w,
+ const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h,
const int stride_w, float* data_im);
template void col2im_gpu<double>(const double* data_col, const int channels,
- const int height, const int width, const int patch_h, const int patch_w,
+ const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h,
const int stride_w, double* data_im);
@@ -302,11 +303,11 @@ __global__ void col2im_nd_gpu_kernel(const int n, const Dtype* data_col,
CUDA_KERNEL_LOOP(index, n) {
// Initialize channel_in, computed in the loop below, with intermediate
// computations used to compute the spatial indices.
- int channel_im = index;
+ int c_im = index;
// Calculate d_im (image dimensions).
for (int i = num_axes - 1; i >= 0; --i) {
- d_im[i] = channel_im % im_shape[i + 1] + pad[i];
- channel_im /= im_shape[i + 1];
+ d_im[i] = c_im % im_shape[i + 1] + pad[i];
+ c_im /= im_shape[i + 1];
}
// Calculate col start/end indices.
bool done = false;
@@ -338,7 +339,7 @@ __global__ void col2im_nd_gpu_kernel(const int n, const Dtype* data_col,
(d_im[i] - d_col_iter[i] * stride[i]) * kernel_shape_prod;
kernel_shape_prod *= kernel_shape[i];
}
- final_offset += kernel_shape_prod * channel_im;
+ final_offset += kernel_shape_prod * c_im;
for (int i = 0; i < num_axes; ++i) {
final_offset *= col_shape[i + 1];
final_offset += d_col_iter[i];