diff options
Diffstat (limited to 'src/caffe')
-rw-r--r-- | src/caffe/layer_factory.cpp | 17 | ||||
-rw-r--r-- | src/caffe/layers/base_conv_layer.cpp | 20 | ||||
-rw-r--r-- | src/caffe/layers/conv_layer.cpp | 4 | ||||
-rw-r--r-- | src/caffe/layers/im2col_layer.cpp | 21 | ||||
-rw-r--r-- | src/caffe/layers/im2col_layer.cu | 2 | ||||
-rw-r--r-- | src/caffe/proto/caffe.proto | 1 | ||||
-rw-r--r-- | src/caffe/test/test_convolution_layer.cpp | 14 | ||||
-rw-r--r-- | src/caffe/test/test_im2col_kernel.cu | 17 | ||||
-rw-r--r-- | src/caffe/test/test_im2col_layer.cpp | 3 | ||||
-rw-r--r-- | src/caffe/util/im2col.cpp | 34 | ||||
-rw-r--r-- | src/caffe/util/im2col.cu | 80 |
11 files changed, 147 insertions, 66 deletions
diff --git a/src/caffe/layer_factory.cpp b/src/caffe/layer_factory.cpp index 76d851af..6b1d1c1a 100644 --- a/src/caffe/layer_factory.cpp +++ b/src/caffe/layer_factory.cpp @@ -37,17 +37,30 @@ namespace caffe { template <typename Dtype> shared_ptr<Layer<Dtype> > GetConvolutionLayer( const LayerParameter& param) { - ConvolutionParameter_Engine engine = param.convolution_param().engine(); + ConvolutionParameter conv_param = param.convolution_param(); + ConvolutionParameter_Engine engine = conv_param.engine(); + bool use_dilation = false; + for (int i = 0; i < conv_param.dilation_size(); ++i) { + if (conv_param.dilation(i) > 1) { + use_dilation = true; + } + } if (engine == ConvolutionParameter_Engine_DEFAULT) { engine = ConvolutionParameter_Engine_CAFFE; #ifdef USE_CUDNN - engine = ConvolutionParameter_Engine_CUDNN; + if (!use_dilation) { + engine = ConvolutionParameter_Engine_CUDNN; + } #endif } if (engine == ConvolutionParameter_Engine_CAFFE) { return shared_ptr<Layer<Dtype> >(new ConvolutionLayer<Dtype>(param)); #ifdef USE_CUDNN } else if (engine == ConvolutionParameter_Engine_CUDNN) { + if (use_dilation) { + LOG(FATAL) << "CuDNN doesn't support the dilated convolution at Layer " + << param.name(); + } return shared_ptr<Layer<Dtype> >(new CuDNNConvolutionLayer<Dtype>(param)); #endif } else { diff --git a/src/caffe/layers/base_conv_layer.cpp b/src/caffe/layers/base_conv_layer.cpp index f6f14cd0..4a4c68e0 100644 --- a/src/caffe/layers/base_conv_layer.cpp +++ b/src/caffe/layers/base_conv_layer.cpp @@ -36,7 +36,7 @@ void BaseConvolutionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom, CHECK(num_kernel_dims == 1 || num_kernel_dims == num_spatial_axes_) << "kernel_size must be specified once, or once per spatial dimension " << "(kernel_size specified " << num_kernel_dims << " times; " - << num_spatial_axes_ << " spatial dims);"; + << num_spatial_axes_ << " spatial dims)."; for (int i = 0; i < num_spatial_axes_; ++i) { kernel_shape_data[i] = conv_param.kernel_size((num_kernel_dims == 1) ? 0 : i); @@ -61,7 +61,7 @@ void BaseConvolutionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom, num_stride_dims == num_spatial_axes_) << "stride must be specified once, or once per spatial dimension " << "(stride specified " << num_stride_dims << " times; " - << num_spatial_axes_ << " spatial dims);"; + << num_spatial_axes_ << " spatial dims)."; const int kDefaultStride = 1; for (int i = 0; i < num_spatial_axes_; ++i) { stride_data[i] = (num_stride_dims == 0) ? kDefaultStride : @@ -85,13 +85,27 @@ void BaseConvolutionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom, num_pad_dims == num_spatial_axes_) << "pad must be specified once, or once per spatial dimension " << "(pad specified " << num_pad_dims << " times; " - << num_spatial_axes_ << " spatial dims);"; + << num_spatial_axes_ << " spatial dims)."; const int kDefaultPad = 0; for (int i = 0; i < num_spatial_axes_; ++i) { pad_data[i] = (num_pad_dims == 0) ? kDefaultPad : conv_param.pad((num_pad_dims == 1) ? 0 : i); } } + // Setup dilation dimensions (dilation_). + dilation_.Reshape(spatial_dim_blob_shape); + int* dilation_data = dilation_.mutable_cpu_data(); + const int num_dilation_dims = conv_param.dilation_size(); + CHECK(num_dilation_dims == 0 || num_dilation_dims == 1 || + num_dilation_dims == num_spatial_axes_) + << "dilation must be specified once, or once per spatial dimension " + << "(dilation specified " << num_dilation_dims << " times; " + << num_spatial_axes_ << " spatial dims)."; + const int kDefaultDilation = 1; + for (int i = 0; i < num_spatial_axes_; ++i) { + dilation_data[i] = (num_dilation_dims == 0) ? kDefaultDilation : + conv_param.dilation((num_dilation_dims == 1) ? 0 : i); + } // Special case: im2col is the identity for 1x1 convolution with stride 1 // and no padding, so flag for skipping the buffer and transformation. is_1x1_ = true; diff --git a/src/caffe/layers/conv_layer.cpp b/src/caffe/layers/conv_layer.cpp index cff09783..5d522ab3 100644 --- a/src/caffe/layers/conv_layer.cpp +++ b/src/caffe/layers/conv_layer.cpp @@ -9,11 +9,13 @@ void ConvolutionLayer<Dtype>::compute_output_shape() { const int* kernel_shape_data = this->kernel_shape_.cpu_data(); const int* stride_data = this->stride_.cpu_data(); const int* pad_data = this->pad_.cpu_data(); + const int* dilation_data = this->dilation_.cpu_data(); this->output_shape_.clear(); for (int i = 0; i < this->num_spatial_axes_; ++i) { // i + 1 to skip channel axis const int input_dim = this->input_shape(i + 1); - const int output_dim = (input_dim + 2 * pad_data[i] - kernel_shape_data[i]) + const int kernel_extent = dilation_data[i] * (kernel_shape_data[i] - 1) + 1; + const int output_dim = (input_dim + 2 * pad_data[i] - kernel_extent) / stride_data[i] + 1; this->output_shape_.push_back(output_dim); } diff --git a/src/caffe/layers/im2col_layer.cpp b/src/caffe/layers/im2col_layer.cpp index c12e4f52..19ae3019 100644 --- a/src/caffe/layers/im2col_layer.cpp +++ b/src/caffe/layers/im2col_layer.cpp @@ -87,6 +87,20 @@ void Im2colLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom, conv_param.pad((num_pad_dims == 1) ? 0 : i); } } + // Setup dilation dimensions (dilation_). + dilation_.Reshape(dim_blob_shape); + int* dilation_data = dilation_.mutable_cpu_data(); + const int num_dilation_dims = conv_param.dilation_size(); + CHECK(num_dilation_dims == 0 || num_dilation_dims == 1 || + num_dilation_dims == num_spatial_axes_) + << "dilation must be specified once, or once per spatial dimension " + << "(dilation specified " << num_dilation_dims << " times; " + << num_spatial_axes_ << " spatial dims)."; + const int kDefaultDilation = 1; + for (int i = 0; i < num_spatial_axes_; ++i) { + dilation_data[i] = (num_dilation_dims == 0) ? kDefaultDilation : + conv_param.dilation((num_dilation_dims == 1) ? 0 : i); + } } template <typename Dtype> @@ -96,10 +110,12 @@ void Im2colLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom, const int* kernel_shape_data = kernel_shape_.cpu_data(); const int* stride_data = stride_.cpu_data(); const int* pad_data = pad_.cpu_data(); + const int* dilation_data = dilation_.cpu_data(); for (int i = 0; i < num_spatial_axes_; ++i) { top_shape[channel_axis_] *= kernel_shape_data[i]; const int input_dim = bottom[0]->shape(channel_axis_ + i + 1); - const int output_dim = (input_dim + 2 * pad_data[i] - kernel_shape_data[i]) + const int kernel_extent = dilation_data[i] * (kernel_shape_data[i] - 1) + 1; + const int output_dim = (input_dim + 2 * pad_data[i] - kernel_extent) / stride_data[i] + 1; top_shape[channel_axis_ + i + 1] = output_dim; } @@ -122,6 +138,7 @@ void Im2colLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom, DCHECK_EQ(kernel_shape_.count(), num_spatial_axes_); DCHECK_EQ(pad_.count(), num_spatial_axes_); DCHECK_EQ(stride_.count(), num_spatial_axes_); + DCHECK_EQ(dilation_.count(), num_spatial_axes_); if (!force_nd_im2col_ && num_spatial_axes_ == 2) { im2col_cpu(bottom_data + n * bottom_dim_, channels_, bottom[0]->shape(channel_axis_ + 1), @@ -129,6 +146,7 @@ void Im2colLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom, kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1], pad_.cpu_data()[0], pad_.cpu_data()[1], stride_.cpu_data()[0], stride_.cpu_data()[1], + dilation_.cpu_data()[0], dilation_.cpu_data()[1], top_data + n * top_dim_); } else { im2col_nd_cpu(bottom_data + n * bottom_dim_, num_spatial_axes_, @@ -153,6 +171,7 @@ void Im2colLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top, kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1], pad_.cpu_data()[0], pad_.cpu_data()[1], stride_.cpu_data()[0], stride_.cpu_data()[1], + dilation_.cpu_data()[0], dilation_.cpu_data()[1], bottom_diff + n * bottom_dim_); } else { col2im_nd_cpu(top_diff + n * top_dim_, num_spatial_axes_, diff --git a/src/caffe/layers/im2col_layer.cu b/src/caffe/layers/im2col_layer.cu index 517b4220..d90075d4 100644 --- a/src/caffe/layers/im2col_layer.cu +++ b/src/caffe/layers/im2col_layer.cu @@ -19,6 +19,7 @@ void Im2colLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom, kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1], pad_.cpu_data()[0], pad_.cpu_data()[1], stride_.cpu_data()[0], stride_.cpu_data()[1], + dilation_.cpu_data()[0], dilation_.cpu_data()[1], top_data + n * top_dim_); } else { im2col_nd_gpu(bottom_data + n * bottom_dim_, num_spatial_axes_, @@ -43,6 +44,7 @@ void Im2colLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top, kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1], pad_.cpu_data()[0], pad_.cpu_data()[1], stride_.cpu_data()[0], stride_.cpu_data()[1], + dilation_.cpu_data()[0], dilation_.cpu_data()[1], bottom_diff + n * bottom_dim_); } else { col2im_nd_gpu(top_diff + n * top_dim_, num_spatial_axes_, bottom_dim_, diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 787369f7..87c46629 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -518,6 +518,7 @@ message ConvolutionParameter { repeated uint32 pad = 3; // The padding size; defaults to 0 repeated uint32 kernel_size = 4; // The kernel size repeated uint32 stride = 6; // The stride; defaults to 1 + repeated uint32 dilation = 18; // The dilation; defaults to 1 // For 2D convolution only, the *_h and *_w versions may also be used to // specify both spatial dimensions. diff --git a/src/caffe/test/test_convolution_layer.cpp b/src/caffe/test/test_convolution_layer.cpp index e2d43f31..95c3c80c 100644 --- a/src/caffe/test/test_convolution_layer.cpp +++ b/src/caffe/test/test_convolution_layer.cpp @@ -46,13 +46,17 @@ void caffe_conv(const Blob<Dtype>* in, ConvolutionParameter* conv_param, } else { stride_h = stride_w = conv_param->stride_size() ? conv_param->stride(0) : 1; } - int kernel_d, pad_d, stride_d; + int dilation_h, dilation_w; + dilation_h = dilation_w = conv_param->dilation_size() ? + conv_param->dilation(0) : 1; + int kernel_d, pad_d, stride_d, dilation_d; if (has_depth) { kernel_d = kernel_h; stride_d = stride_h; pad_d = pad_h; + dilation_d = dilation_h; } else { - kernel_d = stride_d = 1; + kernel_d = stride_d = dilation_d = 1; pad_d = 0; } // Groups @@ -77,9 +81,9 @@ void caffe_conv(const Blob<Dtype>* in, ConvolutionParameter* conv_param, for (int r = 0; r < kernel_d; r++) { for (int p = 0; p < kernel_h; p++) { for (int q = 0; q < kernel_w; q++) { - int in_z = z * stride_d - pad_d + r; - int in_y = y * stride_h - pad_h + p; - int in_x = x * stride_w - pad_w + q; + int in_z = z * stride_d - pad_d + r * dilation_d; + int in_y = y * stride_h - pad_h + p * dilation_h; + int in_x = x * stride_w - pad_w + q * dilation_w; if (in_z >= 0 && in_z < (has_depth ? in->shape(2) : 1) && in_y >= 0 && in_y < in->shape(2 + has_depth) && in_x >= 0 && in_x < in->shape(3 + has_depth)) { diff --git a/src/caffe/test/test_im2col_kernel.cu b/src/caffe/test/test_im2col_kernel.cu index 3f97cf6d..15e06aa8 100644 --- a/src/caffe/test/test_im2col_kernel.cu +++ b/src/caffe/test/test_im2col_kernel.cu @@ -18,6 +18,7 @@ __global__ void im2col_gpu_kernel(const int n, const Dtype* data_im, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int height_col, const int width_col, Dtype* data_col); @@ -38,6 +39,7 @@ class Im2colKernelTest : public GPUDeviceTest<Dtype> { blob_kernel_shape_(new Blob<int>()), blob_stride_(new Blob<int>()), blob_pad_(new Blob<int>()), + blob_dilation_(new Blob<int>()), blob_top_(new Blob<Dtype>()), blob_top_cpu_(new Blob<Dtype>()) { FillerParameter filler_param; @@ -47,20 +49,25 @@ class Im2colKernelTest : public GPUDeviceTest<Dtype> { blob_kernel_shape_->Reshape(dim_blob_shape); blob_stride_->Reshape(dim_blob_shape); blob_pad_->Reshape(dim_blob_shape); + blob_dilation_->Reshape(dim_blob_shape); height_ = blob_bottom_->height(); width_ = blob_bottom_->width(); channels_ = blob_bottom_->channels(); pad_ = 0; stride_ = 2; + dilation_ = 1; kernel_size_ = 3; - height_col_ = (height_ + 2 * pad_ - kernel_size_) / stride_ + 1; - width_col_ = (width_ + 2 * pad_ - kernel_size_) / stride_ + 1; + height_col_ = (height_ + 2 * pad_ - + (dilation_ * (kernel_size_ - 1) + 1)) / stride_ + 1; + width_col_ = (width_ + 2 * pad_ - + (dilation_ * (kernel_size_ - 1) + 1)) / stride_ + 1; for (int i = 0; i < 2; ++i) { blob_kernel_shape_->mutable_cpu_data()[i] = kernel_size_; blob_stride_->mutable_cpu_data()[i] = stride_; blob_pad_->mutable_cpu_data()[i] = pad_; + blob_dilation_->mutable_cpu_data()[i] = dilation_; } } @@ -71,11 +78,13 @@ class Im2colKernelTest : public GPUDeviceTest<Dtype> { delete blob_kernel_shape_; delete blob_stride_; delete blob_pad_; + delete blob_dilation_; } Blob<int>* const blob_kernel_shape_; Blob<int>* const blob_stride_; Blob<int>* const blob_pad_; + Blob<int>* const blob_dilation_; Blob<Dtype>* const blob_bottom_; Blob<Dtype>* const blob_top_; Blob<Dtype>* const blob_top_cpu_; @@ -84,6 +93,7 @@ class Im2colKernelTest : public GPUDeviceTest<Dtype> { int channels_; int pad_; int stride_; + int dilation_; int kernel_size_; int height_col_; int width_col_; @@ -112,7 +122,7 @@ TYPED_TEST(Im2colKernelTest, Test2D) { im2col_cpu(this->blob_bottom_->cpu_data() + this->blob_bottom_->offset(n), this->channels_, this->height_, this->width_, this->kernel_size_, this->kernel_size_, this->pad_, this->pad_, - this->stride_, this->stride_, + this->stride_, this->stride_, this->dilation_, this->dilation_, cpu_data + this->blob_top_cpu_->offset(n)); } @@ -129,6 +139,7 @@ TYPED_TEST(Im2colKernelTest, Test2D) { num_kernels, bottom_data + this->blob_bottom_->offset(n), this->height_, this->width_, this->kernel_size_, this->kernel_size_, this->pad_, this->pad_, this->stride_, this->stride_, + this->dilation_, this->dilation_, this->height_col_, this->width_col_, top_data + this->blob_top_->offset(n)); CUDA_POST_KERNEL_CHECK; diff --git a/src/caffe/test/test_im2col_layer.cpp b/src/caffe/test/test_im2col_layer.cpp index 8274dd48..932d3f21 100644 --- a/src/caffe/test/test_im2col_layer.cpp +++ b/src/caffe/test/test_im2col_layer.cpp @@ -17,7 +17,7 @@ class Im2colLayerTest : public MultiDeviceTest<TypeParam> { typedef typename TypeParam::Dtype Dtype; protected: Im2colLayerTest() - : blob_bottom_(new Blob<Dtype>(2, 3, 6, 5)), + : blob_bottom_(new Blob<Dtype>(2, 3, 10, 9)), blob_top_(new Blob<Dtype>()) { // fill the values Caffe::set_random_seed(1701); @@ -75,6 +75,7 @@ TYPED_TEST(Im2colLayerTest, TestGradient) { layer_param.mutable_convolution_param(); convolution_param->add_kernel_size(3); convolution_param->add_stride(2); + convolution_param->add_dilation(3); Im2colLayer<Dtype> layer(layer_param); GradientChecker<Dtype> checker(1e-2, 1e-2); checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, diff --git a/src/caffe/util/im2col.cpp b/src/caffe/util/im2col.cpp index 27e5b7c0..1e578e7c 100644 --- a/src/caffe/util/im2col.cpp +++ b/src/caffe/util/im2col.cpp @@ -10,9 +10,12 @@ void im2col_cpu(const Dtype* data_im, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, Dtype* data_col) { - const int height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1; - const int width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1; + const int height_col = (height + 2 * pad_h - + (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_col = (width + 2 * pad_w - + (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; const int channels_col = channels * kernel_h * kernel_w; for (int c_col = 0; c_col < channels_col; ++c_col) { int w_offset = c_col % kernel_w; @@ -20,8 +23,8 @@ void im2col_cpu(const Dtype* data_im, const int channels, int c_im = c_col / kernel_h / kernel_w; for (int h_col = 0; h_col < height_col; ++h_col) { for (int w_col = 0; w_col < width_col; ++w_col) { - int h_im = h_col * stride_h - pad_h + h_offset; - int w_im = w_col * stride_w - pad_w + w_offset; + int h_im = h_col * stride_h - pad_h + h_offset * dilation_h; + int w_im = w_col * stride_w - pad_w + w_offset * dilation_w; data_col[(c_col * height_col + h_col) * width_col + w_col] = (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) ? data_im[(c_im * height + h_im) * width + w_im] : 0; @@ -34,11 +37,13 @@ void im2col_cpu(const Dtype* data_im, const int channels, template void im2col_cpu<float>(const float* data_im, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, float* data_col); + const int stride_w, const int dilation_h, const int dilation_w, + float* data_col); template void im2col_cpu<double>(const double* data_im, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, double* data_col); + const int stride_w, const int dilation_h, const int dilation_w, + double* data_col); template <typename Dtype> inline void im2col_nd_core_cpu(const Dtype* data_input, const bool im2col, @@ -137,10 +142,13 @@ void col2im_cpu(const Dtype* data_col, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, Dtype* data_im) { caffe_set(height * width * channels, Dtype(0), data_im); - const int height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1; - const int width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1; + const int height_col = (height + 2 * pad_h - + (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_col = (width + 2 * pad_w - + (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; const int channels_col = channels * kernel_h * kernel_w; for (int c_col = 0; c_col < channels_col; ++c_col) { int w_offset = c_col % kernel_w; @@ -148,8 +156,8 @@ void col2im_cpu(const Dtype* data_col, const int channels, int c_im = c_col / kernel_h / kernel_w; for (int h_col = 0; h_col < height_col; ++h_col) { for (int w_col = 0; w_col < width_col; ++w_col) { - int h_im = h_col * stride_h - pad_h + h_offset; - int w_im = w_col * stride_w - pad_w + w_offset; + int h_im = h_col * stride_h - pad_h + h_offset * dilation_h; + int w_im = w_col * stride_w - pad_w + w_offset * dilation_w; if (h_im >= 0 && h_im < height && w_im >= 0 && w_im < width) data_im[(c_im * height + h_im) * width + w_im] += data_col[(c_col * height_col + h_col) * width_col + w_col]; @@ -162,11 +170,13 @@ void col2im_cpu(const Dtype* data_col, const int channels, template void col2im_cpu<float>(const float* data_col, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, float* data_im); + const int stride_w, const int dilation_h, const int dilation_w, + float* data_im); template void col2im_cpu<double>(const double* data_col, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, double* data_im); + const int stride_w, const int dilation_h, const int dilation_w, + double* data_im); template <typename Dtype> void col2im_nd_cpu(const Dtype* data_col, const int num_spatial_axes, diff --git a/src/caffe/util/im2col.cu b/src/caffe/util/im2col.cu index 49354ab7..cdcaac5b 100644 --- a/src/caffe/util/im2col.cu +++ b/src/caffe/util/im2col.cu @@ -10,6 +10,7 @@ __global__ void im2col_gpu_kernel(const int n, const Dtype* data_im, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int height_col, const int width_col, Dtype* data_col) { CUDA_KERNEL_LOOP(index, n) { @@ -26,11 +27,11 @@ __global__ void im2col_gpu_kernel(const int n, const Dtype* data_im, data_im_ptr += (c_im * height + h_offset) * width + w_offset; for (int i = 0; i < kernel_h; ++i) { for (int j = 0; j < kernel_w; ++j) { - int h_im = h_offset + i; - int w_im = w_offset + j; + int h_im = h_offset + i * dilation_h; + int w_im = w_offset + j * dilation_w; *data_col_ptr = (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) ? - data_im_ptr[i * width + j] : 0; + data_im_ptr[i * dilation_h * width + j * dilation_w] : 0; data_col_ptr += height_col * width_col; } } @@ -42,17 +43,20 @@ void im2col_gpu(const Dtype* data_im, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, Dtype* data_col) { // We are going to launch channels * height_col * width_col kernels, each // kernel responsible for copying a single-channel grid. - int height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1; - int width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1; + int height_col = (height + 2 * pad_h - + (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - + (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; int num_kernels = channels * height_col * width_col; // NOLINT_NEXT_LINE(whitespace/operators) im2col_gpu_kernel<Dtype><<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>>( num_kernels, data_im, height, width, kernel_h, kernel_w, pad_h, - pad_w, stride_h, stride_w, height_col, + pad_w, stride_h, stride_w, dilation_h, dilation_w, height_col, width_col, data_col); CUDA_POST_KERNEL_CHECK; } @@ -61,11 +65,11 @@ void im2col_gpu(const Dtype* data_im, const int channels, template void im2col_gpu<float>(const float* data_im, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, - float* data_col); + const int dilation_h, const int dilation_w, float* data_col); template void im2col_gpu<double>(const double* data_im, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, - double* data_col); + const int dilation_h, const int dilation_w, double* data_col); template <typename Dtype, int num_axes> __global__ void im2col_nd_gpu_kernel(const int n, const Dtype* data_im, @@ -223,6 +227,7 @@ __global__ void col2im_gpu_kernel(const int n, const Dtype* data_col, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int height_col, const int width_col, Dtype* data_im) { CUDA_KERNEL_LOOP(index, n) { @@ -230,33 +235,27 @@ __global__ void col2im_gpu_kernel(const int n, const Dtype* data_col, const int w_im = index % width + pad_w; const int h_im = (index / width) % height + pad_h; const int c_im = index / (width * height); + int kernel_extent_w = (kernel_w - 1) * dilation_w + 1; + int kernel_extent_h = (kernel_h - 1) * dilation_h + 1; // compute the start and end of the output const int w_col_start = - (w_im < kernel_w) ? 0 : (w_im - kernel_w) / stride_w + 1; - const int w_col_end = - min(w_im / stride_w + 1, width_col); + (w_im < kernel_extent_w) ? 0 : (w_im - kernel_extent_w) / stride_w + 1; + const int w_col_end = min(w_im / stride_w + 1, width_col); const int h_col_start = - (h_im < kernel_h) ? 0 : (h_im - kernel_h) / stride_h + 1; - const int h_col_end = - min(h_im / stride_h + 1, height_col); - /* - for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { - for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { - // the col location: [c * width * height + h_out, w_out] - int c_col = c_im * kernel_h * kernel_w - + (h_im - h_col * stride_h) * kernel_w + (w_im - w_col * stride_w); - val += data_col[(c_col * height_col + h_col) * width_col + w_col]; - } - } - */ - // equivalent implementation - int offset = (c_im * kernel_h * kernel_w + h_im * kernel_w + w_im) - * height_col * width_col; - int coeff_h_col = (1 - stride_h * kernel_w * height_col) * width_col; - int coeff_w_col = (1 - stride_w * height_col * width_col); - for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { - for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { - val += data_col[offset + h_col * coeff_h_col + w_col * coeff_w_col]; + (h_im < kernel_extent_h) ? 0 : (h_im - kernel_extent_h) / stride_h + 1; + const int h_col_end = min(h_im / stride_h + 1, height_col); + // TODO: use LCM of stride and dilation to avoid unnecessary loops + for (int h_col = h_col_start; h_col < h_col_end; h_col += 1) { + for (int w_col = w_col_start; w_col < w_col_end; w_col += 1) { + int h_k = (h_im - h_col * stride_h); + int w_k = (w_im - w_col * stride_w); + if (h_k % dilation_h == 0 && w_k % dilation_w == 0) { + h_k /= dilation_h; + w_k /= dilation_w; + int data_col_index = (((c_im * kernel_h + h_k) * kernel_w + w_k) * + height_col + h_col) * width_col + w_col; + val += data_col[data_col_index]; + } } } data_im[index] = val; @@ -267,9 +266,12 @@ template <typename Dtype> void col2im_gpu(const Dtype* data_col, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, Dtype* data_im) { - int height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1; - int width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1; + const int stride_w, const int dilation_h, const int dilation_w, + Dtype* data_im) { + int height_col = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / + stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / + stride_w + 1; int num_kernels = channels * height * width; // To avoid involving atomic operations, we will launch one kernel per // bottom dimension, and then in the kernel add up the top dimensions. @@ -277,7 +279,7 @@ void col2im_gpu(const Dtype* data_col, const int channels, col2im_gpu_kernel<Dtype><<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>>( num_kernels, data_col, height, width, channels, kernel_h, kernel_w, - pad_h, pad_w, stride_h, stride_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, height_col, width_col, data_im); CUDA_POST_KERNEL_CHECK; } @@ -286,11 +288,13 @@ void col2im_gpu(const Dtype* data_col, const int channels, template void col2im_gpu<float>(const float* data_col, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, float* data_im); + const int stride_w, const int dilation_h, const int dilation_w, + float* data_im); template void col2im_gpu<double>(const double* data_col, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, double* data_im); + const int stride_w, const int dilation_h, const int dilation_w, + double* data_im); template <typename Dtype, int num_axes> __global__ void col2im_nd_gpu_kernel(const int n, const Dtype* data_col, |