summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFisher Yu <i@yf.io>2015-12-06 20:04:43 -0500
committerJonathan L Long <jonlong@cs.berkeley.edu>2015-12-28 15:07:35 -0800
commit93bfcb53120416255d6d7261b638f0b38ff9e9bf (patch)
tree4649f835a7b53e78995e59905cbc441953440854
parent03a84bf464dd47bcec9ac943f0229a758c627f05 (diff)
downloadcaffeonacl-93bfcb53120416255d6d7261b638f0b38ff9e9bf.tar.gz
caffeonacl-93bfcb53120416255d6d7261b638f0b38ff9e9bf.tar.bz2
caffeonacl-93bfcb53120416255d6d7261b638f0b38ff9e9bf.zip
add support for 2D dilated convolution
-rw-r--r--include/caffe/layers/base_conv_layer.hpp14
-rw-r--r--include/caffe/layers/conv_layer.hpp3
-rw-r--r--include/caffe/layers/im2col_layer.hpp2
-rw-r--r--include/caffe/util/im2col.hpp12
-rw-r--r--src/caffe/layer_factory.cpp17
-rw-r--r--src/caffe/layers/base_conv_layer.cpp20
-rw-r--r--src/caffe/layers/conv_layer.cpp4
-rw-r--r--src/caffe/layers/im2col_layer.cpp21
-rw-r--r--src/caffe/layers/im2col_layer.cu2
-rw-r--r--src/caffe/proto/caffe.proto1
-rw-r--r--src/caffe/test/test_convolution_layer.cpp14
-rw-r--r--src/caffe/test/test_im2col_kernel.cu17
-rw-r--r--src/caffe/test/test_im2col_layer.cpp3
-rw-r--r--src/caffe/util/im2col.cpp34
-rw-r--r--src/caffe/util/im2col.cu80
15 files changed, 170 insertions, 74 deletions
diff --git a/include/caffe/layers/base_conv_layer.hpp b/include/caffe/layers/base_conv_layer.hpp
index f3def16c..db471b58 100644
--- a/include/caffe/layers/base_conv_layer.hpp
+++ b/include/caffe/layers/base_conv_layer.hpp
@@ -68,6 +68,8 @@ class BaseConvolutionLayer : public Layer<Dtype> {
Blob<int> stride_;
/// @brief The spatial dimensions of the padding.
Blob<int> pad_;
+ /// @brief The spatial dimensions of the dilation.
+ Blob<int> dilation_;
/// @brief The spatial dimensions of the convolution input.
Blob<int> conv_input_shape_;
/// @brief The spatial dimensions of the col_buffer.
@@ -99,7 +101,8 @@ class BaseConvolutionLayer : public Layer<Dtype> {
conv_input_shape_.cpu_data()[1], conv_input_shape_.cpu_data()[2],
kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1],
pad_.cpu_data()[0], pad_.cpu_data()[1],
- stride_.cpu_data()[0], stride_.cpu_data()[1], col_buff);
+ stride_.cpu_data()[0], stride_.cpu_data()[1],
+ dilation_.cpu_data()[0], dilation_.cpu_data()[1], col_buff);
} else {
im2col_nd_cpu(data, num_spatial_axes_, conv_input_shape_.cpu_data(),
col_buffer_shape_.data(), kernel_shape_.cpu_data(),
@@ -112,7 +115,8 @@ class BaseConvolutionLayer : public Layer<Dtype> {
conv_input_shape_.cpu_data()[1], conv_input_shape_.cpu_data()[2],
kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1],
pad_.cpu_data()[0], pad_.cpu_data()[1],
- stride_.cpu_data()[0], stride_.cpu_data()[1], data);
+ stride_.cpu_data()[0], stride_.cpu_data()[1],
+ dilation_.cpu_data()[0], dilation_.cpu_data()[1], data);
} else {
col2im_nd_cpu(col_buff, num_spatial_axes_, conv_input_shape_.cpu_data(),
col_buffer_shape_.data(), kernel_shape_.cpu_data(),
@@ -126,7 +130,8 @@ class BaseConvolutionLayer : public Layer<Dtype> {
conv_input_shape_.cpu_data()[1], conv_input_shape_.cpu_data()[2],
kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1],
pad_.cpu_data()[0], pad_.cpu_data()[1],
- stride_.cpu_data()[0], stride_.cpu_data()[1], col_buff);
+ stride_.cpu_data()[0], stride_.cpu_data()[1],
+ dilation_.cpu_data()[0], dilation_.cpu_data()[1], col_buff);
} else {
im2col_nd_gpu(data, num_spatial_axes_, num_kernels_im2col_,
conv_input_shape_.gpu_data(), col_buffer_.gpu_shape(),
@@ -140,7 +145,8 @@ class BaseConvolutionLayer : public Layer<Dtype> {
conv_input_shape_.cpu_data()[1], conv_input_shape_.cpu_data()[2],
kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1],
pad_.cpu_data()[0], pad_.cpu_data()[1],
- stride_.cpu_data()[0], stride_.cpu_data()[1], data);
+ stride_.cpu_data()[0], stride_.cpu_data()[1],
+ dilation_.cpu_data()[0], dilation_.cpu_data()[1], data);
} else {
col2im_nd_gpu(col_buff, num_spatial_axes_, num_kernels_col2im_,
conv_input_shape_.gpu_data(), col_buffer_.gpu_shape(),
diff --git a/include/caffe/layers/conv_layer.hpp b/include/caffe/layers/conv_layer.hpp
index 15574766..93a618dd 100644
--- a/include/caffe/layers/conv_layer.hpp
+++ b/include/caffe/layers/conv_layer.hpp
@@ -44,6 +44,9 @@ class ConvolutionLayer : public BaseConvolutionLayer<Dtype> {
* convolution, given by pad for equal dimensions or pad_h and pad_w for
* different padding. Input padding is computed implicitly instead of
* actually padding.
+ * - dilation (\b optional, default 1). The filter
+ * dilation, given by dilation_size for equal dimensions for different
+ * dilation. By default the convolution has dilation 1.
* - group (\b optional, default 1). The number of filter groups. Group
* convolution is a method for reducing parameterization by selectively
* connecting input and output channels. The input and output channel dimensions must be divisible
diff --git a/include/caffe/layers/im2col_layer.hpp b/include/caffe/layers/im2col_layer.hpp
index 1d3b2eb6..71e32f74 100644
--- a/include/caffe/layers/im2col_layer.hpp
+++ b/include/caffe/layers/im2col_layer.hpp
@@ -46,6 +46,8 @@ class Im2colLayer : public Layer<Dtype> {
Blob<int> stride_;
/// @brief The spatial dimensions of the padding.
Blob<int> pad_;
+ /// @brief The spatial dimensions of the dilation.
+ Blob<int> dilation_;
int num_spatial_axes_;
int bottom_dim_;
diff --git a/include/caffe/util/im2col.hpp b/include/caffe/util/im2col.hpp
index d3eb6ccd..748b65c4 100644
--- a/include/caffe/util/im2col.hpp
+++ b/include/caffe/util/im2col.hpp
@@ -13,7 +13,8 @@ template <typename Dtype>
void im2col_cpu(const Dtype* data_im, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h,
- const int stride_w, Dtype* data_col);
+ const int stride_w, const int dilation_h, const int dilation_w,
+ Dtype* data_col);
template <typename Dtype>
void col2im_nd_cpu(const Dtype* data_col, const int num_spatial_axes,
@@ -25,7 +26,8 @@ template <typename Dtype>
void col2im_cpu(const Dtype* data_col, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h,
- const int stride_w, Dtype* data_im);
+ const int stride_w, const int dilation_h, const int dilation_w,
+ Dtype* data_im);
template <typename Dtype>
void im2col_nd_gpu(const Dtype* data_im, const int num_spatial_axes,
@@ -37,7 +39,8 @@ template <typename Dtype>
void im2col_gpu(const Dtype* data_im, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h,
- const int stride_w, Dtype* data_col);
+ const int stride_w, const int dilation_h, const int dilation_w,
+ Dtype* data_col);
template <typename Dtype>
void col2im_nd_gpu(const Dtype* data_col, const int num_spatial_axes,
@@ -49,7 +52,8 @@ template <typename Dtype>
void col2im_gpu(const Dtype* data_col, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h,
- const int stride_w, Dtype* data_im);
+ const int stride_w, const int dilation_h, const int dilation_w,
+ Dtype* data_im);
} // namespace caffe
diff --git a/src/caffe/layer_factory.cpp b/src/caffe/layer_factory.cpp
index 76d851af..6b1d1c1a 100644
--- a/src/caffe/layer_factory.cpp
+++ b/src/caffe/layer_factory.cpp
@@ -37,17 +37,30 @@ namespace caffe {
template <typename Dtype>
shared_ptr<Layer<Dtype> > GetConvolutionLayer(
const LayerParameter& param) {
- ConvolutionParameter_Engine engine = param.convolution_param().engine();
+ ConvolutionParameter conv_param = param.convolution_param();
+ ConvolutionParameter_Engine engine = conv_param.engine();
+ bool use_dilation = false;
+ for (int i = 0; i < conv_param.dilation_size(); ++i) {
+ if (conv_param.dilation(i) > 1) {
+ use_dilation = true;
+ }
+ }
if (engine == ConvolutionParameter_Engine_DEFAULT) {
engine = ConvolutionParameter_Engine_CAFFE;
#ifdef USE_CUDNN
- engine = ConvolutionParameter_Engine_CUDNN;
+ if (!use_dilation) {
+ engine = ConvolutionParameter_Engine_CUDNN;
+ }
#endif
}
if (engine == ConvolutionParameter_Engine_CAFFE) {
return shared_ptr<Layer<Dtype> >(new ConvolutionLayer<Dtype>(param));
#ifdef USE_CUDNN
} else if (engine == ConvolutionParameter_Engine_CUDNN) {
+ if (use_dilation) {
+ LOG(FATAL) << "CuDNN doesn't support the dilated convolution at Layer "
+ << param.name();
+ }
return shared_ptr<Layer<Dtype> >(new CuDNNConvolutionLayer<Dtype>(param));
#endif
} else {
diff --git a/src/caffe/layers/base_conv_layer.cpp b/src/caffe/layers/base_conv_layer.cpp
index f6f14cd0..4a4c68e0 100644
--- a/src/caffe/layers/base_conv_layer.cpp
+++ b/src/caffe/layers/base_conv_layer.cpp
@@ -36,7 +36,7 @@ void BaseConvolutionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
CHECK(num_kernel_dims == 1 || num_kernel_dims == num_spatial_axes_)
<< "kernel_size must be specified once, or once per spatial dimension "
<< "(kernel_size specified " << num_kernel_dims << " times; "
- << num_spatial_axes_ << " spatial dims);";
+ << num_spatial_axes_ << " spatial dims).";
for (int i = 0; i < num_spatial_axes_; ++i) {
kernel_shape_data[i] =
conv_param.kernel_size((num_kernel_dims == 1) ? 0 : i);
@@ -61,7 +61,7 @@ void BaseConvolutionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
num_stride_dims == num_spatial_axes_)
<< "stride must be specified once, or once per spatial dimension "
<< "(stride specified " << num_stride_dims << " times; "
- << num_spatial_axes_ << " spatial dims);";
+ << num_spatial_axes_ << " spatial dims).";
const int kDefaultStride = 1;
for (int i = 0; i < num_spatial_axes_; ++i) {
stride_data[i] = (num_stride_dims == 0) ? kDefaultStride :
@@ -85,13 +85,27 @@ void BaseConvolutionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
num_pad_dims == num_spatial_axes_)
<< "pad must be specified once, or once per spatial dimension "
<< "(pad specified " << num_pad_dims << " times; "
- << num_spatial_axes_ << " spatial dims);";
+ << num_spatial_axes_ << " spatial dims).";
const int kDefaultPad = 0;
for (int i = 0; i < num_spatial_axes_; ++i) {
pad_data[i] = (num_pad_dims == 0) ? kDefaultPad :
conv_param.pad((num_pad_dims == 1) ? 0 : i);
}
}
+ // Setup dilation dimensions (dilation_).
+ dilation_.Reshape(spatial_dim_blob_shape);
+ int* dilation_data = dilation_.mutable_cpu_data();
+ const int num_dilation_dims = conv_param.dilation_size();
+ CHECK(num_dilation_dims == 0 || num_dilation_dims == 1 ||
+ num_dilation_dims == num_spatial_axes_)
+ << "dilation must be specified once, or once per spatial dimension "
+ << "(dilation specified " << num_dilation_dims << " times; "
+ << num_spatial_axes_ << " spatial dims).";
+ const int kDefaultDilation = 1;
+ for (int i = 0; i < num_spatial_axes_; ++i) {
+ dilation_data[i] = (num_dilation_dims == 0) ? kDefaultDilation :
+ conv_param.dilation((num_dilation_dims == 1) ? 0 : i);
+ }
// Special case: im2col is the identity for 1x1 convolution with stride 1
// and no padding, so flag for skipping the buffer and transformation.
is_1x1_ = true;
diff --git a/src/caffe/layers/conv_layer.cpp b/src/caffe/layers/conv_layer.cpp
index cff09783..5d522ab3 100644
--- a/src/caffe/layers/conv_layer.cpp
+++ b/src/caffe/layers/conv_layer.cpp
@@ -9,11 +9,13 @@ void ConvolutionLayer<Dtype>::compute_output_shape() {
const int* kernel_shape_data = this->kernel_shape_.cpu_data();
const int* stride_data = this->stride_.cpu_data();
const int* pad_data = this->pad_.cpu_data();
+ const int* dilation_data = this->dilation_.cpu_data();
this->output_shape_.clear();
for (int i = 0; i < this->num_spatial_axes_; ++i) {
// i + 1 to skip channel axis
const int input_dim = this->input_shape(i + 1);
- const int output_dim = (input_dim + 2 * pad_data[i] - kernel_shape_data[i])
+ const int kernel_extent = dilation_data[i] * (kernel_shape_data[i] - 1) + 1;
+ const int output_dim = (input_dim + 2 * pad_data[i] - kernel_extent)
/ stride_data[i] + 1;
this->output_shape_.push_back(output_dim);
}
diff --git a/src/caffe/layers/im2col_layer.cpp b/src/caffe/layers/im2col_layer.cpp
index c12e4f52..19ae3019 100644
--- a/src/caffe/layers/im2col_layer.cpp
+++ b/src/caffe/layers/im2col_layer.cpp
@@ -87,6 +87,20 @@ void Im2colLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
conv_param.pad((num_pad_dims == 1) ? 0 : i);
}
}
+ // Setup dilation dimensions (dilation_).
+ dilation_.Reshape(dim_blob_shape);
+ int* dilation_data = dilation_.mutable_cpu_data();
+ const int num_dilation_dims = conv_param.dilation_size();
+ CHECK(num_dilation_dims == 0 || num_dilation_dims == 1 ||
+ num_dilation_dims == num_spatial_axes_)
+ << "dilation must be specified once, or once per spatial dimension "
+ << "(dilation specified " << num_dilation_dims << " times; "
+ << num_spatial_axes_ << " spatial dims).";
+ const int kDefaultDilation = 1;
+ for (int i = 0; i < num_spatial_axes_; ++i) {
+ dilation_data[i] = (num_dilation_dims == 0) ? kDefaultDilation :
+ conv_param.dilation((num_dilation_dims == 1) ? 0 : i);
+ }
}
template <typename Dtype>
@@ -96,10 +110,12 @@ void Im2colLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const int* kernel_shape_data = kernel_shape_.cpu_data();
const int* stride_data = stride_.cpu_data();
const int* pad_data = pad_.cpu_data();
+ const int* dilation_data = dilation_.cpu_data();
for (int i = 0; i < num_spatial_axes_; ++i) {
top_shape[channel_axis_] *= kernel_shape_data[i];
const int input_dim = bottom[0]->shape(channel_axis_ + i + 1);
- const int output_dim = (input_dim + 2 * pad_data[i] - kernel_shape_data[i])
+ const int kernel_extent = dilation_data[i] * (kernel_shape_data[i] - 1) + 1;
+ const int output_dim = (input_dim + 2 * pad_data[i] - kernel_extent)
/ stride_data[i] + 1;
top_shape[channel_axis_ + i + 1] = output_dim;
}
@@ -122,6 +138,7 @@ void Im2colLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
DCHECK_EQ(kernel_shape_.count(), num_spatial_axes_);
DCHECK_EQ(pad_.count(), num_spatial_axes_);
DCHECK_EQ(stride_.count(), num_spatial_axes_);
+ DCHECK_EQ(dilation_.count(), num_spatial_axes_);
if (!force_nd_im2col_ && num_spatial_axes_ == 2) {
im2col_cpu(bottom_data + n * bottom_dim_, channels_,
bottom[0]->shape(channel_axis_ + 1),
@@ -129,6 +146,7 @@ void Im2colLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1],
pad_.cpu_data()[0], pad_.cpu_data()[1],
stride_.cpu_data()[0], stride_.cpu_data()[1],
+ dilation_.cpu_data()[0], dilation_.cpu_data()[1],
top_data + n * top_dim_);
} else {
im2col_nd_cpu(bottom_data + n * bottom_dim_, num_spatial_axes_,
@@ -153,6 +171,7 @@ void Im2colLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1],
pad_.cpu_data()[0], pad_.cpu_data()[1],
stride_.cpu_data()[0], stride_.cpu_data()[1],
+ dilation_.cpu_data()[0], dilation_.cpu_data()[1],
bottom_diff + n * bottom_dim_);
} else {
col2im_nd_cpu(top_diff + n * top_dim_, num_spatial_axes_,
diff --git a/src/caffe/layers/im2col_layer.cu b/src/caffe/layers/im2col_layer.cu
index 517b4220..d90075d4 100644
--- a/src/caffe/layers/im2col_layer.cu
+++ b/src/caffe/layers/im2col_layer.cu
@@ -19,6 +19,7 @@ void Im2colLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1],
pad_.cpu_data()[0], pad_.cpu_data()[1],
stride_.cpu_data()[0], stride_.cpu_data()[1],
+ dilation_.cpu_data()[0], dilation_.cpu_data()[1],
top_data + n * top_dim_);
} else {
im2col_nd_gpu(bottom_data + n * bottom_dim_, num_spatial_axes_,
@@ -43,6 +44,7 @@ void Im2colLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1],
pad_.cpu_data()[0], pad_.cpu_data()[1],
stride_.cpu_data()[0], stride_.cpu_data()[1],
+ dilation_.cpu_data()[0], dilation_.cpu_data()[1],
bottom_diff + n * bottom_dim_);
} else {
col2im_nd_gpu(top_diff + n * top_dim_, num_spatial_axes_, bottom_dim_,
diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto
index 787369f7..87c46629 100644
--- a/src/caffe/proto/caffe.proto
+++ b/src/caffe/proto/caffe.proto
@@ -518,6 +518,7 @@ message ConvolutionParameter {
repeated uint32 pad = 3; // The padding size; defaults to 0
repeated uint32 kernel_size = 4; // The kernel size
repeated uint32 stride = 6; // The stride; defaults to 1
+ repeated uint32 dilation = 18; // The dilation; defaults to 1
// For 2D convolution only, the *_h and *_w versions may also be used to
// specify both spatial dimensions.
diff --git a/src/caffe/test/test_convolution_layer.cpp b/src/caffe/test/test_convolution_layer.cpp
index e2d43f31..95c3c80c 100644
--- a/src/caffe/test/test_convolution_layer.cpp
+++ b/src/caffe/test/test_convolution_layer.cpp
@@ -46,13 +46,17 @@ void caffe_conv(const Blob<Dtype>* in, ConvolutionParameter* conv_param,
} else {
stride_h = stride_w = conv_param->stride_size() ? conv_param->stride(0) : 1;
}
- int kernel_d, pad_d, stride_d;
+ int dilation_h, dilation_w;
+ dilation_h = dilation_w = conv_param->dilation_size() ?
+ conv_param->dilation(0) : 1;
+ int kernel_d, pad_d, stride_d, dilation_d;
if (has_depth) {
kernel_d = kernel_h;
stride_d = stride_h;
pad_d = pad_h;
+ dilation_d = dilation_h;
} else {
- kernel_d = stride_d = 1;
+ kernel_d = stride_d = dilation_d = 1;
pad_d = 0;
}
// Groups
@@ -77,9 +81,9 @@ void caffe_conv(const Blob<Dtype>* in, ConvolutionParameter* conv_param,
for (int r = 0; r < kernel_d; r++) {
for (int p = 0; p < kernel_h; p++) {
for (int q = 0; q < kernel_w; q++) {
- int in_z = z * stride_d - pad_d + r;
- int in_y = y * stride_h - pad_h + p;
- int in_x = x * stride_w - pad_w + q;
+ int in_z = z * stride_d - pad_d + r * dilation_d;
+ int in_y = y * stride_h - pad_h + p * dilation_h;
+ int in_x = x * stride_w - pad_w + q * dilation_w;
if (in_z >= 0 && in_z < (has_depth ? in->shape(2) : 1)
&& in_y >= 0 && in_y < in->shape(2 + has_depth)
&& in_x >= 0 && in_x < in->shape(3 + has_depth)) {
diff --git a/src/caffe/test/test_im2col_kernel.cu b/src/caffe/test/test_im2col_kernel.cu
index 3f97cf6d..15e06aa8 100644
--- a/src/caffe/test/test_im2col_kernel.cu
+++ b/src/caffe/test/test_im2col_kernel.cu
@@ -18,6 +18,7 @@ __global__ void im2col_gpu_kernel(const int n, const Dtype* data_im,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
const int height_col, const int width_col,
Dtype* data_col);
@@ -38,6 +39,7 @@ class Im2colKernelTest : public GPUDeviceTest<Dtype> {
blob_kernel_shape_(new Blob<int>()),
blob_stride_(new Blob<int>()),
blob_pad_(new Blob<int>()),
+ blob_dilation_(new Blob<int>()),
blob_top_(new Blob<Dtype>()),
blob_top_cpu_(new Blob<Dtype>()) {
FillerParameter filler_param;
@@ -47,20 +49,25 @@ class Im2colKernelTest : public GPUDeviceTest<Dtype> {
blob_kernel_shape_->Reshape(dim_blob_shape);
blob_stride_->Reshape(dim_blob_shape);
blob_pad_->Reshape(dim_blob_shape);
+ blob_dilation_->Reshape(dim_blob_shape);
height_ = blob_bottom_->height();
width_ = blob_bottom_->width();
channels_ = blob_bottom_->channels();
pad_ = 0;
stride_ = 2;
+ dilation_ = 1;
kernel_size_ = 3;
- height_col_ = (height_ + 2 * pad_ - kernel_size_) / stride_ + 1;
- width_col_ = (width_ + 2 * pad_ - kernel_size_) / stride_ + 1;
+ height_col_ = (height_ + 2 * pad_ -
+ (dilation_ * (kernel_size_ - 1) + 1)) / stride_ + 1;
+ width_col_ = (width_ + 2 * pad_ -
+ (dilation_ * (kernel_size_ - 1) + 1)) / stride_ + 1;
for (int i = 0; i < 2; ++i) {
blob_kernel_shape_->mutable_cpu_data()[i] = kernel_size_;
blob_stride_->mutable_cpu_data()[i] = stride_;
blob_pad_->mutable_cpu_data()[i] = pad_;
+ blob_dilation_->mutable_cpu_data()[i] = dilation_;
}
}
@@ -71,11 +78,13 @@ class Im2colKernelTest : public GPUDeviceTest<Dtype> {
delete blob_kernel_shape_;
delete blob_stride_;
delete blob_pad_;
+ delete blob_dilation_;
}
Blob<int>* const blob_kernel_shape_;
Blob<int>* const blob_stride_;
Blob<int>* const blob_pad_;
+ Blob<int>* const blob_dilation_;
Blob<Dtype>* const blob_bottom_;
Blob<Dtype>* const blob_top_;
Blob<Dtype>* const blob_top_cpu_;
@@ -84,6 +93,7 @@ class Im2colKernelTest : public GPUDeviceTest<Dtype> {
int channels_;
int pad_;
int stride_;
+ int dilation_;
int kernel_size_;
int height_col_;
int width_col_;
@@ -112,7 +122,7 @@ TYPED_TEST(Im2colKernelTest, Test2D) {
im2col_cpu(this->blob_bottom_->cpu_data() + this->blob_bottom_->offset(n),
this->channels_, this->height_, this->width_,
this->kernel_size_, this->kernel_size_, this->pad_, this->pad_,
- this->stride_, this->stride_,
+ this->stride_, this->stride_, this->dilation_, this->dilation_,
cpu_data + this->blob_top_cpu_->offset(n));
}
@@ -129,6 +139,7 @@ TYPED_TEST(Im2colKernelTest, Test2D) {
num_kernels, bottom_data + this->blob_bottom_->offset(n),
this->height_, this->width_, this->kernel_size_, this->kernel_size_,
this->pad_, this->pad_, this->stride_, this->stride_,
+ this->dilation_, this->dilation_,
this->height_col_, this->width_col_,
top_data + this->blob_top_->offset(n));
CUDA_POST_KERNEL_CHECK;
diff --git a/src/caffe/test/test_im2col_layer.cpp b/src/caffe/test/test_im2col_layer.cpp
index 8274dd48..932d3f21 100644
--- a/src/caffe/test/test_im2col_layer.cpp
+++ b/src/caffe/test/test_im2col_layer.cpp
@@ -17,7 +17,7 @@ class Im2colLayerTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
protected:
Im2colLayerTest()
- : blob_bottom_(new Blob<Dtype>(2, 3, 6, 5)),
+ : blob_bottom_(new Blob<Dtype>(2, 3, 10, 9)),
blob_top_(new Blob<Dtype>()) {
// fill the values
Caffe::set_random_seed(1701);
@@ -75,6 +75,7 @@ TYPED_TEST(Im2colLayerTest, TestGradient) {
layer_param.mutable_convolution_param();
convolution_param->add_kernel_size(3);
convolution_param->add_stride(2);
+ convolution_param->add_dilation(3);
Im2colLayer<Dtype> layer(layer_param);
GradientChecker<Dtype> checker(1e-2, 1e-2);
checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
diff --git a/src/caffe/util/im2col.cpp b/src/caffe/util/im2col.cpp
index 27e5b7c0..1e578e7c 100644
--- a/src/caffe/util/im2col.cpp
+++ b/src/caffe/util/im2col.cpp
@@ -10,9 +10,12 @@ void im2col_cpu(const Dtype* data_im, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
Dtype* data_col) {
- const int height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1;
- const int width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1;
+ const int height_col = (height + 2 * pad_h -
+ (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+ const int width_col = (width + 2 * pad_w -
+ (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
const int channels_col = channels * kernel_h * kernel_w;
for (int c_col = 0; c_col < channels_col; ++c_col) {
int w_offset = c_col % kernel_w;
@@ -20,8 +23,8 @@ void im2col_cpu(const Dtype* data_im, const int channels,
int c_im = c_col / kernel_h / kernel_w;
for (int h_col = 0; h_col < height_col; ++h_col) {
for (int w_col = 0; w_col < width_col; ++w_col) {
- int h_im = h_col * stride_h - pad_h + h_offset;
- int w_im = w_col * stride_w - pad_w + w_offset;
+ int h_im = h_col * stride_h - pad_h + h_offset * dilation_h;
+ int w_im = w_col * stride_w - pad_w + w_offset * dilation_w;
data_col[(c_col * height_col + h_col) * width_col + w_col] =
(h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) ?
data_im[(c_im * height + h_im) * width + w_im] : 0;
@@ -34,11 +37,13 @@ void im2col_cpu(const Dtype* data_im, const int channels,
template void im2col_cpu<float>(const float* data_im, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h,
- const int stride_w, float* data_col);
+ const int stride_w, const int dilation_h, const int dilation_w,
+ float* data_col);
template void im2col_cpu<double>(const double* data_im, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h,
- const int stride_w, double* data_col);
+ const int stride_w, const int dilation_h, const int dilation_w,
+ double* data_col);
template <typename Dtype>
inline void im2col_nd_core_cpu(const Dtype* data_input, const bool im2col,
@@ -137,10 +142,13 @@ void col2im_cpu(const Dtype* data_col, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
Dtype* data_im) {
caffe_set(height * width * channels, Dtype(0), data_im);
- const int height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1;
- const int width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1;
+ const int height_col = (height + 2 * pad_h -
+ (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+ const int width_col = (width + 2 * pad_w -
+ (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
const int channels_col = channels * kernel_h * kernel_w;
for (int c_col = 0; c_col < channels_col; ++c_col) {
int w_offset = c_col % kernel_w;
@@ -148,8 +156,8 @@ void col2im_cpu(const Dtype* data_col, const int channels,
int c_im = c_col / kernel_h / kernel_w;
for (int h_col = 0; h_col < height_col; ++h_col) {
for (int w_col = 0; w_col < width_col; ++w_col) {
- int h_im = h_col * stride_h - pad_h + h_offset;
- int w_im = w_col * stride_w - pad_w + w_offset;
+ int h_im = h_col * stride_h - pad_h + h_offset * dilation_h;
+ int w_im = w_col * stride_w - pad_w + w_offset * dilation_w;
if (h_im >= 0 && h_im < height && w_im >= 0 && w_im < width)
data_im[(c_im * height + h_im) * width + w_im] +=
data_col[(c_col * height_col + h_col) * width_col + w_col];
@@ -162,11 +170,13 @@ void col2im_cpu(const Dtype* data_col, const int channels,
template void col2im_cpu<float>(const float* data_col, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h,
- const int stride_w, float* data_im);
+ const int stride_w, const int dilation_h, const int dilation_w,
+ float* data_im);
template void col2im_cpu<double>(const double* data_col, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h,
- const int stride_w, double* data_im);
+ const int stride_w, const int dilation_h, const int dilation_w,
+ double* data_im);
template <typename Dtype>
void col2im_nd_cpu(const Dtype* data_col, const int num_spatial_axes,
diff --git a/src/caffe/util/im2col.cu b/src/caffe/util/im2col.cu
index 49354ab7..cdcaac5b 100644
--- a/src/caffe/util/im2col.cu
+++ b/src/caffe/util/im2col.cu
@@ -10,6 +10,7 @@ __global__ void im2col_gpu_kernel(const int n, const Dtype* data_im,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
const int height_col, const int width_col,
Dtype* data_col) {
CUDA_KERNEL_LOOP(index, n) {
@@ -26,11 +27,11 @@ __global__ void im2col_gpu_kernel(const int n, const Dtype* data_im,
data_im_ptr += (c_im * height + h_offset) * width + w_offset;
for (int i = 0; i < kernel_h; ++i) {
for (int j = 0; j < kernel_w; ++j) {
- int h_im = h_offset + i;
- int w_im = w_offset + j;
+ int h_im = h_offset + i * dilation_h;
+ int w_im = w_offset + j * dilation_w;
*data_col_ptr =
(h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) ?
- data_im_ptr[i * width + j] : 0;
+ data_im_ptr[i * dilation_h * width + j * dilation_w] : 0;
data_col_ptr += height_col * width_col;
}
}
@@ -42,17 +43,20 @@ void im2col_gpu(const Dtype* data_im, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
Dtype* data_col) {
// We are going to launch channels * height_col * width_col kernels, each
// kernel responsible for copying a single-channel grid.
- int height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1;
- int width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1;
+ int height_col = (height + 2 * pad_h -
+ (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+ int width_col = (width + 2 * pad_w -
+ (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
int num_kernels = channels * height_col * width_col;
// NOLINT_NEXT_LINE(whitespace/operators)
im2col_gpu_kernel<Dtype><<<CAFFE_GET_BLOCKS(num_kernels),
CAFFE_CUDA_NUM_THREADS>>>(
num_kernels, data_im, height, width, kernel_h, kernel_w, pad_h,
- pad_w, stride_h, stride_w, height_col,
+ pad_w, stride_h, stride_w, dilation_h, dilation_w, height_col,
width_col, data_col);
CUDA_POST_KERNEL_CHECK;
}
@@ -61,11 +65,11 @@ void im2col_gpu(const Dtype* data_im, const int channels,
template void im2col_gpu<float>(const float* data_im, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
- float* data_col);
+ const int dilation_h, const int dilation_w, float* data_col);
template void im2col_gpu<double>(const double* data_im, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
- double* data_col);
+ const int dilation_h, const int dilation_w, double* data_col);
template <typename Dtype, int num_axes>
__global__ void im2col_nd_gpu_kernel(const int n, const Dtype* data_im,
@@ -223,6 +227,7 @@ __global__ void col2im_gpu_kernel(const int n, const Dtype* data_col,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
const int height_col, const int width_col,
Dtype* data_im) {
CUDA_KERNEL_LOOP(index, n) {
@@ -230,33 +235,27 @@ __global__ void col2im_gpu_kernel(const int n, const Dtype* data_col,
const int w_im = index % width + pad_w;
const int h_im = (index / width) % height + pad_h;
const int c_im = index / (width * height);
+ int kernel_extent_w = (kernel_w - 1) * dilation_w + 1;
+ int kernel_extent_h = (kernel_h - 1) * dilation_h + 1;
// compute the start and end of the output
const int w_col_start =
- (w_im < kernel_w) ? 0 : (w_im - kernel_w) / stride_w + 1;
- const int w_col_end =
- min(w_im / stride_w + 1, width_col);
+ (w_im < kernel_extent_w) ? 0 : (w_im - kernel_extent_w) / stride_w + 1;
+ const int w_col_end = min(w_im / stride_w + 1, width_col);
const int h_col_start =
- (h_im < kernel_h) ? 0 : (h_im - kernel_h) / stride_h + 1;
- const int h_col_end =
- min(h_im / stride_h + 1, height_col);
- /*
- for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
- for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
- // the col location: [c * width * height + h_out, w_out]
- int c_col = c_im * kernel_h * kernel_w
- + (h_im - h_col * stride_h) * kernel_w + (w_im - w_col * stride_w);
- val += data_col[(c_col * height_col + h_col) * width_col + w_col];
- }
- }
- */
- // equivalent implementation
- int offset = (c_im * kernel_h * kernel_w + h_im * kernel_w + w_im)
- * height_col * width_col;
- int coeff_h_col = (1 - stride_h * kernel_w * height_col) * width_col;
- int coeff_w_col = (1 - stride_w * height_col * width_col);
- for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
- for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
- val += data_col[offset + h_col * coeff_h_col + w_col * coeff_w_col];
+ (h_im < kernel_extent_h) ? 0 : (h_im - kernel_extent_h) / stride_h + 1;
+ const int h_col_end = min(h_im / stride_h + 1, height_col);
+ // TODO: use LCM of stride and dilation to avoid unnecessary loops
+ for (int h_col = h_col_start; h_col < h_col_end; h_col += 1) {
+ for (int w_col = w_col_start; w_col < w_col_end; w_col += 1) {
+ int h_k = (h_im - h_col * stride_h);
+ int w_k = (w_im - w_col * stride_w);
+ if (h_k % dilation_h == 0 && w_k % dilation_w == 0) {
+ h_k /= dilation_h;
+ w_k /= dilation_w;
+ int data_col_index = (((c_im * kernel_h + h_k) * kernel_w + w_k) *
+ height_col + h_col) * width_col + w_col;
+ val += data_col[data_col_index];
+ }
}
}
data_im[index] = val;
@@ -267,9 +266,12 @@ template <typename Dtype>
void col2im_gpu(const Dtype* data_col, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h,
- const int stride_w, Dtype* data_im) {
- int height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1;
- int width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1;
+ const int stride_w, const int dilation_h, const int dilation_w,
+ Dtype* data_im) {
+ int height_col = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) /
+ stride_h + 1;
+ int width_col = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) /
+ stride_w + 1;
int num_kernels = channels * height * width;
// To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions.
@@ -277,7 +279,7 @@ void col2im_gpu(const Dtype* data_col, const int channels,
col2im_gpu_kernel<Dtype><<<CAFFE_GET_BLOCKS(num_kernels),
CAFFE_CUDA_NUM_THREADS>>>(
num_kernels, data_col, height, width, channels, kernel_h, kernel_w,
- pad_h, pad_w, stride_h, stride_w,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
height_col, width_col, data_im);
CUDA_POST_KERNEL_CHECK;
}
@@ -286,11 +288,13 @@ void col2im_gpu(const Dtype* data_col, const int channels,
template void col2im_gpu<float>(const float* data_col, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h,
- const int stride_w, float* data_im);
+ const int stride_w, const int dilation_h, const int dilation_w,
+ float* data_im);
template void col2im_gpu<double>(const double* data_col, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h,
- const int stride_w, double* data_im);
+ const int stride_w, const int dilation_h, const int dilation_w,
+ double* data_im);
template <typename Dtype, int num_axes>
__global__ void col2im_nd_gpu_kernel(const int n, const Dtype* data_col,