From 9d8206e0f906069e7c04f08dfddefa1357f3915c Mon Sep 17 00:00:00 2001 From: Jeff Donahue Date: Wed, 4 Mar 2015 19:30:17 -0800 Subject: Im2col and Convolution layers support N spatial axes --- include/caffe/util/im2col.hpp | 24 ++ include/caffe/vision_layers.hpp | 108 ++++++-- src/caffe/layers/base_conv_layer.cpp | 241 +++++++++++----- src/caffe/layers/conv_layer.cpp | 32 ++- src/caffe/layers/conv_layer.cu | 16 +- src/caffe/layers/cudnn_conv_layer.cpp | 46 ++-- src/caffe/layers/cudnn_conv_layer.cu | 18 +- src/caffe/layers/deconv_layer.cpp | 32 ++- src/caffe/layers/deconv_layer.cu | 16 +- src/caffe/layers/im2col_layer.cpp | 171 +++++++++--- src/caffe/layers/im2col_layer.cu | 41 ++- src/caffe/proto/caffe.proto | 7 + src/caffe/test/test_convolution_layer.cpp | 409 +++++++++++++++++++++++----- src/caffe/test/test_deconvolution_layer.cpp | 159 ++++++++++- src/caffe/test/test_im2col_kernel.cu | 87 +++++- src/caffe/test/test_im2col_layer.cpp | 30 +- src/caffe/util/im2col.cpp | 116 ++++++++ src/caffe/util/im2col.cu | 306 ++++++++++++++++++++- src/caffe/util/upgrade_proto.cpp | 6 +- 19 files changed, 1554 insertions(+), 311 deletions(-) diff --git a/include/caffe/util/im2col.hpp b/include/caffe/util/im2col.hpp index 0051e2fa..531fd29c 100644 --- a/include/caffe/util/im2col.hpp +++ b/include/caffe/util/im2col.hpp @@ -3,24 +3,48 @@ namespace caffe { +template +void im2col_nd_cpu(const Dtype* data_im, const int num_spatial_axes, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + Dtype* data_col); + template void im2col_cpu(const Dtype* data_im, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, Dtype* data_col); +template +void col2im_nd_cpu(const Dtype* data_col, const int num_spatial_axes, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + Dtype* data_im); + template void col2im_cpu(const Dtype* data_col, const int channels, const int height, const int width, const int patch_h, const int patch_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, Dtype* data_im); +template +void im2col_nd_gpu(const Dtype* data_im, const int num_spatial_axes, + const int col_size, const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + Dtype* data_col); + template void im2col_gpu(const Dtype* data_im, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, Dtype* data_col); +template +void col2im_nd_gpu(const Dtype* data_col, const int num_spatial_axes, + const int im_size, const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + Dtype* data_im); + template void col2im_gpu(const Dtype* data_col, const int channels, const int height, const int width, const int patch_h, const int patch_w, diff --git a/include/caffe/vision_layers.hpp b/include/caffe/vision_layers.hpp index 211e3d90..eae65820 100644 --- a/include/caffe/vision_layers.hpp +++ b/include/caffe/vision_layers.hpp @@ -64,46 +64,101 @@ class BaseConvolutionLayer : public Layer { // Compute height_out_ and width_out_ from other parameters. virtual void compute_output_shape() = 0; - int kernel_h_, kernel_w_; - int stride_h_, stride_w_; + /// @brief The spatial dimensions of a filter kernel. + Blob kernel_shape_; + /// @brief The spatial dimensions of the stride. + Blob stride_; + /// @brief The spatial dimensions of the padding. + Blob pad_; + /// @brief The spatial dimensions of the convolution input. + Blob conv_input_shape_; + /// @brief The spatial dimensions of the input. + Blob input_shape_; + /// @brief The spatial dimensions of the col_buffer. + vector col_buffer_shape_; + /// @brief The spatial dimensions of the output. + vector output_shape_; + + int num_spatial_axes_; + int bottom_dim_; + int top_dim_; + + int channel_axis_; int num_; int channels_; - int pad_h_, pad_w_; - int height_, width_; int group_; + int out_spatial_dim_; + int weight_offset_; int num_output_; - int height_out_, width_out_; bool bias_term_; bool is_1x1_; + bool force_nd_im2col_; private: // wrap im2col/col2im so we don't have to remember the (long) argument lists inline void conv_im2col_cpu(const Dtype* data, Dtype* col_buff) { - im2col_cpu(data, conv_in_channels_, conv_in_height_, conv_in_width_, - kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, col_buff); + if (!force_nd_im2col_ && num_spatial_axes_ == 2) { + im2col_cpu(data, conv_in_channels_, + conv_input_shape_.cpu_data()[1], conv_input_shape_.cpu_data()[2], + kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1], + pad_.cpu_data()[0], pad_.cpu_data()[1], + stride_.cpu_data()[0], stride_.cpu_data()[1], col_buff); + } else { + im2col_nd_cpu(data, num_spatial_axes_, conv_input_shape_.cpu_data(), + col_buffer_shape_.data(), kernel_shape_.cpu_data(), + pad_.cpu_data(), stride_.cpu_data(), col_buff); + } } inline void conv_col2im_cpu(const Dtype* col_buff, Dtype* data) { - col2im_cpu(col_buff, conv_in_channels_, conv_in_height_, conv_in_width_, - kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, data); + if (!force_nd_im2col_ && num_spatial_axes_ == 2) { + col2im_cpu(col_buff, conv_in_channels_, + conv_input_shape_.cpu_data()[1], conv_input_shape_.cpu_data()[2], + kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1], + pad_.cpu_data()[0], pad_.cpu_data()[1], + stride_.cpu_data()[0], stride_.cpu_data()[1], data); + } else { + col2im_nd_cpu(col_buff, num_spatial_axes_, conv_input_shape_.cpu_data(), + col_buffer_shape_.data(), kernel_shape_.cpu_data(), + pad_.cpu_data(), stride_.cpu_data(), data); + } } #ifndef CPU_ONLY inline void conv_im2col_gpu(const Dtype* data, Dtype* col_buff) { - im2col_gpu(data, conv_in_channels_, conv_in_height_, conv_in_width_, - kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, col_buff); + if (!force_nd_im2col_ && num_spatial_axes_ == 2) { + im2col_gpu(data, conv_in_channels_, + conv_input_shape_.cpu_data()[1], conv_input_shape_.cpu_data()[2], + kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1], + pad_.cpu_data()[0], pad_.cpu_data()[1], + stride_.cpu_data()[0], stride_.cpu_data()[1], col_buff); + } else { + im2col_nd_gpu(data, num_spatial_axes_, num_kernels_im2col_, + conv_input_shape_.gpu_data(), col_buffer_.gpu_shape(), + kernel_shape_.gpu_data(), pad_.gpu_data(), + stride_.gpu_data(), col_buff); + } } inline void conv_col2im_gpu(const Dtype* col_buff, Dtype* data) { - col2im_gpu(col_buff, conv_in_channels_, conv_in_height_, conv_in_width_, - kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, data); + if (!force_nd_im2col_ && num_spatial_axes_ == 2) { + col2im_gpu(col_buff, conv_in_channels_, + conv_input_shape_.cpu_data()[1], conv_input_shape_.cpu_data()[2], + kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1], + pad_.cpu_data()[0], pad_.cpu_data()[1], + stride_.cpu_data()[0], stride_.cpu_data()[1], data); + } else { + col2im_nd_gpu(col_buff, num_spatial_axes_, num_kernels_col2im_, + conv_input_shape_.gpu_data(), col_buffer_.gpu_shape(), + kernel_shape_.gpu_data(), pad_.gpu_data(), stride_.gpu_data(), + data); + } } #endif + int num_kernels_im2col_; + int num_kernels_col2im_; int conv_out_channels_; int conv_in_channels_; int conv_out_spatial_dim_; - int conv_in_height_; - int conv_in_width_; int kernel_dim_; - int weight_offset_; int col_offset_; int output_offset_; @@ -250,7 +305,7 @@ class CuDNNConvolutionLayer : public ConvolutionLayer { cudnnTensorDescriptor_t bias_desc_; cudnnFilterDescriptor_t filter_desc_; vector conv_descs_; - int bottom_offset_, top_offset_, weight_offset_, bias_offset_; + int bottom_offset_, top_offset_, bias_offset_; size_t workspaceSizeInBytes; void *workspace; }; @@ -287,11 +342,22 @@ class Im2colLayer : public Layer { virtual void Backward_gpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom); - int kernel_h_, kernel_w_; - int stride_h_, stride_w_; + /// @brief The spatial dimensions of a filter kernel. + Blob kernel_shape_; + /// @brief The spatial dimensions of the stride. + Blob stride_; + /// @brief The spatial dimensions of the padding. + Blob pad_; + + int num_spatial_axes_; + int bottom_dim_; + int top_dim_; + + int channel_axis_; + int num_; int channels_; - int height_, width_; - int pad_h_, pad_w_; + + bool force_nd_im2col_; }; // Forward declare PoolingLayer and SplitLayer for use in LRNLayer. diff --git a/src/caffe/layers/base_conv_layer.cpp b/src/caffe/layers/base_conv_layer.cpp index ccb3adc7..a5b90a54 100644 --- a/src/caffe/layers/base_conv_layer.cpp +++ b/src/caffe/layers/base_conv_layer.cpp @@ -1,3 +1,4 @@ +#include #include #include "caffe/filler.hpp" @@ -11,50 +12,103 @@ namespace caffe { template void BaseConvolutionLayer::LayerSetUp(const vector*>& bottom, const vector*>& top) { - CHECK_EQ(4, bottom[0]->num_axes()) << "Input must have 4 axes, " - << "corresponding to (num, channels, height, width)"; // Configure the kernel size, padding, stride, and inputs. ConvolutionParameter conv_param = this->layer_param_.convolution_param(); - CHECK(!conv_param.has_kernel_size() != - !(conv_param.has_kernel_h() && conv_param.has_kernel_w())) - << "Filter size is kernel_size OR kernel_h and kernel_w; not both"; - CHECK(conv_param.has_kernel_size() || - (conv_param.has_kernel_h() && conv_param.has_kernel_w())) - << "For non-square filters both kernel_h and kernel_w are required."; - CHECK((!conv_param.has_pad() && conv_param.has_pad_h() - && conv_param.has_pad_w()) - || (!conv_param.has_pad_h() && !conv_param.has_pad_w())) - << "pad is pad OR pad_h and pad_w are required."; - CHECK((!conv_param.has_stride() && conv_param.has_stride_h() - && conv_param.has_stride_w()) - || (!conv_param.has_stride_h() && !conv_param.has_stride_w())) - << "Stride is stride OR stride_h and stride_w are required."; - if (conv_param.has_kernel_size()) { - kernel_h_ = kernel_w_ = conv_param.kernel_size(); + force_nd_im2col_ = conv_param.force_nd_im2col(); + channel_axis_ = bottom[0]->CanonicalAxisIndex(conv_param.axis()); + const int first_spatial_axis = channel_axis_ + 1; + const int num_axes = bottom[0]->num_axes(); + num_spatial_axes_ = num_axes - first_spatial_axis; + CHECK_GE(num_spatial_axes_, 0); + // Setup input dimensions (input_shape_). + vector bottom_dim_blob_shape(1, num_spatial_axes_ + 1); + input_shape_.Reshape(bottom_dim_blob_shape); + int* input_shape_data = input_shape_.mutable_cpu_data(); + for (int i = 0; i < num_spatial_axes_ + 1; ++i) { + input_shape_data[i] = bottom[0]->shape(channel_axis_ + i); + } + vector spatial_dim_blob_shape(1, std::max(num_spatial_axes_, 1)); + // Setup filter kernel dimensions (kernel_shape_). + kernel_shape_.Reshape(spatial_dim_blob_shape); + int* kernel_shape_data = kernel_shape_.mutable_cpu_data(); + if (conv_param.has_kernel_h() || conv_param.has_kernel_w()) { + CHECK_EQ(num_spatial_axes_, 2) + << "kernel_h & kernel_w can only be used for 2D convolution."; + CHECK_EQ(0, conv_param.kernel_size_size()) + << "Either kernel_size or kernel_h/w should be specified; not both."; + kernel_shape_data[0] = conv_param.kernel_h(); + kernel_shape_data[1] = conv_param.kernel_w(); } else { - kernel_h_ = conv_param.kernel_h(); - kernel_w_ = conv_param.kernel_w(); + const int num_kernel_dims = conv_param.kernel_size_size(); + CHECK(num_kernel_dims == 1 || num_kernel_dims == num_spatial_axes_) + << "kernel_size must be specified once, or once per spatial dimension " + << "(kernel_size specified " << num_kernel_dims << " times; " + << num_spatial_axes_ << " spatial dims);"; + for (int i = 0; i < num_spatial_axes_; ++i) { + kernel_shape_data[i] = + conv_param.kernel_size((num_kernel_dims == 1) ? 0 : i); + } + } + for (int i = 0; i < num_spatial_axes_; ++i) { + CHECK_GT(kernel_shape_data[i], 0) << "Filter dimensions must be nonzero."; } - CHECK_GT(kernel_h_, 0) << "Filter dimensions cannot be zero."; - CHECK_GT(kernel_w_, 0) << "Filter dimensions cannot be zero."; - if (!conv_param.has_pad_h()) { - pad_h_ = pad_w_ = conv_param.pad(); + // Setup stride dimensions (stride_). + stride_.Reshape(spatial_dim_blob_shape); + int* stride_data = stride_.mutable_cpu_data(); + if (conv_param.has_stride_h() || conv_param.has_stride_w()) { + CHECK_EQ(num_spatial_axes_, 2) + << "stride_h & stride_w can only be used for 2D convolution."; + CHECK_EQ(0, conv_param.stride_size()) + << "Either stride or stride_h/w should be specified; not both."; + stride_data[0] = conv_param.stride_h(); + stride_data[1] = conv_param.stride_w(); } else { - pad_h_ = conv_param.pad_h(); - pad_w_ = conv_param.pad_w(); + const int num_stride_dims = conv_param.stride_size(); + CHECK(num_stride_dims == 0 || num_stride_dims == 1 || + num_stride_dims == num_spatial_axes_) + << "stride must be specified once, or once per spatial dimension " + << "(stride specified " << num_stride_dims << " times; " + << num_spatial_axes_ << " spatial dims);"; + const int kDefaultStride = 1; + for (int i = 0; i < num_spatial_axes_; ++i) { + stride_data[i] = (num_stride_dims == 0) ? kDefaultStride : + conv_param.stride((num_stride_dims == 1) ? 0 : i); + CHECK_GT(stride_data[i], 0) << "Stride dimensions must be nonzero."; + } } - if (!conv_param.has_stride_h()) { - stride_h_ = stride_w_ = conv_param.stride(); + // Setup pad dimensions (pad_). + pad_.Reshape(spatial_dim_blob_shape); + int* pad_data = pad_.mutable_cpu_data(); + if (conv_param.has_pad_h() || conv_param.has_pad_w()) { + CHECK_EQ(num_spatial_axes_, 2) + << "pad_h & pad_w can only be used for 2D convolution."; + CHECK_EQ(0, conv_param.pad_size()) + << "Either pad or pad_h/w should be specified; not both."; + pad_data[0] = conv_param.pad_h(); + pad_data[1] = conv_param.pad_w(); } else { - stride_h_ = conv_param.stride_h(); - stride_w_ = conv_param.stride_w(); + const int num_pad_dims = conv_param.pad_size(); + CHECK(num_pad_dims == 0 || num_pad_dims == 1 || + num_pad_dims == num_spatial_axes_) + << "pad must be specified once, or once per spatial dimension " + << "(pad specified " << num_pad_dims << " times; " + << num_spatial_axes_ << " spatial dims);"; + const int kDefaultPad = 0; + for (int i = 0; i < num_spatial_axes_; ++i) { + pad_data[i] = (num_pad_dims == 0) ? kDefaultPad : + conv_param.pad((num_pad_dims == 1) ? 0 : i); + } } // Special case: im2col is the identity for 1x1 convolution with stride 1 // and no padding, so flag for skipping the buffer and transformation. - is_1x1_ = kernel_w_ == 1 && kernel_h_ == 1 - && stride_h_ == 1 && stride_w_ == 1 && pad_h_ == 0 && pad_w_ == 0; + is_1x1_ = true; + for (int i = 0; i < num_spatial_axes_; ++i) { + is_1x1_ &= + kernel_shape_data[i] == 1 && stride_data[i] == 1 && pad_data[i] == 0; + if (!is_1x1_) { break; } + } // Configure output channels and groups. - channels_ = bottom[0]->channels(); + channels_ = bottom[0]->shape(channel_axis_); num_output_ = this->layer_param_.convolution_param().num_output(); CHECK_GT(num_output_, 0); group_ = this->layer_param_.convolution_param().group(); @@ -71,8 +125,29 @@ void BaseConvolutionLayer::LayerSetUp(const vector*>& bottom, // Handle the parameters: weights and biases. // - blobs_[0] holds the filter weights // - blobs_[1] holds the biases (optional) + vector weight_shape(2); + weight_shape[0] = conv_out_channels_; + weight_shape[1] = conv_in_channels_ / group_; + for (int i = 0; i < num_spatial_axes_; ++i) { + weight_shape.push_back(kernel_shape_data[i]); + } bias_term_ = this->layer_param_.convolution_param().bias_term(); + vector bias_shape(bias_term_, num_output_); if (this->blobs_.size() > 0) { + CHECK_EQ(1 + bias_term_, this->blobs_.size()) + << "Incorrect number of weight blobs."; + if (weight_shape != this->blobs_[0]->shape()) { + Blob weight_shaped_blob(weight_shape); + LOG(FATAL) << "Incorrect weight shape: expected shape " + << weight_shaped_blob.shape_string() << "; instead, shape was " + << this->blobs_[0]->shape_string(); + } + if (bias_term_ && bias_shape != this->blobs_[1]->shape()) { + Blob bias_shaped_blob(bias_shape); + LOG(FATAL) << "Incorrect bias shape: expected shape " + << bias_shaped_blob.shape_string() << "; instead, shape was " + << this->blobs_[1]->shape_string(); + } LOG(INFO) << "Skipping parameter initialization"; } else { if (bias_term_) { @@ -82,20 +157,20 @@ void BaseConvolutionLayer::LayerSetUp(const vector*>& bottom, } // Initialize and fill the weights: // output channels x input channels per-group x kernel height x kernel width - this->blobs_[0].reset(new Blob( - conv_out_channels_, conv_in_channels_ / group_, kernel_h_, kernel_w_)); + this->blobs_[0].reset(new Blob(weight_shape)); shared_ptr > weight_filler(GetFiller( this->layer_param_.convolution_param().weight_filler())); weight_filler->Fill(this->blobs_[0].get()); // If necessary, initialize and fill the biases. if (bias_term_) { - vector bias_shape(1, num_output_); this->blobs_[1].reset(new Blob(bias_shape)); shared_ptr > bias_filler(GetFiller( this->layer_param_.convolution_param().bias_filler())); bias_filler->Fill(this->blobs_[1].get()); } } + kernel_dim_ = this->blobs_[0]->count(1); + weight_offset_ = conv_out_channels_ * kernel_dim_ / group_; // Propagate gradients to the parameters (as directed by backward pass). this->param_propagate_down_.resize(this->blobs_.size(), true); } @@ -103,52 +178,68 @@ void BaseConvolutionLayer::LayerSetUp(const vector*>& bottom, template void BaseConvolutionLayer::Reshape(const vector*>& bottom, const vector*>& top) { - CHECK_EQ(4, bottom[0]->num_axes()) << "Input must have 4 axes, " - << "corresponding to (num, channels, height, width)"; - num_ = bottom[0]->num(); - height_ = bottom[0]->height(); - width_ = bottom[0]->width(); - CHECK_EQ(bottom[0]->channels(), channels_) << "Input size incompatible with" - " convolution kernel."; + const int first_spatial_axis = channel_axis_ + 1; + CHECK_EQ(bottom[0]->num_axes(), first_spatial_axis + num_spatial_axes_) + << "bottom num_axes may not change."; + num_ = bottom[0]->count(0, channel_axis_); + CHECK_EQ(bottom[0]->shape(channel_axis_), channels_) + << "Input size incompatible with convolution kernel."; // TODO: generalize to handle inputs of different shapes. for (int bottom_id = 1; bottom_id < bottom.size(); ++bottom_id) { - CHECK_EQ(num_, bottom[bottom_id]->num()) << "Inputs must have same num."; - CHECK_EQ(channels_, bottom[bottom_id]->channels()) - << "Inputs must have same channels."; - CHECK_EQ(height_, bottom[bottom_id]->height()) - << "Inputs must have same height."; - CHECK_EQ(width_, bottom[bottom_id]->width()) - << "Inputs must have same width."; + CHECK(bottom[0]->shape() == bottom[bottom_id]->shape()) + << "All inputs must have the same shape."; } // Shape the tops. compute_output_shape(); + vector top_shape(bottom[0]->shape().begin(), + bottom[0]->shape().begin() + channel_axis_); + top_shape.push_back(num_output_); + for (int i = 0; i < num_spatial_axes_; ++i) { + top_shape.push_back(output_shape_[i]); + } for (int top_id = 0; top_id < top.size(); ++top_id) { - top[top_id]->Reshape(num_, num_output_, height_out_, width_out_); + top[top_id]->Reshape(top_shape); } if (reverse_dimensions()) { - conv_in_height_ = height_out_; - conv_in_width_ = width_out_; - conv_out_spatial_dim_ = height_ * width_; + conv_out_spatial_dim_ = bottom[0]->count(first_spatial_axis); } else { - conv_in_height_ = height_; - conv_in_width_ = width_; - conv_out_spatial_dim_ = height_out_ * width_out_; + conv_out_spatial_dim_ = top[0]->count(first_spatial_axis); } - kernel_dim_ = conv_in_channels_ * kernel_h_ * kernel_w_; - weight_offset_ = conv_out_channels_ * kernel_dim_ / group_ / group_; - col_offset_ = kernel_dim_ * conv_out_spatial_dim_ / group_; + col_offset_ = kernel_dim_ * conv_out_spatial_dim_; output_offset_ = conv_out_channels_ * conv_out_spatial_dim_ / group_; + // Setup input dimensions (conv_input_shape_). + vector bottom_dim_blob_shape(1, num_spatial_axes_ + 1); + conv_input_shape_.Reshape(bottom_dim_blob_shape); + int* conv_input_shape_data = conv_input_shape_.mutable_cpu_data(); + for (int i = 0; i < num_spatial_axes_ + 1; ++i) { + if (reverse_dimensions()) { + conv_input_shape_data[i] = top[0]->shape(channel_axis_ + i); + } else { + conv_input_shape_data[i] = bottom[0]->shape(channel_axis_ + i); + } + } // The im2col result buffer will only hold one image at a time to avoid // overly large memory usage. In the special case of 1x1 convolution // it goes lazily unused to save memory. - if (reverse_dimensions()) { - col_buffer_.Reshape(1, kernel_dim_, height_, width_); - } else { - col_buffer_.Reshape(1, kernel_dim_, height_out_, width_out_); + col_buffer_shape_.clear(); + col_buffer_shape_.push_back(kernel_dim_ * group_); + const int* input_shape_data = input_shape_.cpu_data() + 1; + for (int i = 0; i < num_spatial_axes_; ++i) { + if (reverse_dimensions()) { + col_buffer_shape_.push_back(input_shape_data[i]); + } else { + col_buffer_shape_.push_back(output_shape_[i]); + } } + col_buffer_.Reshape(col_buffer_shape_); + bottom_dim_ = bottom[0]->count(channel_axis_); + top_dim_ = top[0]->count(channel_axis_); + num_kernels_im2col_ = conv_in_channels_ * conv_out_spatial_dim_; + num_kernels_col2im_ = reverse_dimensions() ? top_dim_ : bottom_dim_; // Set up the all ones "bias multiplier" for adding biases by BLAS + out_spatial_dim_ = top[0]->count(first_spatial_axis); if (bias_term_) { - vector bias_multiplier_shape(1, height_out_ * width_out_); + vector bias_multiplier_shape(1, out_spatial_dim_); bias_multiplier_.Reshape(bias_multiplier_shape); caffe_set(bias_multiplier_.count(), Dtype(1), bias_multiplier_.mutable_cpu_data()); @@ -167,7 +258,7 @@ void BaseConvolutionLayer::forward_cpu_gemm(const Dtype* input, } for (int g = 0; g < group_; ++g) { caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, conv_out_channels_ / - group_, conv_out_spatial_dim_, kernel_dim_ / group_, + group_, conv_out_spatial_dim_, kernel_dim_, (Dtype)1., weights + weight_offset_ * g, col_buff + col_offset_ * g, (Dtype)0., output + output_offset_ * g); } @@ -177,7 +268,7 @@ template void BaseConvolutionLayer::forward_cpu_bias(Dtype* output, const Dtype* bias) { caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, num_output_, - height_out_ * width_out_, 1, (Dtype)1., bias, bias_multiplier_.cpu_data(), + out_spatial_dim_, 1, (Dtype)1., bias, bias_multiplier_.cpu_data(), (Dtype)1., output); } @@ -189,7 +280,7 @@ void BaseConvolutionLayer::backward_cpu_gemm(const Dtype* output, col_buff = input; } for (int g = 0; g < group_; ++g) { - caffe_cpu_gemm(CblasTrans, CblasNoTrans, kernel_dim_ / group_, + caffe_cpu_gemm(CblasTrans, CblasNoTrans, kernel_dim_, conv_out_spatial_dim_, conv_out_channels_ / group_, (Dtype)1., weights + weight_offset_ * g, output + output_offset_ * g, (Dtype)0., col_buff + col_offset_ * g); @@ -209,7 +300,7 @@ void BaseConvolutionLayer::weight_cpu_gemm(const Dtype* input, } for (int g = 0; g < group_; ++g) { caffe_cpu_gemm(CblasNoTrans, CblasTrans, conv_out_channels_ / group_, - kernel_dim_ / group_, conv_out_spatial_dim_, + kernel_dim_, conv_out_spatial_dim_, (Dtype)1., output + output_offset_ * g, col_buff + col_offset_ * g, (Dtype)1., weights + weight_offset_ * g); } @@ -218,7 +309,7 @@ void BaseConvolutionLayer::weight_cpu_gemm(const Dtype* input, template void BaseConvolutionLayer::backward_cpu_bias(Dtype* bias, const Dtype* input) { - caffe_cpu_gemv(CblasNoTrans, num_output_, height_out_ * width_out_, 1., + caffe_cpu_gemv(CblasNoTrans, num_output_, out_spatial_dim_, 1., input, bias_multiplier_.cpu_data(), 1., bias); } @@ -236,7 +327,7 @@ void BaseConvolutionLayer::forward_gpu_gemm(const Dtype* input, } for (int g = 0; g < group_; ++g) { caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, conv_out_channels_ / - group_, conv_out_spatial_dim_, kernel_dim_ / group_, + group_, conv_out_spatial_dim_, kernel_dim_, (Dtype)1., weights + weight_offset_ * g, col_buff + col_offset_ * g, (Dtype)0., output + output_offset_ * g); } @@ -246,7 +337,7 @@ template void BaseConvolutionLayer::forward_gpu_bias(Dtype* output, const Dtype* bias) { caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, num_output_, - height_out_ * width_out_, 1, (Dtype)1., bias, bias_multiplier_.gpu_data(), + out_spatial_dim_, 1, (Dtype)1., bias, bias_multiplier_.gpu_data(), (Dtype)1., output); } @@ -258,7 +349,7 @@ void BaseConvolutionLayer::backward_gpu_gemm(const Dtype* output, col_buff = input; } for (int g = 0; g < group_; ++g) { - caffe_gpu_gemm(CblasTrans, CblasNoTrans, kernel_dim_ / group_, + caffe_gpu_gemm(CblasTrans, CblasNoTrans, kernel_dim_, conv_out_spatial_dim_, conv_out_channels_ / group_, (Dtype)1., weights + weight_offset_ * g, output + output_offset_ * g, (Dtype)0., col_buff + col_offset_ * g); @@ -278,7 +369,7 @@ void BaseConvolutionLayer::weight_gpu_gemm(const Dtype* input, } for (int g = 0; g < group_; ++g) { caffe_gpu_gemm(CblasNoTrans, CblasTrans, conv_out_channels_ / group_, - kernel_dim_ / group_, conv_out_spatial_dim_, + kernel_dim_, conv_out_spatial_dim_, (Dtype)1., output + output_offset_ * g, col_buff + col_offset_ * g, (Dtype)1., weights + weight_offset_ * g); } @@ -287,7 +378,7 @@ void BaseConvolutionLayer::weight_gpu_gemm(const Dtype* input, template void BaseConvolutionLayer::backward_gpu_bias(Dtype* bias, const Dtype* input) { - caffe_gpu_gemv(CblasNoTrans, num_output_, height_out_ * width_out_, 1., + caffe_gpu_gemv(CblasNoTrans, num_output_, out_spatial_dim_, 1., input, bias_multiplier_.gpu_data(), 1., bias); } diff --git a/src/caffe/layers/conv_layer.cpp b/src/caffe/layers/conv_layer.cpp index 928ef5ee..5cf26970 100644 --- a/src/caffe/layers/conv_layer.cpp +++ b/src/caffe/layers/conv_layer.cpp @@ -10,10 +10,18 @@ namespace caffe { template void ConvolutionLayer::compute_output_shape() { - this->height_out_ = (this->height_ + 2 * this->pad_h_ - this->kernel_h_) - / this->stride_h_ + 1; - this->width_out_ = (this->width_ + 2 * this->pad_w_ - this->kernel_w_) - / this->stride_w_ + 1; + // input_shape_ + 1 to skip channel axis + const int* input_shape_data = this->input_shape_.cpu_data() + 1; + const int* kernel_shape_data = this->kernel_shape_.cpu_data(); + const int* stride_data = this->stride_.cpu_data(); + const int* pad_data = this->pad_.cpu_data(); + this->output_shape_.clear(); + for (int i = 0; i < this->num_spatial_axes_; ++i) { + const int input_dim = input_shape_data[i]; + const int output_dim = (input_dim + 2 * pad_data[i] - kernel_shape_data[i]) + / stride_data[i] + 1; + this->output_shape_.push_back(output_dim); + } } template @@ -24,11 +32,11 @@ void ConvolutionLayer::Forward_cpu(const vector*>& bottom, const Dtype* bottom_data = bottom[i]->cpu_data(); Dtype* top_data = top[i]->mutable_cpu_data(); for (int n = 0; n < this->num_; ++n) { - this->forward_cpu_gemm(bottom_data + bottom[i]->offset(n), weight, - top_data + top[i]->offset(n)); + this->forward_cpu_gemm(bottom_data + n * this->bottom_dim_, weight, + top_data + n * this->top_dim_); if (this->bias_term_) { const Dtype* bias = this->blobs_[1]->cpu_data(); - this->forward_cpu_bias(top_data + top[i]->offset(n), bias); + this->forward_cpu_bias(top_data + n * this->top_dim_, bias); } } } @@ -47,20 +55,20 @@ void ConvolutionLayer::Backward_cpu(const vector*>& top, if (this->bias_term_ && this->param_propagate_down_[1]) { Dtype* bias_diff = this->blobs_[1]->mutable_cpu_diff(); for (int n = 0; n < this->num_; ++n) { - this->backward_cpu_bias(bias_diff, top_diff + top[i]->offset(n)); + this->backward_cpu_bias(bias_diff, top_diff + n * this->top_dim_); } } if (this->param_propagate_down_[0] || propagate_down[i]) { for (int n = 0; n < this->num_; ++n) { // gradient w.r.t. weight. Note that we will accumulate diffs. if (this->param_propagate_down_[0]) { - this->weight_cpu_gemm(bottom_data + bottom[i]->offset(n), - top_diff + top[i]->offset(n), weight_diff); + this->weight_cpu_gemm(bottom_data + n * this->bottom_dim_, + top_diff + n * this->top_dim_, weight_diff); } // gradient w.r.t. bottom data, if necessary. if (propagate_down[i]) { - this->backward_cpu_gemm(top_diff + top[i]->offset(n), weight, - bottom_diff + bottom[i]->offset(n)); + this->backward_cpu_gemm(top_diff + n * this->top_dim_, weight, + bottom_diff + n * this->bottom_dim_); } } } diff --git a/src/caffe/layers/conv_layer.cu b/src/caffe/layers/conv_layer.cu index b8a98ff7..b429d2b4 100644 --- a/src/caffe/layers/conv_layer.cu +++ b/src/caffe/layers/conv_layer.cu @@ -16,11 +16,11 @@ void ConvolutionLayer::Forward_gpu(const vector*>& bottom, const Dtype* bottom_data = bottom[i]->gpu_data(); Dtype* top_data = top[i]->mutable_gpu_data(); for (int n = 0; n < this->num_; ++n) { - this->forward_gpu_gemm(bottom_data + bottom[i]->offset(n), weight, - top_data + top[i]->offset(n)); + this->forward_gpu_gemm(bottom_data + n * this->bottom_dim_, weight, + top_data + n * this->top_dim_); if (this->bias_term_) { const Dtype* bias = this->blobs_[1]->gpu_data(); - this->forward_gpu_bias(top_data + top[i]->offset(n), bias); + this->forward_gpu_bias(top_data + n * this->top_dim_, bias); } } } @@ -37,7 +37,7 @@ void ConvolutionLayer::Backward_gpu(const vector*>& top, if (this->bias_term_ && this->param_propagate_down_[1]) { Dtype* bias_diff = this->blobs_[1]->mutable_gpu_diff(); for (int n = 0; n < this->num_; ++n) { - this->backward_gpu_bias(bias_diff, top_diff + top[i]->offset(n)); + this->backward_gpu_bias(bias_diff, top_diff + n * this->top_dim_); } } if (this->param_propagate_down_[0] || propagate_down[i]) { @@ -46,13 +46,13 @@ void ConvolutionLayer::Backward_gpu(const vector*>& top, for (int n = 0; n < this->num_; ++n) { // gradient w.r.t. weight. Note that we will accumulate diffs. if (this->param_propagate_down_[0]) { - this->weight_gpu_gemm(bottom_data + bottom[i]->offset(n), - top_diff + top[i]->offset(n), weight_diff); + this->weight_gpu_gemm(bottom_data + n * this->bottom_dim_, + top_diff + n * this->top_dim_, weight_diff); } // gradient w.r.t. bottom data, if necessary. if (propagate_down[i]) { - this->backward_gpu_gemm(top_diff + top[i]->offset(n), weight, - bottom_diff + bottom[i]->offset(n)); + this->backward_gpu_gemm(top_diff + n * this->top_dim_, weight, + bottom_diff + n * this->bottom_dim_); } } } diff --git a/src/caffe/layers/cudnn_conv_layer.cpp b/src/caffe/layers/cudnn_conv_layer.cpp index 104d2b9d..3514fe2a 100644 --- a/src/caffe/layers/cudnn_conv_layer.cpp +++ b/src/caffe/layers/cudnn_conv_layer.cpp @@ -34,14 +34,15 @@ void CuDNNConvolutionLayer::LayerSetUp( } // Set the indexing parameters. - weight_offset_ = (this->num_output_ / this->group_) - * (this->channels_ / this->group_) * this->kernel_h_ * this->kernel_w_; bias_offset_ = (this->num_output_ / this->group_); // Create filter descriptor. + const int* kernel_shape_data = this->kernel_shape_.cpu_data(); + const int kernel_h = kernel_shape_data[0]; + const int kernel_w = kernel_shape_data[1]; cudnn::createFilterDesc(&filter_desc_, this->num_output_ / this->group_, this->channels_ / this->group_, - this->kernel_h_, this->kernel_w_); + kernel_h, kernel_w); // Create tensor descriptor(s) for data and corresponding convolution(s). for (int i = 0; i < bottom.size(); i++) { @@ -68,29 +69,36 @@ template void CuDNNConvolutionLayer::Reshape( const vector*>& bottom, const vector*>& top) { ConvolutionLayer::Reshape(bottom, top); - bottom_offset_ = (this->channels_ / this->group_) - * this->height_ * this->width_; - top_offset_ = (this->num_output_ / this->group_) - * this->height_out_ * this->width_out_; + CHECK_EQ(2, this->num_spatial_axes_) + << "CuDNNConvolution input must have 2 spatial axes " + << "(e.g., height and width). " + << "Use 'engine: CAFFE' for general ND convolution."; + bottom_offset_ = this->bottom_dim_ / this->group_; + top_offset_ = this->top_dim_ / this->group_; + const int height = bottom[0]->shape(this->channel_axis_ + 1); + const int width = bottom[0]->shape(this->channel_axis_ + 2); + const int height_out = top[0]->shape(this->channel_axis_ + 1); + const int width_out = top[0]->shape(this->channel_axis_ + 2); + const int* pad_data = this->pad_.cpu_data(); + const int pad_h = pad_data[0]; + const int pad_w = pad_data[1]; + const int* stride_data = this->stride_.cpu_data(); + const int stride_h = stride_data[0]; + const int stride_w = stride_data[1]; for (int i = 0; i < bottom.size(); i++) { cudnn::setTensor4dDesc(&bottom_descs_[i], this->num_, - this->channels_ / this->group_, - this->height_, this->width_, - this->channels_ * this->height_ * this->width_, - this->height_ * this->width_, - this->width_, 1); + this->channels_ / this->group_, height, width, + this->channels_ * height * width, + height * width, width, 1); cudnn::setTensor4dDesc(&top_descs_[i], this->num_, - this->num_output_ / this->group_, - this->height_out_, this->width_out_, - this->num_output_ * this->height_out_ * this->width_out_, - this->height_out_ * this->width_out_, - this->width_out_, 1); + this->num_output_ / this->group_, height_out, width_out, + this->num_output_ * this->out_spatial_dim_, + this->out_spatial_dim_, width_out, 1); cudnn::setConvolutionDesc(&conv_descs_[i], bottom_descs_[i], - filter_desc_, this->pad_h_, this->pad_w_, - this->stride_h_, this->stride_w_); + filter_desc_, pad_h, pad_w, stride_h, stride_w); } // Tensor descriptor for bias. diff --git a/src/caffe/layers/cudnn_conv_layer.cu b/src/caffe/layers/cudnn_conv_layer.cu index b4e802e1..69115202 100644 --- a/src/caffe/layers/cudnn_conv_layer.cu +++ b/src/caffe/layers/cudnn_conv_layer.cu @@ -14,15 +14,15 @@ __global__ void sync_conv_groups() { } template void CuDNNConvolutionLayer::Forward_gpu( const vector*>& bottom, const vector*>& top) { + const int* kernel_shape_data = this->kernel_shape_.cpu_data(); + const int kernel_h = kernel_shape_data[0]; + const int kernel_w = kernel_shape_data[1]; + const size_t workspace_limit_bytes = + kernel_h * kernel_w * this->channels_ * sizeof(int) + 1; + const Dtype* weight = this->blobs_[0]->gpu_data(); for (int i = 0; i < bottom.size(); ++i) { const Dtype* bottom_data = bottom[i]->gpu_data(); Dtype* top_data = top[i]->mutable_gpu_data(); - const Dtype* weight = this->blobs_[0]->gpu_data(); - - size_t workspace_limit_bytes = this->kernel_h_ * - this->kernel_w_ * - this->channels_ * - sizeof(int) + 1; // Forward through cuDNN in parallel over groups. for (int g = 0; g < this->group_; g++) { @@ -69,7 +69,7 @@ void CuDNNConvolutionLayer::Forward_gpu( CUDNN_CHECK(cudnnConvolutionForward(handle_[g], cudnn::dataType::one, bottom_descs_[i], bottom_data + bottom_offset_ * g, - filter_desc_, weight + weight_offset_ * g, + filter_desc_, weight + this->weight_offset_ * g, conv_descs_[i], algo, workspace, workspaceSizeInBytes, cudnn::dataType::zero, @@ -128,7 +128,7 @@ void CuDNNConvolutionLayer::Backward_gpu(const vector*>& top, top_descs_[i], top_diff + top_offset_ * g, conv_descs_[i], cudnn::dataType::one, - filter_desc_, weight_diff + weight_offset_ * g)); + filter_desc_, weight_diff + this->weight_offset_ * g)); } // Gradient w.r.t. bottom data. @@ -139,7 +139,7 @@ void CuDNNConvolutionLayer::Backward_gpu(const vector*>& top, Dtype* bottom_diff = bottom[i]->mutable_gpu_diff(); CUDNN_CHECK(cudnnConvolutionBackwardData(handle_[2*this->group_ + g], cudnn::dataType::one, - filter_desc_, weight + weight_offset_ * g, + filter_desc_, weight + this->weight_offset_ * g, top_descs_[i], top_diff + top_offset_ * g, conv_descs_[i], cudnn::dataType::zero, diff --git a/src/caffe/layers/deconv_layer.cpp b/src/caffe/layers/deconv_layer.cpp index a4612963..f1d1abf2 100644 --- a/src/caffe/layers/deconv_layer.cpp +++ b/src/caffe/layers/deconv_layer.cpp @@ -10,10 +10,18 @@ namespace caffe { template void DeconvolutionLayer::compute_output_shape() { - this->height_out_ = this->stride_h_ * (this->height_ - 1) + this->kernel_h_ - - 2 * this->pad_h_; - this->width_out_ = this->stride_w_ * (this->width_ - 1) + this->kernel_w_ - - 2 * this->pad_w_; + // input_shape_ + 1 to skip channel axis + const int* input_shape_data = this->input_shape_.cpu_data() + 1; + const int* kernel_shape_data = this->kernel_shape_.cpu_data(); + const int* stride_data = this->stride_.cpu_data(); + const int* pad_data = this->pad_.cpu_data(); + this->output_shape_.clear(); + for (int i = 0; i < this->num_spatial_axes_; ++i) { + const int input_dim = input_shape_data[i]; + const int output_dim = stride_data[i] * (input_dim - 1) + + kernel_shape_data[i] - 2 * pad_data[i]; + this->output_shape_.push_back(output_dim); + } } template @@ -24,11 +32,11 @@ void DeconvolutionLayer::Forward_cpu(const vector*>& bottom, const Dtype* bottom_data = bottom[i]->cpu_data(); Dtype* top_data = top[i]->mutable_cpu_data(); for (int n = 0; n < this->num_; ++n) { - this->backward_cpu_gemm(bottom_data + bottom[i]->offset(n), weight, - top_data + top[i]->offset(n)); + this->backward_cpu_gemm(bottom_data + n * this->bottom_dim_, weight, + top_data + n * this->top_dim_); if (this->bias_term_) { const Dtype* bias = this->blobs_[1]->cpu_data(); - this->forward_cpu_bias(top_data + top[i]->offset(n), bias); + this->forward_cpu_bias(top_data + n * this->top_dim_, bias); } } } @@ -47,21 +55,21 @@ void DeconvolutionLayer::Backward_cpu(const vector*>& top, if (this->bias_term_ && this->param_propagate_down_[1]) { Dtype* bias_diff = this->blobs_[1]->mutable_cpu_diff(); for (int n = 0; n < this->num_; ++n) { - this->backward_cpu_bias(bias_diff, top_diff + top[i]->offset(n)); + this->backward_cpu_bias(bias_diff, top_diff + n * this->top_dim_); } } if (this->param_propagate_down_[0] || propagate_down[i]) { for (int n = 0; n < this->num_; ++n) { // Gradient w.r.t. weight. Note that we will accumulate diffs. if (this->param_propagate_down_[0]) { - this->weight_cpu_gemm(top_diff + top[i]->offset(n), - bottom_data + bottom[i]->offset(n), weight_diff); + this->weight_cpu_gemm(top_diff + n * this->top_dim_, + bottom_data + n * this->bottom_dim_, weight_diff); } // Gradient w.r.t. bottom data, if necessary, reusing the column buffer // we might have just computed above. if (propagate_down[i]) { - this->forward_cpu_gemm(top_diff + top[i]->offset(n), weight, - bottom_diff + bottom[i]->offset(n), + this->forward_cpu_gemm(top_diff + n * this->top_dim_, weight, + bottom_diff + n * this->bottom_dim_, this->param_propagate_down_[0]); } } diff --git a/src/caffe/layers/deconv_layer.cu b/src/caffe/layers/deconv_layer.cu index 8a1eed8a..ea83f56f 100644 --- a/src/caffe/layers/deconv_layer.cu +++ b/src/caffe/layers/deconv_layer.cu @@ -16,11 +16,11 @@ void DeconvolutionLayer::Forward_gpu(const vector*>& bottom, const Dtype* bottom_data = bottom[i]->gpu_data(); Dtype* top_data = top[i]->mutable_gpu_data(); for (int n = 0; n < this->num_; ++n) { - this->backward_gpu_gemm(bottom_data + bottom[i]->offset(n), weight, - top_data + top[i]->offset(n)); + this->backward_gpu_gemm(bottom_data + n * this->bottom_dim_, weight, + top_data + n * this->top_dim_); if (this->bias_term_) { const Dtype* bias = this->blobs_[1]->gpu_data(); - this->forward_gpu_bias(top_data + top[i]->offset(n), bias); + this->forward_gpu_bias(top_data + n * this->top_dim_, bias); } } } @@ -39,20 +39,20 @@ void DeconvolutionLayer::Backward_gpu(const vector*>& top, if (this->bias_term_ && this->param_propagate_down_[1]) { Dtype* bias_diff = this->blobs_[1]->mutable_gpu_diff(); for (int n = 0; n < this->num_; ++n) { - this->backward_gpu_bias(bias_diff, top_diff + top[i]->offset(n)); + this->backward_gpu_bias(bias_diff, top_diff + n * this->top_dim_); } } if (this->param_propagate_down_[0] || propagate_down[i]) { for (int n = 0; n < this->num_; ++n) { // gradient w.r.t. weight. Note that we will accumulate diffs. if (this->param_propagate_down_[0]) { - this->weight_gpu_gemm(top_diff + top[i]->offset(n), - bottom_data + bottom[i]->offset(n), weight_diff); + this->weight_gpu_gemm(top_diff + n * this->top_dim_, + bottom_data + n * this->bottom_dim_, weight_diff); } // gradient w.r.t. bottom data, if necessary. if (propagate_down[i]) { - this->forward_gpu_gemm(top_diff + top[i]->offset(n), weight, - bottom_diff + bottom[i]->offset(n), + this->forward_gpu_gemm(top_diff + this->top_dim_, weight, + bottom_diff + n * this->bottom_dim_, this->param_propagate_down_[0]); } } diff --git a/src/caffe/layers/im2col_layer.cpp b/src/caffe/layers/im2col_layer.cpp index 1c802714..595c9dbb 100644 --- a/src/caffe/layers/im2col_layer.cpp +++ b/src/caffe/layers/im2col_layer.cpp @@ -11,54 +11,106 @@ template void Im2colLayer::LayerSetUp(const vector*>& bottom, const vector*>& top) { ConvolutionParameter conv_param = this->layer_param_.convolution_param(); - CHECK(!conv_param.has_kernel_size() != - !(conv_param.has_kernel_h() && conv_param.has_kernel_w())) - << "Filter size is kernel_size OR kernel_h and kernel_w; not both"; - CHECK(conv_param.has_kernel_size() || - (conv_param.has_kernel_h() && conv_param.has_kernel_w())) - << "For non-square filters both kernel_h and kernel_w are required."; - CHECK((!conv_param.has_pad() && conv_param.has_pad_h() - && conv_param.has_pad_w()) - || (!conv_param.has_pad_h() && !conv_param.has_pad_w())) - << "pad is pad OR pad_h and pad_w are required."; - CHECK((!conv_param.has_stride() && conv_param.has_stride_h() - && conv_param.has_stride_w()) - || (!conv_param.has_stride_h() && !conv_param.has_stride_w())) - << "Stride is stride OR stride_h and stride_w are required."; - if (conv_param.has_kernel_size()) { - kernel_h_ = kernel_w_ = conv_param.kernel_size(); + force_nd_im2col_ = conv_param.force_nd_im2col(); + const int input_num_dims = bottom[0]->shape().size(); + channel_axis_ = bottom[0]->CanonicalAxisIndex(conv_param.axis()); + const int first_spatial_dim = channel_axis_ + 1; + num_spatial_axes_ = input_num_dims - first_spatial_dim; + CHECK_GE(num_spatial_axes_, 1); + vector dim_blob_shape(1, num_spatial_axes_); + // Setup filter kernel dimensions (kernel_shape_). + kernel_shape_.Reshape(dim_blob_shape); + int* kernel_shape_data = kernel_shape_.mutable_cpu_data(); + if (conv_param.has_kernel_h() || conv_param.has_kernel_w()) { + CHECK_EQ(num_spatial_axes_, 2) + << "kernel_h & kernel_w can only be used for 2D convolution."; + CHECK_EQ(0, conv_param.kernel_size_size()) + << "Either kernel_size or kernel_h/w should be specified; not both."; + kernel_shape_data[0] = conv_param.kernel_h(); + kernel_shape_data[1] = conv_param.kernel_w(); } else { - kernel_h_ = conv_param.kernel_h(); - kernel_w_ = conv_param.kernel_w(); + const int num_kernel_dims = conv_param.kernel_size_size(); + CHECK(num_kernel_dims == 1 || num_kernel_dims == num_spatial_axes_) + << "kernel_size must be specified once, or once per spatial dimension " + << "(kernel_size specified " << num_kernel_dims << " times; " + << num_spatial_axes_ << " spatial dims);"; + for (int i = 0; i < num_spatial_axes_; ++i) { + kernel_shape_data[i] = + conv_param.kernel_size((num_kernel_dims == 1) ? 0 : i); + } } - CHECK_GT(kernel_h_, 0) << "Filter dimensions cannot be zero."; - CHECK_GT(kernel_w_, 0) << "Filter dimensions cannot be zero."; - if (!conv_param.has_pad_h()) { - pad_h_ = pad_w_ = conv_param.pad(); + for (int i = 0; i < num_spatial_axes_; ++i) { + CHECK_GT(kernel_shape_data[i], 0) << "Filter dimensions must be nonzero."; + } + // Setup stride dimensions (stride_). + stride_.Reshape(dim_blob_shape); + int* stride_data = stride_.mutable_cpu_data(); + if (conv_param.has_stride_h() || conv_param.has_stride_w()) { + CHECK_EQ(num_spatial_axes_, 2) + << "stride_h & stride_w can only be used for 2D convolution."; + CHECK_EQ(0, conv_param.stride_size()) + << "Either stride or stride_h/w should be specified; not both."; + stride_data[0] = conv_param.stride_h(); + stride_data[1] = conv_param.stride_w(); } else { - pad_h_ = conv_param.pad_h(); - pad_w_ = conv_param.pad_w(); + const int num_stride_dims = conv_param.stride_size(); + CHECK(num_stride_dims == 0 || num_stride_dims == 1 || + num_stride_dims == num_spatial_axes_) + << "stride must be specified once, or once per spatial dimension " + << "(stride specified " << num_stride_dims << " times; " + << num_spatial_axes_ << " spatial dims);"; + const int kDefaultStride = 1; + for (int i = 0; i < num_spatial_axes_; ++i) { + stride_data[i] = (num_stride_dims == 0) ? kDefaultStride : + conv_param.stride((num_stride_dims == 1) ? 0 : i); + CHECK_GT(stride_data[i], 0) << "Stride dimensions must be nonzero."; + } } - if (!conv_param.has_stride_h()) { - stride_h_ = stride_w_ = conv_param.stride(); + // Setup pad dimensions (pad_). + pad_.Reshape(dim_blob_shape); + int* pad_data = pad_.mutable_cpu_data(); + if (conv_param.has_pad_h() || conv_param.has_pad_w()) { + CHECK_EQ(num_spatial_axes_, 2) + << "pad_h & pad_w can only be used for 2D convolution."; + CHECK_EQ(0, conv_param.pad_size()) + << "Either pad or pad_h/w should be specified; not both."; + pad_data[0] = conv_param.pad_h(); + pad_data[1] = conv_param.pad_w(); } else { - stride_h_ = conv_param.stride_h(); - stride_w_ = conv_param.stride_w(); + const int num_pad_dims = conv_param.pad_size(); + CHECK(num_pad_dims == 0 || num_pad_dims == 1 || + num_pad_dims == num_spatial_axes_) + << "pad must be specified once, or once per spatial dimension " + << "(pad specified " << num_pad_dims << " times; " + << num_spatial_axes_ << " spatial dims);"; + const int kDefaultPad = 0; + for (int i = 0; i < num_spatial_axes_; ++i) { + pad_data[i] = (num_pad_dims == 0) ? kDefaultPad : + conv_param.pad((num_pad_dims == 1) ? 0 : i); + } } } template void Im2colLayer::Reshape(const vector*>& bottom, const vector*>& top) { - CHECK_EQ(4, bottom[0]->num_axes()) << "Input must have 4 axes, " - << "corresponding to (num, channels, height, width)"; - channels_ = bottom[0]->channels(); - height_ = bottom[0]->height(); - width_ = bottom[0]->width(); - top[0]->Reshape( - bottom[0]->num(), channels_ * kernel_h_ * kernel_w_, - (height_ + 2 * pad_h_ - kernel_h_) / stride_h_ + 1, - (width_ + 2 * pad_w_ - kernel_w_) / stride_w_ + 1); + vector top_shape = bottom[0]->shape(); + const int* kernel_shape_data = kernel_shape_.cpu_data(); + const int* stride_data = stride_.cpu_data(); + const int* pad_data = pad_.cpu_data(); + for (int i = 0; i < num_spatial_axes_; ++i) { + top_shape[channel_axis_] *= kernel_shape_data[i]; + const int input_dim = bottom[0]->shape(channel_axis_ + i + 1); + const int output_dim = (input_dim + 2 * pad_data[i] - kernel_shape_data[i]) + / stride_data[i] + 1; + top_shape[channel_axis_ + i + 1] = output_dim; + } + top[0]->Reshape(top_shape); + num_ = bottom[0]->count(0, channel_axis_); + bottom_dim_ = bottom[0]->count(channel_axis_); + top_dim_ = top[0]->count(channel_axis_); + + channels_ = bottom[0]->shape(channel_axis_); } template @@ -66,10 +118,27 @@ void Im2colLayer::Forward_cpu(const vector*>& bottom, const vector*>& top) { const Dtype* bottom_data = bottom[0]->cpu_data(); Dtype* top_data = top[0]->mutable_cpu_data(); - for (int n = 0; n < bottom[0]->num(); ++n) { - im2col_cpu(bottom_data + bottom[0]->offset(n), channels_, height_, - width_, kernel_h_, kernel_w_, pad_h_, pad_w_, - stride_h_, stride_w_, top_data + top[0]->offset(n)); + for (int n = 0; n < num_; ++n) { + DCHECK_EQ(bottom[0]->shape().size() - channel_axis_, num_spatial_axes_ + 1); + DCHECK_EQ(top[0]->shape().size() - channel_axis_, num_spatial_axes_ + 1); + DCHECK_EQ(kernel_shape_.count(), num_spatial_axes_); + DCHECK_EQ(pad_.count(), num_spatial_axes_); + DCHECK_EQ(stride_.count(), num_spatial_axes_); + if (!force_nd_im2col_ && num_spatial_axes_ == 2) { + im2col_cpu(bottom_data + n * bottom_dim_, channels_, + bottom[0]->shape(channel_axis_ + 1), + bottom[0]->shape(channel_axis_ + 2), + kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1], + pad_.cpu_data()[0], pad_.cpu_data()[1], + stride_.cpu_data()[0], stride_.cpu_data()[1], + top_data + n * top_dim_); + } else { + im2col_nd_cpu(bottom_data + n * bottom_dim_, num_spatial_axes_, + bottom[0]->shape().data() + channel_axis_, + top[0]->shape().data() + channel_axis_, + kernel_shape_.cpu_data(), pad_.cpu_data(), stride_.cpu_data(), + top_data + n * top_dim_); + } } } @@ -78,10 +147,22 @@ void Im2colLayer::Backward_cpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom) { const Dtype* top_diff = top[0]->cpu_diff(); Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); - for (int n = 0; n < top[0]->num(); ++n) { - col2im_cpu(top_diff + top[0]->offset(n), channels_, height_, width_, - kernel_h_, kernel_w_, pad_h_, pad_w_, - stride_h_, stride_w_, bottom_diff + bottom[0]->offset(n)); + for (int n = 0; n < num_; ++n) { + if (!force_nd_im2col_ && num_spatial_axes_ == 2) { + col2im_cpu(top_diff + n * top_dim_, channels_, + bottom[0]->shape(channel_axis_ + 1), + bottom[0]->shape(channel_axis_ + 2), + kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1], + pad_.cpu_data()[0], pad_.cpu_data()[1], + stride_.cpu_data()[0], stride_.cpu_data()[1], + bottom_diff + n * bottom_dim_); + } else { + col2im_nd_cpu(top_diff + n * top_dim_, num_spatial_axes_, + bottom[0]->shape().data() + channel_axis_, + top[0]->shape().data() + channel_axis_, + kernel_shape_.cpu_data(), pad_.cpu_data(), stride_.cpu_data(), + bottom_diff + n * bottom_dim_); + } } } diff --git a/src/caffe/layers/im2col_layer.cu b/src/caffe/layers/im2col_layer.cu index 9c338b14..cd507623 100644 --- a/src/caffe/layers/im2col_layer.cu +++ b/src/caffe/layers/im2col_layer.cu @@ -12,10 +12,23 @@ void Im2colLayer::Forward_gpu(const vector*>& bottom, const vector*>& top) { const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* top_data = top[0]->mutable_gpu_data(); - for (int n = 0; n < bottom[0]->num(); ++n) { - im2col_gpu(bottom_data + bottom[0]->offset(n), channels_, height_, - width_, kernel_h_, kernel_w_, pad_h_, pad_w_, - stride_h_, stride_w_, top_data + top[0]->offset(n)); + const int num_kernels = channels_ * top[0]->count(channel_axis_ + 1); + for (int n = 0; n < num_; ++n) { + if (!force_nd_im2col_ && num_spatial_axes_ == 2) { + im2col_gpu(bottom_data + n * bottom_dim_, channels_, + bottom[0]->shape(channel_axis_ + 1), + bottom[0]->shape(channel_axis_ + 2), + kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1], + pad_.cpu_data()[0], pad_.cpu_data()[1], + stride_.cpu_data()[0], stride_.cpu_data()[1], + top_data + n * top_dim_); + } else { + im2col_nd_gpu(bottom_data + n * bottom_dim_, num_spatial_axes_, + num_kernels, bottom[0]->gpu_shape() + channel_axis_, + top[0]->gpu_shape() + channel_axis_, + kernel_shape_.gpu_data(), pad_.gpu_data(), stride_.gpu_data(), + top_data + n * top_dim_); + } } } @@ -24,10 +37,22 @@ void Im2colLayer::Backward_gpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom) { const Dtype* top_diff = top[0]->gpu_diff(); Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); - for (int n = 0; n < top[0]->num(); ++n) { - col2im_gpu(top_diff + top[0]->offset(n), channels_, height_, width_, - kernel_h_, kernel_w_, pad_h_, pad_w_, - stride_h_, stride_w_, bottom_diff + bottom[0]->offset(n)); + for (int n = 0; n < num_; ++n) { + if (!force_nd_im2col_ && num_spatial_axes_ == 2) { + col2im_gpu(top_diff + n * top_dim_, channels_, + bottom[0]->shape(channel_axis_ + 1), + bottom[0]->shape(channel_axis_ + 2), + kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1], + pad_.cpu_data()[0], pad_.cpu_data()[1], + stride_.cpu_data()[0], stride_.cpu_data()[1], + bottom_diff + n * bottom_dim_); + } else { + col2im_nd_gpu(top_diff + n * top_dim_, num_spatial_axes_, bottom_dim_, + bottom[0]->gpu_shape() + channel_axis_, + top[0]->gpu_shape() + channel_axis_, + kernel_shape_.gpu_data(), pad_.gpu_data(), stride_.gpu_data(), + bottom_diff + n * bottom_dim_); + } } } diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 86683eb4..f52c941b 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -508,6 +508,13 @@ message ConvolutionParameter { // N independent 3D convolutions, sliding (C/g)-channels // filters across the spatial axes (D, H, W) of the input. optional int32 axis = 16 [default = 1]; + + // Whether to force use of the general ND convolution, even if a specific + // implementation for blobs of the appropriate number of spatial dimensions + // is available. (Currently, there is only a 2D-specific convolution + // implementation; for input blobs with num_axes != 2, this option is + // ignored and the ND implementation will be used.) + optional bool force_nd_im2col = 17 [default = false]; } message DataParameter { diff --git a/src/caffe/test/test_convolution_layer.cpp b/src/caffe/test/test_convolution_layer.cpp index 67d41fff..9df979a2 100644 --- a/src/caffe/test/test_convolution_layer.cpp +++ b/src/caffe/test/test_convolution_layer.cpp @@ -19,54 +19,87 @@ template void caffe_conv(const Blob* in, ConvolutionParameter* conv_param, const vector > >& weights, Blob* out) { + const bool has_depth = (out->num_axes() == 5); + if (!has_depth) { CHECK_EQ(4, out->num_axes()); } // Kernel size, stride, and pad int kernel_h, kernel_w; - if (conv_param->has_kernel_size()) { - kernel_h = kernel_w = conv_param->kernel_size(); - } else { + if (conv_param->has_kernel_h() || conv_param->has_kernel_w()) { kernel_h = conv_param->kernel_h(); kernel_w = conv_param->kernel_w(); + } else { + kernel_h = kernel_w = conv_param->kernel_size(0); } int pad_h, pad_w; - if (!conv_param->has_pad_h()) { - pad_h = pad_w = conv_param->pad(); - } else { + if (conv_param->has_pad_h() || conv_param->has_pad_w()) { pad_h = conv_param->pad_h(); pad_w = conv_param->pad_w(); + } else { + pad_h = pad_w = conv_param->pad_size() ? conv_param->pad(0) : 0; } int stride_h, stride_w; - if (!conv_param->has_stride_h()) { - stride_h = stride_w = conv_param->stride(); - } else { + if (conv_param->has_stride_h() || conv_param->has_stride_w()) { stride_h = conv_param->stride_h(); stride_w = conv_param->stride_w(); + } else { + stride_h = stride_w = conv_param->stride_size() ? conv_param->stride(0) : 1; + } + int kernel_d, pad_d, stride_d; + if (has_depth) { + kernel_d = kernel_h; + stride_d = stride_h; + pad_d = pad_h; + } else { + kernel_d = stride_d = 1; + pad_d = 0; } // Groups int groups = conv_param->group(); - int o_g = out->channels() / groups; - int k_g = in->channels() / groups; + int o_g = out->shape(1) / groups; + int k_g = in->shape(1) / groups; int o_head, k_head; // Convolution - const Dtype* in_data = in->cpu_data(); - const Dtype* weight_data = weights[0]->cpu_data(); + vector weight_offset(4 + has_depth); + vector in_offset(4 + has_depth); + vector out_offset(4 + has_depth); Dtype* out_data = out->mutable_cpu_data(); - for (int n = 0; n < out->num(); n++) { + for (int n = 0; n < out->shape(0); n++) { for (int g = 0; g < groups; g++) { o_head = o_g * g; k_head = k_g * g; for (int o = 0; o < o_g; o++) { for (int k = 0; k < k_g; k++) { - for (int y = 0; y < out->height(); y++) { - for (int x = 0; x < out->width(); x++) { - for (int p = 0; p < kernel_h; p++) { - for (int q = 0; q < kernel_w; q++) { - int in_y = y * stride_h - pad_h + p; - int in_x = x * stride_w - pad_w + q; - if (in_y >= 0 && in_y < in->height() - && in_x >= 0 && in_x < in->width()) { - out_data[out->offset(n, o + o_head, y, x)] += - in_data[in->offset(n, k + k_head, in_y, in_x)] - * weight_data[weights[0]->offset(o + o_head, k, p, q)]; + for (int z = 0; z < (has_depth ? out->shape(2) : 1); z++) { + for (int y = 0; y < out->shape(2 + has_depth); y++) { + for (int x = 0; x < out->shape(3 + has_depth); x++) { + for (int r = 0; r < kernel_d; r++) { + for (int p = 0; p < kernel_h; p++) { + for (int q = 0; q < kernel_w; q++) { + int in_z = z * stride_d - pad_d + r; + int in_y = y * stride_h - pad_h + p; + int in_x = x * stride_w - pad_w + q; + if (in_z >= 0 && in_z < (has_depth ? in->shape(2) : 1) + && in_y >= 0 && in_y < in->shape(2 + has_depth) + && in_x >= 0 && in_x < in->shape(3 + has_depth)) { + weight_offset[0] = o + o_head; + weight_offset[1] = k; + if (has_depth) { weight_offset[2] = r; } + weight_offset[2 + has_depth] = p; + weight_offset[3 + has_depth] = q; + in_offset[0] = n; + in_offset[1] = k + k_head; + if (has_depth) { in_offset[2] = in_z; } + in_offset[2 + has_depth] = in_y; + in_offset[3 + has_depth] = in_x; + out_offset[0] = n; + out_offset[1] = o + o_head; + if (has_depth) { out_offset[2] = z; } + out_offset[2 + has_depth] = y; + out_offset[3 + has_depth] = x; + out_data[out->offset(out_offset)] += + in->data_at(in_offset) + * weights[0]->data_at(weight_offset); + } + } } } } @@ -79,11 +112,18 @@ void caffe_conv(const Blob* in, ConvolutionParameter* conv_param, // Bias if (conv_param->bias_term()) { const Dtype* bias_data = weights[1]->cpu_data(); - for (int n = 0; n < out->num(); n++) { - for (int o = 0; o < out->channels(); o++) { - for (int y = 0; y < out->height(); y++) { - for (int x = 0; x < out->width(); x++) { - out_data[out->offset(n, o, y, x)] += bias_data[o]; + for (int n = 0; n < out->shape(0); n++) { + for (int o = 0; o < out->shape(1); o++) { + for (int z = 0; z < (has_depth ? out->shape(2) : 1); z++) { + for (int y = 0; y < out->shape(2 + has_depth); y++) { + for (int x = 0; x < out->shape(3 + has_depth); x++) { + out_offset[0] = n; + out_offset[1] = o; + if (has_depth) { out_offset[2] = z; } + out_offset[2 + has_depth] = y; + out_offset[3 + has_depth] = x; + out_data[out->offset(out_offset)] += bias_data[o]; + } } } } @@ -150,8 +190,8 @@ TYPED_TEST(ConvolutionLayerTest, TestSetup) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_num_output(4); this->blob_bottom_vec_.push_back(this->blob_bottom_2_); this->blob_top_vec_.push_back(this->blob_top_2_); @@ -188,8 +228,8 @@ TYPED_TEST(ConvolutionLayerTest, TestSimpleConvolution) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_num_output(4); convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("constant"); @@ -217,13 +257,98 @@ TYPED_TEST(ConvolutionLayerTest, TestSimpleConvolution) { } } +TYPED_TEST(ConvolutionLayerTest, Test0DConvolution) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + const int kNumOutput = 3; + convolution_param->set_num_output(kNumOutput); + convolution_param->set_axis(3); + convolution_param->mutable_weight_filler()->set_type("gaussian"); + convolution_param->mutable_bias_filler()->set_type("gaussian"); + shared_ptr > layer( + new ConvolutionLayer(layer_param)); + vector top_shape = this->blob_bottom_->shape(); + top_shape[3] = kNumOutput; + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(top_shape, this->blob_top_->shape()); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Check against reference convolution. + vector weight_offset(2); + const Blob* weight = layer->blobs()[0].get(); + const Blob* bias = layer->blobs()[1].get(); + const int num = this->blob_top_->count(3); + const int dim = this->blob_top_->shape(3); + const int bottom_dim = this->blob_bottom_->shape(3); + for (int n = 0; n < num; ++n) { + for (int d = 0; d < dim; ++d) { + weight_offset[0] = d; + Dtype value = bias->cpu_data()[d]; + for (int bottom_d = 0; bottom_d < bottom_dim; ++bottom_d) { + weight_offset[1] = bottom_d; + value += weight->data_at(weight_offset) * + this->blob_bottom_->cpu_data()[n * bottom_dim + bottom_d]; + } + EXPECT_NEAR(value, this->blob_top_->cpu_data()[n * dim + d], 1e-4); + } + } +} + +TYPED_TEST(ConvolutionLayerTest, TestSimple3DConvolution) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_2_); + this->blob_top_vec_.push_back(this->blob_top_2_); + vector bottom_shape(5); + bottom_shape[0] = this->blob_bottom_vec_[0]->shape(0); + bottom_shape[1] = this->blob_bottom_vec_[0]->shape(1); + bottom_shape[2] = 5; + bottom_shape[3] = this->blob_bottom_vec_[0]->shape(2); + bottom_shape[4] = this->blob_bottom_vec_[0]->shape(3); + FillerParameter filler_param; + GaussianFiller filler(filler_param); + for (int i = 0; i < this->blob_bottom_vec_.size(); ++i) { + this->blob_bottom_vec_[i]->Reshape(bottom_shape); + filler.Fill(this->blob_bottom_vec_[i]); + } + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); + convolution_param->set_num_output(4); + convolution_param->mutable_weight_filler()->set_type("gaussian"); + convolution_param->mutable_bias_filler()->set_type("gaussian"); + shared_ptr > layer( + new ConvolutionLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Check against reference convolution. + const Dtype* top_data; + const Dtype* ref_top_data; + caffe_conv(this->blob_bottom_, convolution_param, layer->blobs(), + this->MakeReferenceTop(this->blob_top_)); + top_data = this->blob_top_->cpu_data(); + ref_top_data = this->ref_blob_top_->cpu_data(); + for (int i = 0; i < this->blob_top_->count(); ++i) { + EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + } + caffe_conv(this->blob_bottom_2_, convolution_param, layer->blobs(), + this->MakeReferenceTop(this->blob_top_2_)); + top_data = this->blob_top_2_->cpu_data(); + ref_top_data = this->ref_blob_top_->cpu_data(); + for (int i = 0; i < this->blob_top_->count(); ++i) { + EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + } +} + TYPED_TEST(ConvolutionLayerTest, Test1x1Convolution) { typedef typename TypeParam::Dtype Dtype; LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(1); - convolution_param->set_stride(1); + convolution_param->add_kernel_size(1); + convolution_param->add_stride(1); convolution_param->set_num_output(4); convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("constant"); @@ -249,8 +374,8 @@ TYPED_TEST(ConvolutionLayerTest, TestSimpleConvolutionGroup) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_num_output(3); convolution_param->set_group(3); convolution_param->mutable_weight_filler()->set_type("gaussian"); @@ -288,8 +413,8 @@ TYPED_TEST(ConvolutionLayerTest, TestSobelConvolution) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_num_output(1); convolution_param->set_bias_term(false); shared_ptr > layer( @@ -350,14 +475,11 @@ TYPED_TEST(ConvolutionLayerTest, TestSobelConvolution) { convolution_param->set_bias_term(false); layer.reset(new ConvolutionLayer(layer_param)); layer->blobs().resize(1); - layer->blobs()[0].reset(new Blob(1, 3, 1, 3)); + layer->blobs()[0].reset(new Blob(1, 1, 1, 3)); Dtype* weights_2 = layer->blobs()[0]->mutable_cpu_data(); - for (int c = 0; c < 3; ++c) { - int i = c * 3; // 1 x 3 filter - weights_2[i + 0] = -1; - weights_2[i + 1] = 0; - weights_2[i + 2] = 1; - } + weights_2[0] = -1; + weights_2[1] = 0; + weights_2[2] = 1; layer->SetUp(sep_blob_bottom_vec, sep_blob_top_vec); layer->Forward(sep_blob_bottom_vec, sep_blob_top_vec); // Test equivalence of full and separable filters. @@ -368,6 +490,124 @@ TYPED_TEST(ConvolutionLayerTest, TestSobelConvolution) { } } +TYPED_TEST(ConvolutionLayerTest, TestNDAgainst2D) { + typedef typename TypeParam::Dtype Dtype; + const int kernel_h = 11; + const int kernel_w = 13; + vector bottom_shape(4); + bottom_shape[0] = 15; + bottom_shape[1] = 18; + bottom_shape[2] = kernel_h * 2; + bottom_shape[3] = kernel_w * 2; + FillerParameter filler_param; + GaussianFiller filler(filler_param); + for (int i = 0; i < this->blob_bottom_vec_.size(); ++i) { + this->blob_bottom_vec_[i]->Reshape(bottom_shape); + filler.Fill(this->blob_bottom_vec_[i]); + } + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + convolution_param->set_num_output(12); + convolution_param->set_bias_term(false); + convolution_param->set_group(6); + convolution_param->set_kernel_h(kernel_h); + convolution_param->set_kernel_w(kernel_w); + convolution_param->mutable_weight_filler()->set_type("gaussian"); + Blob weights; + Blob top_diff; + // Shape and fill weights and top_diff. + bool copy_diff; + bool reshape; + { + ConvolutionLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + top_diff.ReshapeLike(*this->blob_top_); + filler.Fill(&top_diff); + ASSERT_EQ(1, layer.blobs().size()); + copy_diff = false; reshape = true; + weights.CopyFrom(*layer.blobs()[0], copy_diff, reshape); + } + vector propagate_down(1, true); + Blob result_2d; + Blob backward_result_2d; + Blob backward_weight_result_2d; + // Test with 2D im2col + { + caffe_set(this->blob_top_->count(), Dtype(0), + this->blob_top_->mutable_cpu_data()); + caffe_set(this->blob_bottom_->count(), Dtype(0), + this->blob_bottom_->mutable_cpu_diff()); + caffe_set(weights.count(), Dtype(0), weights.mutable_cpu_diff()); + // Do SetUp and Forward; save Forward result in result_2d. + convolution_param->set_force_nd_im2col(false); + ConvolutionLayer layer_2d(layer_param); + layer_2d.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(1, layer_2d.blobs().size()); + copy_diff = false; reshape = false; + layer_2d.blobs()[0]->CopyFrom(weights, copy_diff, reshape); + layer_2d.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + copy_diff = false; reshape = true; + result_2d.CopyFrom(*this->blob_top_, copy_diff, reshape); + // Copy pre-generated top diff into actual top diff; + // do Backward and save result in backward_result_2d. + ASSERT_EQ(this->blob_top_->shape(), top_diff.shape()); + caffe_copy(top_diff.count(), top_diff.cpu_data(), + this->blob_top_->mutable_cpu_diff()); + layer_2d.Backward(this->blob_top_vec_, propagate_down, + this->blob_bottom_vec_); + copy_diff = true; reshape = true; + backward_result_2d.CopyFrom(*this->blob_bottom_, copy_diff, reshape); + backward_weight_result_2d.CopyFrom(weights, copy_diff, reshape); + } + Blob result_nd; + Blob backward_result_nd; + Blob backward_weight_result_nd; + // Test with ND im2col + { + caffe_set(this->blob_top_->count(), Dtype(0), + this->blob_top_->mutable_cpu_data()); + caffe_set(this->blob_bottom_->count(), Dtype(0), + this->blob_bottom_->mutable_cpu_diff()); + caffe_set(weights.count(), Dtype(0), weights.mutable_cpu_diff()); + // Do SetUp and Forward; save Forward result in result_nd. + convolution_param->set_force_nd_im2col(true); + ConvolutionLayer layer_nd(layer_param); + layer_nd.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(1, layer_nd.blobs().size()); + copy_diff = false; reshape = false; + layer_nd.blobs()[0]->CopyFrom(weights, copy_diff, reshape); + layer_nd.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + copy_diff = false; reshape = true; + result_nd.CopyFrom(*this->blob_top_, copy_diff, reshape); + // Copy pre-generated top diff into actual top diff; + // do Backward and save result in backward_result_nd. + ASSERT_EQ(this->blob_top_->shape(), top_diff.shape()); + caffe_copy(top_diff.count(), top_diff.cpu_data(), + this->blob_top_->mutable_cpu_diff()); + layer_nd.Backward(this->blob_top_vec_, propagate_down, + this->blob_bottom_vec_); + copy_diff = true; reshape = true; + backward_result_nd.CopyFrom(*this->blob_bottom_, copy_diff, reshape); + backward_weight_result_nd.CopyFrom(weights, copy_diff, reshape); + } + ASSERT_EQ(result_nd.count(), result_2d.count()); + for (int i = 0; i < result_2d.count(); ++i) { + EXPECT_EQ(result_2d.cpu_data()[i], result_nd.cpu_data()[i]); + } + ASSERT_EQ(backward_result_nd.count(), backward_result_2d.count()); + for (int i = 0; i < backward_result_2d.count(); ++i) { + EXPECT_EQ(backward_result_2d.cpu_diff()[i], + backward_result_nd.cpu_diff()[i]); + } + ASSERT_EQ(backward_weight_result_nd.count(), + backward_weight_result_2d.count()); + for (int i = 0; i < backward_weight_result_2d.count(); ++i) { + EXPECT_EQ(backward_weight_result_2d.cpu_diff()[i], + backward_weight_result_nd.cpu_diff()[i]); + } +} + TYPED_TEST(ConvolutionLayerTest, TestGradient) { typedef typename TypeParam::Dtype Dtype; LayerParameter layer_param; @@ -375,8 +615,36 @@ TYPED_TEST(ConvolutionLayerTest, TestGradient) { layer_param.mutable_convolution_param(); this->blob_bottom_vec_.push_back(this->blob_bottom_2_); this->blob_top_vec_.push_back(this->blob_top_2_); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); + convolution_param->set_num_output(2); + convolution_param->mutable_weight_filler()->set_type("gaussian"); + convolution_param->mutable_bias_filler()->set_type("gaussian"); + ConvolutionLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(ConvolutionLayerTest, TestGradient3D) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + vector bottom_shape(5); + bottom_shape[0] = this->blob_bottom_vec_[0]->shape(0); + bottom_shape[1] = this->blob_bottom_vec_[0]->shape(1); + bottom_shape[2] = 5; + bottom_shape[3] = this->blob_bottom_vec_[0]->shape(2); + bottom_shape[4] = this->blob_bottom_vec_[0]->shape(3); + FillerParameter filler_param; + GaussianFiller filler(filler_param); + for (int i = 0; i < this->blob_bottom_vec_.size(); ++i) { + this->blob_bottom_vec_[i]->Reshape(bottom_shape); + filler.Fill(this->blob_bottom_vec_[i]); + } + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_num_output(2); convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("gaussian"); @@ -393,8 +661,8 @@ TYPED_TEST(ConvolutionLayerTest, Test1x1Gradient) { layer_param.mutable_convolution_param(); this->blob_bottom_vec_.push_back(this->blob_bottom_2_); this->blob_top_vec_.push_back(this->blob_top_2_); - convolution_param->set_kernel_size(1); - convolution_param->set_stride(1); + convolution_param->add_kernel_size(1); + convolution_param->add_stride(1); convolution_param->set_num_output(2); convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("gaussian"); @@ -409,8 +677,8 @@ TYPED_TEST(ConvolutionLayerTest, TestGradientGroup) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_num_output(3); convolution_param->set_group(3); convolution_param->mutable_weight_filler()->set_type("gaussian"); @@ -472,8 +740,8 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestSetupCuDNN) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_num_output(4); this->blob_bottom_vec_.push_back(this->blob_bottom_2_); this->blob_top_vec_.push_back(this->blob_top_2_); @@ -509,8 +777,8 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestSimpleConvolutionCuDNN) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_num_output(4); convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("constant"); @@ -542,8 +810,8 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestSimpleConvolutionGroupCuDNN) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_num_output(3); convolution_param->set_group(3); convolution_param->mutable_weight_filler()->set_type("gaussian"); @@ -581,8 +849,8 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestSobelConvolutionCuDNN) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_num_output(1); convolution_param->set_bias_term(false); shared_ptr > layer( @@ -643,14 +911,11 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestSobelConvolutionCuDNN) { convolution_param->set_bias_term(false); layer.reset(new CuDNNConvolutionLayer(layer_param)); layer->blobs().resize(1); - layer->blobs()[0].reset(new Blob(1, 3, 1, 3)); + layer->blobs()[0].reset(new Blob(1, 1, 1, 3)); TypeParam* weights_2 = layer->blobs()[0]->mutable_cpu_data(); - for (int c = 0; c < 3; ++c) { - int i = c * 3; // 1 x 3 filter - weights_2[i + 0] = -1; - weights_2[i + 1] = 0; - weights_2[i + 2] = 1; - } + weights_2[0] = -1; + weights_2[1] = 0; + weights_2[2] = 1; layer->SetUp(sep_blob_bottom_vec, sep_blob_top_vec); layer->Forward(sep_blob_bottom_vec, sep_blob_top_vec); // Test equivalence of full and separable filters. @@ -667,8 +932,8 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestGradientCuDNN) { layer_param.mutable_convolution_param(); this->blob_bottom_vec_.push_back(this->blob_bottom_2_); this->blob_top_vec_.push_back(this->blob_top_2_); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_num_output(2); convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("gaussian"); @@ -682,8 +947,8 @@ TYPED_TEST(CuDNNConvolutionLayerTest, TestGradientGroupCuDNN) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_num_output(3); convolution_param->set_group(3); convolution_param->mutable_weight_filler()->set_type("gaussian"); diff --git a/src/caffe/test/test_deconvolution_layer.cpp b/src/caffe/test/test_deconvolution_layer.cpp index fc63d5ef..770e7b27 100644 --- a/src/caffe/test/test_deconvolution_layer.cpp +++ b/src/caffe/test/test_deconvolution_layer.cpp @@ -58,8 +58,8 @@ TYPED_TEST(DeconvolutionLayerTest, TestSetup) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_num_output(4); this->blob_bottom_vec_.push_back(this->blob_bottom_2_); this->blob_top_vec_.push_back(this->blob_top_2_); @@ -96,8 +96,8 @@ TYPED_TEST(DeconvolutionLayerTest, TestSimpleDeconvolution) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); convolution_param->set_num_output(4); convolution_param->mutable_weight_filler()->set_type("constant"); convolution_param->mutable_weight_filler()->set_value(1); @@ -144,8 +144,8 @@ TYPED_TEST(DeconvolutionLayerTest, TestGradient) { layer_param.mutable_convolution_param(); this->blob_bottom_vec_.push_back(this->blob_bottom_2_); this->blob_top_vec_.push_back(this->blob_top_2_); - convolution_param->set_kernel_size(2); - convolution_param->set_stride(1); + convolution_param->add_kernel_size(2); + convolution_param->add_stride(1); convolution_param->set_num_output(1); convolution_param->mutable_weight_filler()->set_type("gaussian"); convolution_param->mutable_bias_filler()->set_type("gaussian"); @@ -155,4 +155,151 @@ TYPED_TEST(DeconvolutionLayerTest, TestGradient) { this->blob_top_vec_); } +TYPED_TEST(DeconvolutionLayerTest, TestNDAgainst2D) { + typedef typename TypeParam::Dtype Dtype; + const int kernel_h = 11; + const int kernel_w = 13; + vector bottom_shape(4); + bottom_shape[0] = 15; + bottom_shape[1] = 12; + bottom_shape[2] = kernel_h * 2; + bottom_shape[3] = kernel_w * 2; + FillerParameter filler_param; + GaussianFiller filler(filler_param); + for (int i = 0; i < this->blob_bottom_vec_.size(); ++i) { + this->blob_bottom_vec_[i]->Reshape(bottom_shape); + filler.Fill(this->blob_bottom_vec_[i]); + } + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + convolution_param->set_num_output(18); + convolution_param->set_bias_term(false); + convolution_param->set_group(6); + convolution_param->set_kernel_h(kernel_h); + convolution_param->set_kernel_w(kernel_w); + convolution_param->mutable_weight_filler()->set_type("gaussian"); + Blob weights; + Blob top_diff; + // Shape and fill weights and top_diff. + bool copy_diff; + bool reshape; + { + DeconvolutionLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + top_diff.ReshapeLike(*this->blob_top_); + filler.Fill(&top_diff); + ASSERT_EQ(1, layer.blobs().size()); + copy_diff = false; reshape = true; + weights.CopyFrom(*layer.blobs()[0], copy_diff, reshape); + } + vector propagate_down(1, true); + Blob result_2d; + Blob backward_result_2d; + Blob backward_weight_result_2d; + // Test with 2D im2col + { + caffe_set(this->blob_top_->count(), Dtype(0), + this->blob_top_->mutable_cpu_data()); + caffe_set(this->blob_bottom_->count(), Dtype(0), + this->blob_bottom_->mutable_cpu_diff()); + caffe_set(weights.count(), Dtype(0), weights.mutable_cpu_diff()); + // Do SetUp and Forward; save Forward result in result_2d. + convolution_param->set_force_nd_im2col(false); + DeconvolutionLayer layer_2d(layer_param); + layer_2d.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(1, layer_2d.blobs().size()); + copy_diff = false; reshape = false; + layer_2d.blobs()[0]->CopyFrom(weights, copy_diff, reshape); + layer_2d.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + copy_diff = false; reshape = true; + result_2d.CopyFrom(*this->blob_top_, copy_diff, reshape); + // Copy pre-generated top diff into actual top diff; + // do Backward and save result in backward_result_2d. + ASSERT_EQ(this->blob_top_->shape(), top_diff.shape()); + caffe_copy(top_diff.count(), top_diff.cpu_data(), + this->blob_top_->mutable_cpu_diff()); + layer_2d.Backward(this->blob_top_vec_, propagate_down, + this->blob_bottom_vec_); + copy_diff = true; reshape = true; + backward_result_2d.CopyFrom(*this->blob_bottom_, copy_diff, reshape); + backward_weight_result_2d.CopyFrom(weights, copy_diff, reshape); + } + Blob result_nd; + Blob backward_result_nd; + Blob backward_weight_result_nd; + // Test with ND im2col + { + caffe_set(this->blob_top_->count(), Dtype(0), + this->blob_top_->mutable_cpu_data()); + caffe_set(this->blob_bottom_->count(), Dtype(0), + this->blob_bottom_->mutable_cpu_diff()); + caffe_set(weights.count(), Dtype(0), weights.mutable_cpu_diff()); + // Do SetUp and Forward; save Forward result in result_nd. + convolution_param->set_force_nd_im2col(true); + DeconvolutionLayer layer_nd(layer_param); + layer_nd.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(1, layer_nd.blobs().size()); + copy_diff = false; reshape = false; + layer_nd.blobs()[0]->CopyFrom(weights, copy_diff, reshape); + layer_nd.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + copy_diff = false; reshape = true; + result_nd.CopyFrom(*this->blob_top_, copy_diff, reshape); + // Copy pre-generated top diff into actual top diff; + // do Backward and save result in backward_result_nd. + ASSERT_EQ(this->blob_top_->shape(), top_diff.shape()); + caffe_copy(top_diff.count(), top_diff.cpu_data(), + this->blob_top_->mutable_cpu_diff()); + layer_nd.Backward(this->blob_top_vec_, propagate_down, + this->blob_bottom_vec_); + copy_diff = true; reshape = true; + backward_result_nd.CopyFrom(*this->blob_bottom_, copy_diff, reshape); + backward_weight_result_nd.CopyFrom(weights, copy_diff, reshape); + } + ASSERT_EQ(result_nd.count(), result_2d.count()); + for (int i = 0; i < result_2d.count(); ++i) { + EXPECT_EQ(result_2d.cpu_data()[i], result_nd.cpu_data()[i]); + } + ASSERT_EQ(backward_result_nd.count(), backward_result_2d.count()); + for (int i = 0; i < backward_result_2d.count(); ++i) { + EXPECT_EQ(backward_result_2d.cpu_diff()[i], + backward_result_nd.cpu_diff()[i]); + } + ASSERT_EQ(backward_weight_result_nd.count(), + backward_weight_result_2d.count()); + for (int i = 0; i < backward_weight_result_2d.count(); ++i) { + EXPECT_EQ(backward_weight_result_2d.cpu_diff()[i], + backward_weight_result_nd.cpu_diff()[i]); + } +} + +TYPED_TEST(DeconvolutionLayerTest, TestGradient3D) { + typedef typename TypeParam::Dtype Dtype; + vector bottom_shape(5); + bottom_shape[0] = this->blob_bottom_vec_[0]->shape(0); + bottom_shape[1] = this->blob_bottom_vec_[0]->shape(1); + bottom_shape[2] = 2; + bottom_shape[3] = 3; + bottom_shape[4] = 2; + FillerParameter filler_param; + GaussianFiller filler(filler_param); + for (int i = 0; i < this->blob_bottom_vec_.size(); ++i) { + this->blob_bottom_vec_[i]->Reshape(bottom_shape); + filler.Fill(this->blob_bottom_vec_[i]); + } + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + convolution_param->add_kernel_size(2); + convolution_param->add_stride(2); + convolution_param->add_pad(1); + convolution_param->set_num_output(2); + convolution_param->mutable_weight_filler()->set_type("gaussian"); + convolution_param->mutable_bias_filler()->set_type("gaussian"); + DeconvolutionLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + } // namespace caffe diff --git a/src/caffe/test/test_im2col_kernel.cu b/src/caffe/test/test_im2col_kernel.cu index 0017ac23..f0b75fcc 100644 --- a/src/caffe/test/test_im2col_kernel.cu +++ b/src/caffe/test/test_im2col_kernel.cu @@ -22,6 +22,12 @@ __global__ void im2col_gpu_kernel(const int n, const Dtype* data_im, const int height_col, const int width_col, Dtype* data_col); +template +__global__ void im2col_nd_gpu_kernel(const int n, const Dtype* data_im, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + Dtype* data_col); + extern cudaDeviceProp CAFFE_TEST_CUDA_PROP; template @@ -30,11 +36,18 @@ class Im2colKernelTest : public GPUDeviceTest { Im2colKernelTest() // big so launches > 1024 threads : blob_bottom_(new Blob(5, 500, 10, 10)), + blob_kernel_shape_(new Blob()), + blob_stride_(new Blob()), + blob_pad_(new Blob()), blob_top_(new Blob()), blob_top_cpu_(new Blob()) { FillerParameter filler_param; GaussianFiller filler(filler_param); filler.Fill(this->blob_bottom_); + vector dim_blob_shape(1, 2); + blob_kernel_shape_->Reshape(dim_blob_shape); + blob_stride_->Reshape(dim_blob_shape); + blob_pad_->Reshape(dim_blob_shape); height_ = blob_bottom_->height(); width_ = blob_bottom_->width(); @@ -44,14 +57,26 @@ class Im2colKernelTest : public GPUDeviceTest { kernel_size_ = 3; height_col_ = (height_ + 2 * pad_ - kernel_size_) / stride_ + 1; width_col_ = (width_ + 2 * pad_ - kernel_size_) / stride_ + 1; + + for (int i = 0; i < 2; ++i) { + blob_kernel_shape_->mutable_cpu_data()[i] = kernel_size_; + blob_stride_->mutable_cpu_data()[i] = stride_; + blob_pad_->mutable_cpu_data()[i] = pad_; + } } virtual ~Im2colKernelTest() { - delete blob_bottom_; - delete blob_top_; - delete blob_top_cpu_; + delete blob_bottom_; + delete blob_top_; + delete blob_top_cpu_; + delete blob_kernel_shape_; + delete blob_stride_; + delete blob_pad_; } + Blob* const blob_kernel_shape_; + Blob* const blob_stride_; + Blob* const blob_pad_; Blob* const blob_bottom_; Blob* const blob_top_; Blob* const blob_top_cpu_; @@ -67,7 +92,7 @@ class Im2colKernelTest : public GPUDeviceTest { TYPED_TEST_CASE(Im2colKernelTest, TestDtypes); -TYPED_TEST(Im2colKernelTest, TestGPU) { +TYPED_TEST(Im2colKernelTest, Test2D) { // Reshape the blobs to correct size for im2col output this->blob_top_->Reshape(this->blob_bottom_->num(), this->channels_ * this->kernel_size_ * this->kernel_size_, @@ -122,4 +147,58 @@ TYPED_TEST(Im2colKernelTest, TestGPU) { } } +TYPED_TEST(Im2colKernelTest, TestND) { + // Reshape the blobs to correct size for im2col output + this->blob_top_->Reshape(this->blob_bottom_->num(), + this->channels_ * this->kernel_size_ * this->kernel_size_, + this->height_col_, + this->width_col_); + + this->blob_top_cpu_->ReshapeLike(*this->blob_top_); + + const TypeParam* bottom_data_cpu = this->blob_bottom_->cpu_data(); + TypeParam* top_data_cpu = this->blob_top_cpu_->mutable_cpu_data(); + + // CPU Version + for (int n = 0; n < this->blob_bottom_->num(); ++n) { + im2col_nd_cpu(bottom_data_cpu + this->blob_bottom_->offset(n), 2, + this->blob_bottom_->shape().data() + 1, + this->blob_top_cpu_->shape().data() + 1, + this->blob_kernel_shape_->cpu_data(), + this->blob_pad_->cpu_data(), this->blob_stride_->cpu_data(), + top_data_cpu + this->blob_top_cpu_->offset(n)); + } + + // GPU version + int num_kernels = this->channels_ * this->height_col_ * this->width_col_; + int default_grid_dim = CAFFE_GET_BLOCKS(num_kernels); + const TypeParam* bottom_data_gpu = this->blob_bottom_->gpu_data(); + + // Launch with different grid sizes + for (int grid_div = 2; grid_div <= 8; grid_div++) { + for (int n = 0; n < this->blob_bottom_->num(); ++n) { + const int grid_dim = default_grid_dim / grid_div; + TypeParam* top_data_gpu = this->blob_top_->mutable_gpu_data(); + // NOLINT_NEXT_LINE(whitespace/operators) + im2col_nd_gpu_kernel<<>>( + num_kernels, bottom_data_gpu + this->blob_bottom_->offset(n), + this->blob_bottom_->gpu_shape() + 1, this->blob_top_->gpu_shape() + 1, + this->blob_kernel_shape_->gpu_data(), this->blob_pad_->gpu_data(), + this->blob_stride_->gpu_data(), + top_data_gpu + this->blob_top_->offset(n)); + CUDA_POST_KERNEL_CHECK; + } + + // Compare results against CPU version + for (int i = 0; i < this->blob_top_->count(); ++i) { + TypeParam cpuval = top_data_cpu[i]; + TypeParam gpuval = this->blob_top_->cpu_data()[i]; + EXPECT_EQ(cpuval, gpuval); + if (cpuval != gpuval) { + break; + } + } + } +} + } // namespace caffe diff --git a/src/caffe/test/test_im2col_layer.cpp b/src/caffe/test/test_im2col_layer.cpp index f50abe10..293aa262 100644 --- a/src/caffe/test/test_im2col_layer.cpp +++ b/src/caffe/test/test_im2col_layer.cpp @@ -21,6 +21,7 @@ class Im2colLayerTest : public MultiDeviceTest { : blob_bottom_(new Blob(2, 3, 6, 5)), blob_top_(new Blob()) { // fill the values + Caffe::set_random_seed(1701); FillerParameter filler_param; GaussianFiller filler(filler_param); filler.Fill(this->blob_bottom_); @@ -41,8 +42,8 @@ TYPED_TEST(Im2colLayerTest, TestSetup) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); Im2colLayer layer(layer_param); layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); EXPECT_EQ(this->blob_top_->num(), 2); @@ -56,8 +57,8 @@ TYPED_TEST(Im2colLayerTest, TestForward) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); Im2colLayer layer(layer_param); layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); @@ -73,14 +74,27 @@ TYPED_TEST(Im2colLayerTest, TestGradient) { LayerParameter layer_param; ConvolutionParameter* convolution_param = layer_param.mutable_convolution_param(); - convolution_param->set_kernel_size(3); - convolution_param->set_stride(2); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); Im2colLayer layer(layer_param); GradientChecker checker(1e-2, 1e-2); checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, this->blob_top_vec_); } +TYPED_TEST(Im2colLayerTest, TestGradientForceND) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + convolution_param->add_kernel_size(3); + convolution_param->add_stride(2); + convolution_param->set_force_nd_im2col(true); + Im2colLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} TYPED_TEST(Im2colLayerTest, TestRect) { typedef typename TypeParam::Dtype Dtype; @@ -89,7 +103,7 @@ TYPED_TEST(Im2colLayerTest, TestRect) { layer_param.mutable_convolution_param(); convolution_param->set_kernel_h(5); convolution_param->set_kernel_w(3); - convolution_param->set_stride(2); + convolution_param->add_stride(2); Im2colLayer layer(layer_param); layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); @@ -108,7 +122,7 @@ TYPED_TEST(Im2colLayerTest, TestRectGradient) { layer_param.mutable_convolution_param(); convolution_param->set_kernel_h(5); convolution_param->set_kernel_w(3); - convolution_param->set_stride(2); + convolution_param->add_stride(2); Im2colLayer layer(layer_param); GradientChecker checker(1e-2, 1e-2); checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, diff --git a/src/caffe/util/im2col.cpp b/src/caffe/util/im2col.cpp index c48f31f3..b0a7be50 100644 --- a/src/caffe/util/im2col.cpp +++ b/src/caffe/util/im2col.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include "caffe/util/im2col.hpp" #include "caffe/util/math_functions.hpp" @@ -44,6 +45,98 @@ template void im2col_cpu(const double* data_im, const int channels, const int pad_h, const int pad_w, const int stride_h, const int stride_w, double* data_col); +template +inline void im2col_nd_core_cpu(const Dtype* data_input, const bool im2col, + const int num_spatial_axes, const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + Dtype* data_output) { + if (!im2col) { + int im_size = im_shape[0]; + for (int i = 0; i < num_spatial_axes; ++i) { + im_size *= im_shape[1 + i]; + } + caffe_set(im_size, Dtype(0), data_output); + } + int kernel_size = 1; + for (int i = 0; i < num_spatial_axes; ++i) { + kernel_size *= kernel_shape[i]; + } + const int channels_col = col_shape[0]; + vector d_offset(num_spatial_axes, 0); + vector d_iter(num_spatial_axes, 0); + for (int c = 0; c < channels_col; ++c) { + // Loop over spatial axes in reverse order to compute a per-axis offset. + int offset = c; + for (int d_i = num_spatial_axes - 1; d_i >= 0; --d_i) { + if (d_i < num_spatial_axes - 1) { + offset /= kernel_shape[d_i + 1]; + } + d_offset[d_i] = offset % kernel_shape[d_i]; + } + for (bool incremented = true; incremented; ) { + // Loop over spatial axes in forward order to compute the indices in the + // image and column, and whether the index lies in the padding. + int index_col = c; + int index_im = c / kernel_size; + bool is_padding = false; + for (int d_i = 0; d_i < num_spatial_axes; ++d_i) { + const int d = d_iter[d_i]; + const int d_pad = d * stride[d_i] - pad[d_i] + d_offset[d_i]; + is_padding |= d_pad < 0 || d_pad >= im_shape[d_i + 1]; + index_col *= col_shape[d_i + 1]; + index_col += d; + index_im *= im_shape[d_i + 1]; + index_im += d_pad; + } + if (im2col) { + if (is_padding) { + data_output[index_col] = 0; + } else { + data_output[index_col] = data_input[index_im]; + } + } else if (!is_padding) { // col2im + data_output[index_im] += data_input[index_col]; + } + // Loop over spatial axes in reverse order to choose an index, + // like counting. + incremented = false; + for (int d_i = num_spatial_axes - 1; d_i >= 0; --d_i) { + const int d_max = col_shape[d_i + 1]; + DCHECK_LT(d_iter[d_i], d_max); + if (d_iter[d_i] == d_max - 1) { + d_iter[d_i] = 0; + } else { // d_iter[d_i] < d_max - 1 + ++d_iter[d_i]; + incremented = true; + break; + } + } + } // while(incremented) { + } // for (int c = 0; c < channels_col; ++c) { +} + +template +void im2col_nd_cpu(const Dtype* data_im, const int num_spatial_axes, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + Dtype* data_col) { + const bool kIm2Col = true; + im2col_nd_core_cpu(data_im, kIm2Col, num_spatial_axes, im_shape, col_shape, + kernel_shape, pad, stride, data_col); +} + +// Explicit instantiation +template void im2col_nd_cpu(const float* data_im, + const int num_spatial_axes, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + float* data_col); +template void im2col_nd_cpu(const double* data_im, + const int num_spatial_axes, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + double* data_col); + template void col2im_cpu(const Dtype* data_col, const int channels, const int height, const int width, const int patch_h, const int patch_w, @@ -80,4 +173,27 @@ template void col2im_cpu(const double* data_col, const int channels, const int pad_h, const int pad_w, const int stride_h, const int stride_w, double* data_im); +template +void col2im_nd_cpu(const Dtype* data_col, const int num_spatial_axes, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + Dtype* data_im) { + const bool kIm2Col = false; + im2col_nd_core_cpu(data_col, kIm2Col, num_spatial_axes, im_shape, col_shape, + kernel_shape, pad, stride, data_im); +} + +// Explicit instantiation +template void col2im_nd_cpu(const float* data_col, + const int num_spatial_axes, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + float* data_im); +template void col2im_nd_cpu(const double* data_col, + const int num_spatial_axes, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + double* data_im); + + } // namespace caffe diff --git a/src/caffe/util/im2col.cu b/src/caffe/util/im2col.cu index c90f93eb..5a478ba6 100644 --- a/src/caffe/util/im2col.cu +++ b/src/caffe/util/im2col.cu @@ -59,7 +59,6 @@ void im2col_gpu(const Dtype* data_im, const int channels, CUDA_POST_KERNEL_CHECK; } - // Explicit instantiation template void im2col_gpu(const float* data_im, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, @@ -70,6 +69,156 @@ template void im2col_gpu(const double* data_im, const int channels, const int pad_h, const int pad_w, const int stride_h, const int stride_w, double* data_col); +template +__global__ void im2col_nd_gpu_kernel(const int n, const Dtype* data_im, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + Dtype* data_col) { + int d_temp[num_axes]; // NOLINT(runtime/arrays) + int d_iter[num_axes]; // NOLINT(runtime/arrays) + int i; + CUDA_KERNEL_LOOP(index, n) { + // Initialize channel_in, computed in the loop below, with intermediate + // computations used to compute the spatial indices. + int channel_in = index; + int channel_out = 1; + for (i = num_axes - 1; i >= 0; --i) { + d_temp[i] = channel_in % col_shape[i + 1]; + channel_in /= col_shape[i + 1]; + channel_out *= kernel_shape[i]; + } + channel_out *= channel_in; + int data_col_inc = 1; + for (i = 0; i < num_axes; ++i) { + channel_out *= col_shape[i + 1]; + channel_out += d_temp[i]; + d_temp[i] = d_temp[i] * stride[i] - pad[i]; + channel_in *= im_shape[i + 1]; + channel_in += d_temp[i]; + data_col_inc *= col_shape[i + 1]; + d_iter[i] = 0; + } + Dtype* data_col_ptr = data_col + channel_out; + const Dtype* data_im_ptr = data_im + channel_in; + bool incremented; + do { + bool in_range = true; + for (i = 0; i < num_axes; ++i) { + const int d_iter_im = d_iter[i] + d_temp[i]; + in_range &= d_iter_im >= 0 && d_iter_im < im_shape[i + 1]; + if (!in_range) { break; } + } + if (in_range) { + int data_im_offset = d_iter[0]; + for (i = 1; i < num_axes; ++i) { + data_im_offset *= im_shape[i + 1]; + data_im_offset += d_iter[i]; + } + *data_col_ptr = data_im_ptr[data_im_offset]; + } else { + *data_col_ptr = 0; + } + data_col_ptr += data_col_inc; + incremented = false; + for (i = num_axes - 1; i >= 0; --i) { + const int d_max = kernel_shape[i]; + if (d_iter[i] == d_max - 1) { + d_iter[i] = 0; + } else { // d_iter[i] < d_max - 1 + ++d_iter[i]; + incremented = true; + break; + } + } // for (int i = num_axes - 1; i >= 0; --i) + } while (incremented); // do + } // CUDA_KERNEL_LOOP(index, n) +} + +template +void im2col_nd_gpu(const Dtype* data_im, const int num_spatial_axes, + const int num_kernels, const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + Dtype* data_col) { + switch (num_spatial_axes) { + case 1: + im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + num_kernels, data_im, im_shape, col_shape, + kernel_shape, pad, stride, data_col); + break; + case 2: + im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + num_kernels, data_im, im_shape, col_shape, + kernel_shape, pad, stride, data_col); + break; + case 3: + im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + num_kernels, data_im, im_shape, col_shape, + kernel_shape, pad, stride, data_col); + break; + case 4: + im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + num_kernels, data_im, im_shape, col_shape, + kernel_shape, pad, stride, data_col); + break; + case 5: + im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + num_kernels, data_im, im_shape, col_shape, + kernel_shape, pad, stride, data_col); + break; + case 6: + im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + num_kernels, data_im, im_shape, col_shape, + kernel_shape, pad, stride, data_col); + break; + case 7: + im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + num_kernels, data_im, im_shape, col_shape, + kernel_shape, pad, stride, data_col); + break; + case 8: + im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + num_kernels, data_im, im_shape, col_shape, + kernel_shape, pad, stride, data_col); + break; + case 9: + im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + num_kernels, data_im, im_shape, col_shape, + kernel_shape, pad, stride, data_col); + break; + case 10: + im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + num_kernels, data_im, im_shape, col_shape, + kernel_shape, pad, stride, data_col); + break; + default: + LOG(FATAL) << "im2col_nd_gpu does not support computation with " + << num_spatial_axes << " spatial axes"; + } + CUDA_POST_KERNEL_CHECK; +} + +// Explicit instantiation +template void im2col_nd_gpu(const float* data_im, + const int num_spatial_axes, const int col_size, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + float* data_col); +template void im2col_nd_gpu(const double* data_im, + const int num_spatial_axes, const int col_size, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + double* data_col); + template __global__ void col2im_gpu_kernel(const int n, const Dtype* data_col, const int height, const int width, const int channels, @@ -141,4 +290,159 @@ template void col2im_gpu(const double* data_col, const int channels, const int pad_h, const int pad_w, const int stride_h, const int stride_w, double* data_im); +template +__global__ void col2im_nd_gpu_kernel(const int n, const Dtype* data_col, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + Dtype* data_im) { + int d_im[num_axes]; // NOLINT(runtime/arrays) + int d_col_iter[num_axes]; // NOLINT(runtime/arrays) + int d_col_start[num_axes]; // NOLINT(runtime/arrays) + int d_col_end[num_axes]; // NOLINT(runtime/arrays) + CUDA_KERNEL_LOOP(index, n) { + // Initialize channel_in, computed in the loop below, with intermediate + // computations used to compute the spatial indices. + int channel_im = index; + // Calculate d_im (image dimensions). + for (int i = num_axes - 1; i >= 0; --i) { + d_im[i] = channel_im % im_shape[i + 1] + pad[i]; + channel_im /= im_shape[i + 1]; + } + // Calculate col start/end indices. + bool done = false; + for (int i = 0; i < num_axes; ++i) { + d_col_start[i] = d_col_iter[i] = + (d_im[i] < kernel_shape[i]) ? + 0 : (d_im[i] - kernel_shape[i]) / stride[i] + 1; + d_col_end[i] = min(d_im[i] / stride[i] + 1, col_shape[i + 1]); + if (d_col_start[i] >= d_col_end[i]) { + // Skip computation if the dimension is 0 at any spatial axis -- + // final val will be 0. + data_im[index] = 0; + done = true; + break; // for (int i = 0; i < num_axes; ++i) + } + } + if (done) { + continue; // CUDA_KERNEL_LOOP(index, n) + } + // Loop over the col to compute the output val. + Dtype val = 0; + bool incremented = true; + do { + // Compute the final offset. + int final_offset = 0; + int kernel_shape_prod = 1; + for (int i = num_axes - 1; i >= 0; --i) { + final_offset += + (d_im[i] - d_col_iter[i] * stride[i]) * kernel_shape_prod; + kernel_shape_prod *= kernel_shape[i]; + } + final_offset += kernel_shape_prod * channel_im; + for (int i = 0; i < num_axes; ++i) { + final_offset *= col_shape[i + 1]; + final_offset += d_col_iter[i]; + } + val += data_col[final_offset]; + incremented = false; + for (int i = num_axes - 1; i >= 0; --i) { + const int d_max = d_col_end[i]; + if (d_col_iter[i] == d_max - 1) { + d_col_iter[i] = d_col_start[i]; + } else { // d_col_iter[i] < d_max - 1 + ++d_col_iter[i]; + incremented = true; + break; // for (int i = num_axes - 1; i >= 0; --i) + } + } // for (int i = num_axes - 1; i >= 0; --i) + } while (incremented); + data_im[index] = val; + } // CUDA_KERNEL_LOOP(index, n) +} + +template +void col2im_nd_gpu(const Dtype* data_col, const int num_spatial_axes, + const int im_size, const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + Dtype* data_im) { + switch (num_spatial_axes) { + case 1: + col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + im_size, data_col, im_shape, col_shape, + kernel_shape, pad, stride, data_im); + break; + case 2: + col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + im_size, data_col, im_shape, col_shape, + kernel_shape, pad, stride, data_im); + break; + case 3: + col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + im_size, data_col, im_shape, col_shape, + kernel_shape, pad, stride, data_im); + break; + case 4: + col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + im_size, data_col, im_shape, col_shape, + kernel_shape, pad, stride, data_im); + break; + case 5: + col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + im_size, data_col, im_shape, col_shape, + kernel_shape, pad, stride, data_im); + break; + case 6: + col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + im_size, data_col, im_shape, col_shape, + kernel_shape, pad, stride, data_im); + break; + case 7: + col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + im_size, data_col, im_shape, col_shape, + kernel_shape, pad, stride, data_im); + break; + case 8: + col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + im_size, data_col, im_shape, col_shape, + kernel_shape, pad, stride, data_im); + break; + case 9: + col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + im_size, data_col, im_shape, col_shape, + kernel_shape, pad, stride, data_im); + break; + case 10: + col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + im_size, data_col, im_shape, col_shape, + kernel_shape, pad, stride, data_im); + break; + default: + LOG(FATAL) << "col2im_nd_gpu does not support computation with " + << num_spatial_axes << " spatial axes"; + } + CUDA_POST_KERNEL_CHECK; +} + +// Explicit instantiation +template void col2im_nd_gpu(const float* data_col, + const int num_spatial_axes, const int im_size, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + float* data_im); +template void col2im_nd_gpu(const double* data_col, + const int num_spatial_axes, const int im_size, + const int* im_shape, const int* col_shape, + const int* kernel_shape, const int* pad, const int* stride, + double* data_im); + } // namespace caffe diff --git a/src/caffe/util/upgrade_proto.cpp b/src/caffe/util/upgrade_proto.cpp index 92e5cf55..ac379e50 100644 --- a/src/caffe/util/upgrade_proto.cpp +++ b/src/caffe/util/upgrade_proto.cpp @@ -193,7 +193,7 @@ bool UpgradeV0LayerParameter(const V1LayerParameter& v0_layer_connection, } if (v0_layer_param.has_pad()) { if (type == "conv") { - layer_param->mutable_convolution_param()->set_pad(v0_layer_param.pad()); + layer_param->mutable_convolution_param()->add_pad(v0_layer_param.pad()); } else if (type == "pool") { layer_param->mutable_pooling_param()->set_pad(v0_layer_param.pad()); } else { @@ -203,7 +203,7 @@ bool UpgradeV0LayerParameter(const V1LayerParameter& v0_layer_connection, } if (v0_layer_param.has_kernelsize()) { if (type == "conv") { - layer_param->mutable_convolution_param()->set_kernel_size( + layer_param->mutable_convolution_param()->add_kernel_size( v0_layer_param.kernelsize()); } else if (type == "pool") { layer_param->mutable_pooling_param()->set_kernel_size( @@ -224,7 +224,7 @@ bool UpgradeV0LayerParameter(const V1LayerParameter& v0_layer_connection, } if (v0_layer_param.has_stride()) { if (type == "conv") { - layer_param->mutable_convolution_param()->set_stride( + layer_param->mutable_convolution_param()->add_stride( v0_layer_param.stride()); } else if (type == "pool") { layer_param->mutable_pooling_param()->set_stride( -- cgit v1.2.3