diff options
Diffstat (limited to 'src/caffe/layers/base_conv_layer.cpp')
-rw-r--r-- | src/caffe/layers/base_conv_layer.cpp | 241 |
1 files changed, 166 insertions, 75 deletions
diff --git a/src/caffe/layers/base_conv_layer.cpp b/src/caffe/layers/base_conv_layer.cpp index ccb3adc7..a5b90a54 100644 --- a/src/caffe/layers/base_conv_layer.cpp +++ b/src/caffe/layers/base_conv_layer.cpp @@ -1,3 +1,4 @@ +#include <algorithm> #include <vector> #include "caffe/filler.hpp" @@ -11,50 +12,103 @@ namespace caffe { template <typename Dtype> void BaseConvolutionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { - CHECK_EQ(4, bottom[0]->num_axes()) << "Input must have 4 axes, " - << "corresponding to (num, channels, height, width)"; // Configure the kernel size, padding, stride, and inputs. ConvolutionParameter conv_param = this->layer_param_.convolution_param(); - CHECK(!conv_param.has_kernel_size() != - !(conv_param.has_kernel_h() && conv_param.has_kernel_w())) - << "Filter size is kernel_size OR kernel_h and kernel_w; not both"; - CHECK(conv_param.has_kernel_size() || - (conv_param.has_kernel_h() && conv_param.has_kernel_w())) - << "For non-square filters both kernel_h and kernel_w are required."; - CHECK((!conv_param.has_pad() && conv_param.has_pad_h() - && conv_param.has_pad_w()) - || (!conv_param.has_pad_h() && !conv_param.has_pad_w())) - << "pad is pad OR pad_h and pad_w are required."; - CHECK((!conv_param.has_stride() && conv_param.has_stride_h() - && conv_param.has_stride_w()) - || (!conv_param.has_stride_h() && !conv_param.has_stride_w())) - << "Stride is stride OR stride_h and stride_w are required."; - if (conv_param.has_kernel_size()) { - kernel_h_ = kernel_w_ = conv_param.kernel_size(); + force_nd_im2col_ = conv_param.force_nd_im2col(); + channel_axis_ = bottom[0]->CanonicalAxisIndex(conv_param.axis()); + const int first_spatial_axis = channel_axis_ + 1; + const int num_axes = bottom[0]->num_axes(); + num_spatial_axes_ = num_axes - first_spatial_axis; + CHECK_GE(num_spatial_axes_, 0); + // Setup input dimensions (input_shape_). + vector<int> bottom_dim_blob_shape(1, num_spatial_axes_ + 1); + input_shape_.Reshape(bottom_dim_blob_shape); + int* input_shape_data = input_shape_.mutable_cpu_data(); + for (int i = 0; i < num_spatial_axes_ + 1; ++i) { + input_shape_data[i] = bottom[0]->shape(channel_axis_ + i); + } + vector<int> spatial_dim_blob_shape(1, std::max(num_spatial_axes_, 1)); + // Setup filter kernel dimensions (kernel_shape_). + kernel_shape_.Reshape(spatial_dim_blob_shape); + int* kernel_shape_data = kernel_shape_.mutable_cpu_data(); + if (conv_param.has_kernel_h() || conv_param.has_kernel_w()) { + CHECK_EQ(num_spatial_axes_, 2) + << "kernel_h & kernel_w can only be used for 2D convolution."; + CHECK_EQ(0, conv_param.kernel_size_size()) + << "Either kernel_size or kernel_h/w should be specified; not both."; + kernel_shape_data[0] = conv_param.kernel_h(); + kernel_shape_data[1] = conv_param.kernel_w(); } else { - kernel_h_ = conv_param.kernel_h(); - kernel_w_ = conv_param.kernel_w(); + const int num_kernel_dims = conv_param.kernel_size_size(); + CHECK(num_kernel_dims == 1 || num_kernel_dims == num_spatial_axes_) + << "kernel_size must be specified once, or once per spatial dimension " + << "(kernel_size specified " << num_kernel_dims << " times; " + << num_spatial_axes_ << " spatial dims);"; + for (int i = 0; i < num_spatial_axes_; ++i) { + kernel_shape_data[i] = + conv_param.kernel_size((num_kernel_dims == 1) ? 0 : i); + } + } + for (int i = 0; i < num_spatial_axes_; ++i) { + CHECK_GT(kernel_shape_data[i], 0) << "Filter dimensions must be nonzero."; } - CHECK_GT(kernel_h_, 0) << "Filter dimensions cannot be zero."; - CHECK_GT(kernel_w_, 0) << "Filter dimensions cannot be zero."; - if (!conv_param.has_pad_h()) { - pad_h_ = pad_w_ = conv_param.pad(); + // Setup stride dimensions (stride_). + stride_.Reshape(spatial_dim_blob_shape); + int* stride_data = stride_.mutable_cpu_data(); + if (conv_param.has_stride_h() || conv_param.has_stride_w()) { + CHECK_EQ(num_spatial_axes_, 2) + << "stride_h & stride_w can only be used for 2D convolution."; + CHECK_EQ(0, conv_param.stride_size()) + << "Either stride or stride_h/w should be specified; not both."; + stride_data[0] = conv_param.stride_h(); + stride_data[1] = conv_param.stride_w(); } else { - pad_h_ = conv_param.pad_h(); - pad_w_ = conv_param.pad_w(); + const int num_stride_dims = conv_param.stride_size(); + CHECK(num_stride_dims == 0 || num_stride_dims == 1 || + num_stride_dims == num_spatial_axes_) + << "stride must be specified once, or once per spatial dimension " + << "(stride specified " << num_stride_dims << " times; " + << num_spatial_axes_ << " spatial dims);"; + const int kDefaultStride = 1; + for (int i = 0; i < num_spatial_axes_; ++i) { + stride_data[i] = (num_stride_dims == 0) ? kDefaultStride : + conv_param.stride((num_stride_dims == 1) ? 0 : i); + CHECK_GT(stride_data[i], 0) << "Stride dimensions must be nonzero."; + } } - if (!conv_param.has_stride_h()) { - stride_h_ = stride_w_ = conv_param.stride(); + // Setup pad dimensions (pad_). + pad_.Reshape(spatial_dim_blob_shape); + int* pad_data = pad_.mutable_cpu_data(); + if (conv_param.has_pad_h() || conv_param.has_pad_w()) { + CHECK_EQ(num_spatial_axes_, 2) + << "pad_h & pad_w can only be used for 2D convolution."; + CHECK_EQ(0, conv_param.pad_size()) + << "Either pad or pad_h/w should be specified; not both."; + pad_data[0] = conv_param.pad_h(); + pad_data[1] = conv_param.pad_w(); } else { - stride_h_ = conv_param.stride_h(); - stride_w_ = conv_param.stride_w(); + const int num_pad_dims = conv_param.pad_size(); + CHECK(num_pad_dims == 0 || num_pad_dims == 1 || + num_pad_dims == num_spatial_axes_) + << "pad must be specified once, or once per spatial dimension " + << "(pad specified " << num_pad_dims << " times; " + << num_spatial_axes_ << " spatial dims);"; + const int kDefaultPad = 0; + for (int i = 0; i < num_spatial_axes_; ++i) { + pad_data[i] = (num_pad_dims == 0) ? kDefaultPad : + conv_param.pad((num_pad_dims == 1) ? 0 : i); + } } // Special case: im2col is the identity for 1x1 convolution with stride 1 // and no padding, so flag for skipping the buffer and transformation. - is_1x1_ = kernel_w_ == 1 && kernel_h_ == 1 - && stride_h_ == 1 && stride_w_ == 1 && pad_h_ == 0 && pad_w_ == 0; + is_1x1_ = true; + for (int i = 0; i < num_spatial_axes_; ++i) { + is_1x1_ &= + kernel_shape_data[i] == 1 && stride_data[i] == 1 && pad_data[i] == 0; + if (!is_1x1_) { break; } + } // Configure output channels and groups. - channels_ = bottom[0]->channels(); + channels_ = bottom[0]->shape(channel_axis_); num_output_ = this->layer_param_.convolution_param().num_output(); CHECK_GT(num_output_, 0); group_ = this->layer_param_.convolution_param().group(); @@ -71,8 +125,29 @@ void BaseConvolutionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom, // Handle the parameters: weights and biases. // - blobs_[0] holds the filter weights // - blobs_[1] holds the biases (optional) + vector<int> weight_shape(2); + weight_shape[0] = conv_out_channels_; + weight_shape[1] = conv_in_channels_ / group_; + for (int i = 0; i < num_spatial_axes_; ++i) { + weight_shape.push_back(kernel_shape_data[i]); + } bias_term_ = this->layer_param_.convolution_param().bias_term(); + vector<int> bias_shape(bias_term_, num_output_); if (this->blobs_.size() > 0) { + CHECK_EQ(1 + bias_term_, this->blobs_.size()) + << "Incorrect number of weight blobs."; + if (weight_shape != this->blobs_[0]->shape()) { + Blob<Dtype> weight_shaped_blob(weight_shape); + LOG(FATAL) << "Incorrect weight shape: expected shape " + << weight_shaped_blob.shape_string() << "; instead, shape was " + << this->blobs_[0]->shape_string(); + } + if (bias_term_ && bias_shape != this->blobs_[1]->shape()) { + Blob<Dtype> bias_shaped_blob(bias_shape); + LOG(FATAL) << "Incorrect bias shape: expected shape " + << bias_shaped_blob.shape_string() << "; instead, shape was " + << this->blobs_[1]->shape_string(); + } LOG(INFO) << "Skipping parameter initialization"; } else { if (bias_term_) { @@ -82,20 +157,20 @@ void BaseConvolutionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom, } // Initialize and fill the weights: // output channels x input channels per-group x kernel height x kernel width - this->blobs_[0].reset(new Blob<Dtype>( - conv_out_channels_, conv_in_channels_ / group_, kernel_h_, kernel_w_)); + this->blobs_[0].reset(new Blob<Dtype>(weight_shape)); shared_ptr<Filler<Dtype> > weight_filler(GetFiller<Dtype>( this->layer_param_.convolution_param().weight_filler())); weight_filler->Fill(this->blobs_[0].get()); // If necessary, initialize and fill the biases. if (bias_term_) { - vector<int> bias_shape(1, num_output_); this->blobs_[1].reset(new Blob<Dtype>(bias_shape)); shared_ptr<Filler<Dtype> > bias_filler(GetFiller<Dtype>( this->layer_param_.convolution_param().bias_filler())); bias_filler->Fill(this->blobs_[1].get()); } } + kernel_dim_ = this->blobs_[0]->count(1); + weight_offset_ = conv_out_channels_ * kernel_dim_ / group_; // Propagate gradients to the parameters (as directed by backward pass). this->param_propagate_down_.resize(this->blobs_.size(), true); } @@ -103,52 +178,68 @@ void BaseConvolutionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom, template <typename Dtype> void BaseConvolutionLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { - CHECK_EQ(4, bottom[0]->num_axes()) << "Input must have 4 axes, " - << "corresponding to (num, channels, height, width)"; - num_ = bottom[0]->num(); - height_ = bottom[0]->height(); - width_ = bottom[0]->width(); - CHECK_EQ(bottom[0]->channels(), channels_) << "Input size incompatible with" - " convolution kernel."; + const int first_spatial_axis = channel_axis_ + 1; + CHECK_EQ(bottom[0]->num_axes(), first_spatial_axis + num_spatial_axes_) + << "bottom num_axes may not change."; + num_ = bottom[0]->count(0, channel_axis_); + CHECK_EQ(bottom[0]->shape(channel_axis_), channels_) + << "Input size incompatible with convolution kernel."; // TODO: generalize to handle inputs of different shapes. for (int bottom_id = 1; bottom_id < bottom.size(); ++bottom_id) { - CHECK_EQ(num_, bottom[bottom_id]->num()) << "Inputs must have same num."; - CHECK_EQ(channels_, bottom[bottom_id]->channels()) - << "Inputs must have same channels."; - CHECK_EQ(height_, bottom[bottom_id]->height()) - << "Inputs must have same height."; - CHECK_EQ(width_, bottom[bottom_id]->width()) - << "Inputs must have same width."; + CHECK(bottom[0]->shape() == bottom[bottom_id]->shape()) + << "All inputs must have the same shape."; } // Shape the tops. compute_output_shape(); + vector<int> top_shape(bottom[0]->shape().begin(), + bottom[0]->shape().begin() + channel_axis_); + top_shape.push_back(num_output_); + for (int i = 0; i < num_spatial_axes_; ++i) { + top_shape.push_back(output_shape_[i]); + } for (int top_id = 0; top_id < top.size(); ++top_id) { - top[top_id]->Reshape(num_, num_output_, height_out_, width_out_); + top[top_id]->Reshape(top_shape); } if (reverse_dimensions()) { - conv_in_height_ = height_out_; - conv_in_width_ = width_out_; - conv_out_spatial_dim_ = height_ * width_; + conv_out_spatial_dim_ = bottom[0]->count(first_spatial_axis); } else { - conv_in_height_ = height_; - conv_in_width_ = width_; - conv_out_spatial_dim_ = height_out_ * width_out_; + conv_out_spatial_dim_ = top[0]->count(first_spatial_axis); } - kernel_dim_ = conv_in_channels_ * kernel_h_ * kernel_w_; - weight_offset_ = conv_out_channels_ * kernel_dim_ / group_ / group_; - col_offset_ = kernel_dim_ * conv_out_spatial_dim_ / group_; + col_offset_ = kernel_dim_ * conv_out_spatial_dim_; output_offset_ = conv_out_channels_ * conv_out_spatial_dim_ / group_; + // Setup input dimensions (conv_input_shape_). + vector<int> bottom_dim_blob_shape(1, num_spatial_axes_ + 1); + conv_input_shape_.Reshape(bottom_dim_blob_shape); + int* conv_input_shape_data = conv_input_shape_.mutable_cpu_data(); + for (int i = 0; i < num_spatial_axes_ + 1; ++i) { + if (reverse_dimensions()) { + conv_input_shape_data[i] = top[0]->shape(channel_axis_ + i); + } else { + conv_input_shape_data[i] = bottom[0]->shape(channel_axis_ + i); + } + } // The im2col result buffer will only hold one image at a time to avoid // overly large memory usage. In the special case of 1x1 convolution // it goes lazily unused to save memory. - if (reverse_dimensions()) { - col_buffer_.Reshape(1, kernel_dim_, height_, width_); - } else { - col_buffer_.Reshape(1, kernel_dim_, height_out_, width_out_); + col_buffer_shape_.clear(); + col_buffer_shape_.push_back(kernel_dim_ * group_); + const int* input_shape_data = input_shape_.cpu_data() + 1; + for (int i = 0; i < num_spatial_axes_; ++i) { + if (reverse_dimensions()) { + col_buffer_shape_.push_back(input_shape_data[i]); + } else { + col_buffer_shape_.push_back(output_shape_[i]); + } } + col_buffer_.Reshape(col_buffer_shape_); + bottom_dim_ = bottom[0]->count(channel_axis_); + top_dim_ = top[0]->count(channel_axis_); + num_kernels_im2col_ = conv_in_channels_ * conv_out_spatial_dim_; + num_kernels_col2im_ = reverse_dimensions() ? top_dim_ : bottom_dim_; // Set up the all ones "bias multiplier" for adding biases by BLAS + out_spatial_dim_ = top[0]->count(first_spatial_axis); if (bias_term_) { - vector<int> bias_multiplier_shape(1, height_out_ * width_out_); + vector<int> bias_multiplier_shape(1, out_spatial_dim_); bias_multiplier_.Reshape(bias_multiplier_shape); caffe_set(bias_multiplier_.count(), Dtype(1), bias_multiplier_.mutable_cpu_data()); @@ -167,7 +258,7 @@ void BaseConvolutionLayer<Dtype>::forward_cpu_gemm(const Dtype* input, } for (int g = 0; g < group_; ++g) { caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, conv_out_channels_ / - group_, conv_out_spatial_dim_, kernel_dim_ / group_, + group_, conv_out_spatial_dim_, kernel_dim_, (Dtype)1., weights + weight_offset_ * g, col_buff + col_offset_ * g, (Dtype)0., output + output_offset_ * g); } @@ -177,7 +268,7 @@ template <typename Dtype> void BaseConvolutionLayer<Dtype>::forward_cpu_bias(Dtype* output, const Dtype* bias) { caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num_output_, - height_out_ * width_out_, 1, (Dtype)1., bias, bias_multiplier_.cpu_data(), + out_spatial_dim_, 1, (Dtype)1., bias, bias_multiplier_.cpu_data(), (Dtype)1., output); } @@ -189,7 +280,7 @@ void BaseConvolutionLayer<Dtype>::backward_cpu_gemm(const Dtype* output, col_buff = input; } for (int g = 0; g < group_; ++g) { - caffe_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans, kernel_dim_ / group_, + caffe_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans, kernel_dim_, conv_out_spatial_dim_, conv_out_channels_ / group_, (Dtype)1., weights + weight_offset_ * g, output + output_offset_ * g, (Dtype)0., col_buff + col_offset_ * g); @@ -209,7 +300,7 @@ void BaseConvolutionLayer<Dtype>::weight_cpu_gemm(const Dtype* input, } for (int g = 0; g < group_; ++g) { caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasTrans, conv_out_channels_ / group_, - kernel_dim_ / group_, conv_out_spatial_dim_, + kernel_dim_, conv_out_spatial_dim_, (Dtype)1., output + output_offset_ * g, col_buff + col_offset_ * g, (Dtype)1., weights + weight_offset_ * g); } @@ -218,7 +309,7 @@ void BaseConvolutionLayer<Dtype>::weight_cpu_gemm(const Dtype* input, template <typename Dtype> void BaseConvolutionLayer<Dtype>::backward_cpu_bias(Dtype* bias, const Dtype* input) { - caffe_cpu_gemv<Dtype>(CblasNoTrans, num_output_, height_out_ * width_out_, 1., + caffe_cpu_gemv<Dtype>(CblasNoTrans, num_output_, out_spatial_dim_, 1., input, bias_multiplier_.cpu_data(), 1., bias); } @@ -236,7 +327,7 @@ void BaseConvolutionLayer<Dtype>::forward_gpu_gemm(const Dtype* input, } for (int g = 0; g < group_; ++g) { caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, conv_out_channels_ / - group_, conv_out_spatial_dim_, kernel_dim_ / group_, + group_, conv_out_spatial_dim_, kernel_dim_, (Dtype)1., weights + weight_offset_ * g, col_buff + col_offset_ * g, (Dtype)0., output + output_offset_ * g); } @@ -246,7 +337,7 @@ template <typename Dtype> void BaseConvolutionLayer<Dtype>::forward_gpu_bias(Dtype* output, const Dtype* bias) { caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num_output_, - height_out_ * width_out_, 1, (Dtype)1., bias, bias_multiplier_.gpu_data(), + out_spatial_dim_, 1, (Dtype)1., bias, bias_multiplier_.gpu_data(), (Dtype)1., output); } @@ -258,7 +349,7 @@ void BaseConvolutionLayer<Dtype>::backward_gpu_gemm(const Dtype* output, col_buff = input; } for (int g = 0; g < group_; ++g) { - caffe_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans, kernel_dim_ / group_, + caffe_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans, kernel_dim_, conv_out_spatial_dim_, conv_out_channels_ / group_, (Dtype)1., weights + weight_offset_ * g, output + output_offset_ * g, (Dtype)0., col_buff + col_offset_ * g); @@ -278,7 +369,7 @@ void BaseConvolutionLayer<Dtype>::weight_gpu_gemm(const Dtype* input, } for (int g = 0; g < group_; ++g) { caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasTrans, conv_out_channels_ / group_, - kernel_dim_ / group_, conv_out_spatial_dim_, + kernel_dim_, conv_out_spatial_dim_, (Dtype)1., output + output_offset_ * g, col_buff + col_offset_ * g, (Dtype)1., weights + weight_offset_ * g); } @@ -287,7 +378,7 @@ void BaseConvolutionLayer<Dtype>::weight_gpu_gemm(const Dtype* input, template <typename Dtype> void BaseConvolutionLayer<Dtype>::backward_gpu_bias(Dtype* bias, const Dtype* input) { - caffe_gpu_gemv<Dtype>(CblasNoTrans, num_output_, height_out_ * width_out_, 1., + caffe_gpu_gemv<Dtype>(CblasNoTrans, num_output_, out_spatial_dim_, 1., input, bias_multiplier_.gpu_data(), 1., bias); } |