diff options
Diffstat (limited to 'src/caffe/util/im2col.cpp')
-rw-r--r-- | src/caffe/util/im2col.cpp | 116 |
1 files changed, 116 insertions, 0 deletions
diff --git a/src/caffe/util/im2col.cpp b/src/caffe/util/im2col.cpp index c48f31f3..b0a7be50 100644 --- a/src/caffe/util/im2col.cpp +++ b/src/caffe/util/im2col.cpp @@ -1,6 +1,7 @@ #include <cmath> #include <cstdlib> #include <cstring> +#include <vector> #include "caffe/util/im2col.hpp" #include "caffe/util/math_functions.hpp" @@ -45,6 +46,98 @@ template void im2col_cpu<double>(const double* data_im, const int channels, const int stride_w, double* data_col); template <typename Dtype> +inline void im2col_nd_core_cpu(const Dtype* data_input, const bool im2col, + const int num_spatial_axes, const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + Dtype* data_output) { + if (!im2col) { + int im_size = im_shape[0]; + for (int i = 0; i < num_spatial_axes; ++i) { + im_size *= im_shape[1 + i]; + } + caffe_set(im_size, Dtype(0), data_output); + } + int kernel_size = 1; + for (int i = 0; i < num_spatial_axes; ++i) { + kernel_size *= kernel_shape[i]; + } + const int channels_col = col_shape[0]; + vector<int> d_offset(num_spatial_axes, 0); + vector<int> d_iter(num_spatial_axes, 0); + for (int c = 0; c < channels_col; ++c) { + // Loop over spatial axes in reverse order to compute a per-axis offset. + int offset = c; + for (int d_i = num_spatial_axes - 1; d_i >= 0; --d_i) { + if (d_i < num_spatial_axes - 1) { + offset /= kernel_shape[d_i + 1]; + } + d_offset[d_i] = offset % kernel_shape[d_i]; + } + for (bool incremented = true; incremented; ) { + // Loop over spatial axes in forward order to compute the indices in the + // image and column, and whether the index lies in the padding. + int index_col = c; + int index_im = c / kernel_size; + bool is_padding = false; + for (int d_i = 0; d_i < num_spatial_axes; ++d_i) { + const int d = d_iter[d_i]; + const int d_pad = d * stride[d_i] - pad[d_i] + d_offset[d_i]; + is_padding |= d_pad < 0 || d_pad >= im_shape[d_i + 1]; + index_col *= col_shape[d_i + 1]; + index_col += d; + index_im *= im_shape[d_i + 1]; + index_im += d_pad; + } + if (im2col) { + if (is_padding) { + data_output[index_col] = 0; + } else { + data_output[index_col] = data_input[index_im]; + } + } else if (!is_padding) { // col2im + data_output[index_im] += data_input[index_col]; + } + // Loop over spatial axes in reverse order to choose an index, + // like counting. + incremented = false; + for (int d_i = num_spatial_axes - 1; d_i >= 0; --d_i) { + const int d_max = col_shape[d_i + 1]; + DCHECK_LT(d_iter[d_i], d_max); + if (d_iter[d_i] == d_max - 1) { + d_iter[d_i] = 0; + } else { // d_iter[d_i] < d_max - 1 + ++d_iter[d_i]; + incremented = true; + break; + } + } + } // while(incremented) { + } // for (int c = 0; c < channels_col; ++c) { +} + +template <typename Dtype> +void im2col_nd_cpu(const Dtype* data_im, const int num_spatial_axes, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + Dtype* data_col) { + const bool kIm2Col = true; + im2col_nd_core_cpu(data_im, kIm2Col, num_spatial_axes, im_shape, col_shape, + kernel_shape, pad, stride, data_col); +} + +// Explicit instantiation +template void im2col_nd_cpu<float>(const float* data_im, + const int num_spatial_axes, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + float* data_col); +template void im2col_nd_cpu<double>(const double* data_im, + const int num_spatial_axes, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + double* data_col); + +template <typename Dtype> void col2im_cpu(const Dtype* data_col, const int channels, const int height, const int width, const int patch_h, const int patch_w, const int pad_h, const int pad_w, @@ -80,4 +173,27 @@ template void col2im_cpu<double>(const double* data_col, const int channels, const int pad_h, const int pad_w, const int stride_h, const int stride_w, double* data_im); +template <typename Dtype> +void col2im_nd_cpu(const Dtype* data_col, const int num_spatial_axes, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + Dtype* data_im) { + const bool kIm2Col = false; + im2col_nd_core_cpu(data_col, kIm2Col, num_spatial_axes, im_shape, col_shape, + kernel_shape, pad, stride, data_im); +} + +// Explicit instantiation +template void col2im_nd_cpu<float>(const float* data_col, + const int num_spatial_axes, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + float* data_im); +template void col2im_nd_cpu<double>(const double* data_col, + const int num_spatial_axes, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + double* data_im); + + } // namespace caffe |