summaryrefslogtreecommitdiff
path: root/src/caffe/layers
diff options
context:
space:
mode:
authorJonathan L Long <jonlong@cs.berkeley.edu>2014-12-21 19:42:29 -0800
committerJonathan L Long <jonlong@cs.berkeley.edu>2015-01-11 00:28:44 -0800
commite3e2f2d3139880f77355e6837e72ad6c2848b448 (patch)
tree0430b5a351a3f81f85912872392f1ceef89c1ac5 /src/caffe/layers
parenta0e9db1347c325ff007166e79d1ca693e2e5de18 (diff)
downloadcaffeonacl-e3e2f2d3139880f77355e6837e72ad6c2848b448.tar.gz
caffeonacl-e3e2f2d3139880f77355e6837e72ad6c2848b448.tar.bz2
caffeonacl-e3e2f2d3139880f77355e6837e72ad6c2848b448.zip
rewrite ConvolutionLayer to use BaseConvolutionLayer helpers
Diffstat (limited to 'src/caffe/layers')
-rw-r--r--src/caffe/layers/conv_layer.cpp243
-rw-r--r--src/caffe/layers/conv_layer.cu117
2 files changed, 52 insertions, 308 deletions
diff --git a/src/caffe/layers/conv_layer.cpp b/src/caffe/layers/conv_layer.cpp
index 0a032025..9fd2fc6a 100644
--- a/src/caffe/layers/conv_layer.cpp
+++ b/src/caffe/layers/conv_layer.cpp
@@ -9,166 +9,26 @@
namespace caffe {
template <typename Dtype>
-void ConvolutionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
- const vector<Blob<Dtype>*>& top) {
- // Configure the kernel size, padding, stride, and inputs.
- ConvolutionParameter conv_param = this->layer_param_.convolution_param();
- CHECK(!conv_param.has_kernel_size() !=
- !(conv_param.has_kernel_h() && conv_param.has_kernel_w()))
- << "Filter size is kernel_size OR kernel_h and kernel_w; not both";
- CHECK(conv_param.has_kernel_size() ||
- (conv_param.has_kernel_h() && conv_param.has_kernel_w()))
- << "For non-square filters both kernel_h and kernel_w are required.";
- CHECK((!conv_param.has_pad() && conv_param.has_pad_h()
- && conv_param.has_pad_w())
- || (!conv_param.has_pad_h() && !conv_param.has_pad_w()))
- << "pad is pad OR pad_h and pad_w are required.";
- CHECK((!conv_param.has_stride() && conv_param.has_stride_h()
- && conv_param.has_stride_w())
- || (!conv_param.has_stride_h() && !conv_param.has_stride_w()))
- << "Stride is stride OR stride_h and stride_w are required.";
- if (conv_param.has_kernel_size()) {
- kernel_h_ = kernel_w_ = conv_param.kernel_size();
- } else {
- kernel_h_ = conv_param.kernel_h();
- kernel_w_ = conv_param.kernel_w();
- }
- CHECK_GT(kernel_h_, 0) << "Filter dimensions cannot be zero.";
- CHECK_GT(kernel_w_, 0) << "Filter dimensions cannot be zero.";
- if (!conv_param.has_pad_h()) {
- pad_h_ = pad_w_ = conv_param.pad();
- } else {
- pad_h_ = conv_param.pad_h();
- pad_w_ = conv_param.pad_w();
- }
- if (!conv_param.has_stride_h()) {
- stride_h_ = stride_w_ = conv_param.stride();
- } else {
- stride_h_ = conv_param.stride_h();
- stride_w_ = conv_param.stride_w();
- }
- // Special case: im2col is the identity for 1x1 convolution with stride 1
- // and no padding, so flag for skipping the buffer and transformation.
- is_1x1_ = kernel_w_ == 1 && kernel_h_ == 1
- && stride_h_ == 1 && stride_w_ == 1 && pad_h_ == 0 && pad_w_ == 0;
- // Configure output channels and groups.
- channels_ = bottom[0]->channels();
- num_output_ = this->layer_param_.convolution_param().num_output();
- CHECK_GT(num_output_, 0);
- group_ = this->layer_param_.convolution_param().group();
- CHECK_EQ(channels_ % group_, 0);
- CHECK_EQ(num_output_ % group_, 0)
- << "Number of output should be multiples of group.";
- // Handle the parameters: weights and biases.
- // - blobs_[0] holds the filter weights
- // - blobs_[1] holds the biases (optional)
- bias_term_ = this->layer_param_.convolution_param().bias_term();
- if (this->blobs_.size() > 0) {
- LOG(INFO) << "Skipping parameter initialization";
- } else {
- if (bias_term_) {
- this->blobs_.resize(2);
- } else {
- this->blobs_.resize(1);
- }
- // Initialize and fill the weights:
- // output channels x input channels per-group x kernel height x kernel width
- this->blobs_[0].reset(new Blob<Dtype>(
- num_output_, channels_ / group_, kernel_h_, kernel_w_));
- shared_ptr<Filler<Dtype> > weight_filler(GetFiller<Dtype>(
- this->layer_param_.convolution_param().weight_filler()));
- weight_filler->Fill(this->blobs_[0].get());
- // If necessary, initialize and fill the biases:
- // 1 x 1 x 1 x output channels
- if (bias_term_) {
- this->blobs_[1].reset(new Blob<Dtype>(1, 1, 1, num_output_));
- shared_ptr<Filler<Dtype> > bias_filler(GetFiller<Dtype>(
- this->layer_param_.convolution_param().bias_filler()));
- bias_filler->Fill(this->blobs_[1].get());
- }
- }
- // Propagate gradients to the parameters (as directed by backward pass).
- this->param_propagate_down_.resize(this->blobs_.size(), true);
-}
-
-template <typename Dtype>
-void ConvolutionLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
- const vector<Blob<Dtype>*>& top) {
- num_ = bottom[0]->num();
- height_ = bottom[0]->height();
- width_ = bottom[0]->width();
- CHECK_EQ(bottom[0]->channels(), channels_) << "Input size incompatible with"
- " convolution kernel.";
- // TODO: generalize to handle inputs of different shapes.
- for (int bottom_id = 1; bottom_id < bottom.size(); ++bottom_id) {
- CHECK_EQ(num_, bottom[bottom_id]->num()) << "Inputs must have same num.";
- CHECK_EQ(channels_, bottom[bottom_id]->channels())
- << "Inputs must have same channels.";
- CHECK_EQ(height_, bottom[bottom_id]->height())
- << "Inputs must have same height.";
- CHECK_EQ(width_, bottom[bottom_id]->width())
- << "Inputs must have same width.";
- }
- // Shape the tops.
- height_out_ =
- (height_ + 2 * pad_h_ - kernel_h_) / stride_h_ + 1;
- width_out_ = (width_ + 2 * pad_w_ - kernel_w_) / stride_w_ + 1;
- for (int top_id = 0; top_id < top.size(); ++top_id) {
- top[top_id]->Reshape(num_, num_output_, height_out_, width_out_);
- }
- // Prepare the matrix multiplication computation.
- // Each input will be convolved as a single GEMM.
- M_ = num_output_ / group_;
- K_ = channels_ * kernel_h_ * kernel_w_ / group_;
- N_ = height_out_ * width_out_;
- // The im2col result buffer will only hold one image at a time to avoid
- // overly large memory usage. In the special case of 1x1 convolution
- // it goes lazily unused to save memory.
- col_buffer_.Reshape(
- 1, channels_ * kernel_h_ * kernel_w_, height_out_, width_out_);
- // Set up the all ones "bias multiplier" for adding biases by BLAS
- if (bias_term_) {
- bias_multiplier_.Reshape(1, 1, 1, N_);
- caffe_set(N_, Dtype(1), bias_multiplier_.mutable_cpu_data());
- }
+void ConvolutionLayer<Dtype>::compute_output_shape() {
+ this->height_out_ = (this->height_ + 2 * this->pad_h_ - this->kernel_h_)
+ / this->stride_h_ + 1;
+ this->width_out_ = (this->width_ + 2 * this->pad_w_ - this->kernel_w_)
+ / this->stride_w_ + 1;
}
template <typename Dtype>
void ConvolutionLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
+ const Dtype* weight = this->blobs_[0]->cpu_data();
for (int i = 0; i < bottom.size(); ++i) {
const Dtype* bottom_data = bottom[i]->cpu_data();
Dtype* top_data = top[i]->mutable_cpu_data();
- Dtype* col_buff = NULL;
- if (!is_1x1_) {
- col_buff = col_buffer_.mutable_cpu_data();
- }
- const Dtype* weight = this->blobs_[0]->cpu_data();
- int weight_offset = M_ * K_; // number of filter parameters in a group
- int col_offset = K_ * N_; // number of values in an input region / column
- int top_offset = M_ * N_; // number of values in an output region / column
- for (int n = 0; n < num_; ++n) {
- // im2col transformation: unroll input regions for filtering
- // into column matrix for multplication.
- if (!is_1x1_) {
- im2col_cpu(bottom_data + bottom[i]->offset(n), channels_, height_,
- width_, kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_,
- col_buff);
- } else { // special case for 1x1 convolution
- col_buff = bottom[i]->mutable_cpu_data() + bottom[i]->offset(n);
- }
- // Take inner products for groups.
- for (int g = 0; g < group_; ++g) {
- caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_,
- (Dtype)1., weight + weight_offset * g, col_buff + col_offset * g,
- (Dtype)0., top_data + top[i]->offset(n) + top_offset * g);
- }
- // Add bias.
- if (bias_term_) {
- caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num_output_,
- N_, 1, (Dtype)1., this->blobs_[1]->cpu_data(),
- bias_multiplier_.cpu_data(),
- (Dtype)1., top_data + top[i]->offset(n));
+ for (int n = 0; n < this->num_; ++n) {
+ this->forward_cpu_gemm(bottom_data + bottom[i]->offset(n), weight,
+ top_data + top[i]->offset(n));
+ if (this->bias_term_) {
+ const Dtype* bias = this->blobs_[1]->cpu_data();
+ this->forward_cpu_bias(top_data + top[i]->offset(n), bias);
}
}
}
@@ -177,82 +37,37 @@ void ConvolutionLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
template <typename Dtype>
void ConvolutionLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
- const Dtype* weight = NULL;
- Dtype* weight_diff = NULL;
+ const Dtype* weight = this->blobs_[0]->cpu_data();
+ Dtype* weight_diff = this->blobs_[0]->mutable_cpu_diff();
if (this->param_propagate_down_[0]) {
- weight = this->blobs_[0]->cpu_data();
- weight_diff = this->blobs_[0]->mutable_cpu_diff();
caffe_set(this->blobs_[0]->count(), Dtype(0), weight_diff);
}
- Dtype* bias_diff = NULL;
- if (bias_term_ && this->param_propagate_down_[1]) {
- bias_diff = this->blobs_[1]->mutable_cpu_diff();
- caffe_set(this->blobs_[1]->count(), Dtype(0), bias_diff);
+ if (this->bias_term_ && this->param_propagate_down_[1]) {
+ caffe_set(this->blobs_[1]->count(), Dtype(0),
+ this->blobs_[1]->mutable_cpu_diff());
}
- const int weight_offset = M_ * K_;
- const int col_offset = K_ * N_;
- const int top_offset = M_ * N_;
for (int i = 0; i < top.size(); ++i) {
- const Dtype* top_diff = NULL;
+ const Dtype* top_diff = top[i]->cpu_diff();
+ const Dtype* bottom_data = bottom[i]->cpu_data();
+ Dtype* bottom_diff = bottom[i]->mutable_cpu_diff();
// Bias gradient, if necessary.
- if (bias_term_ && this->param_propagate_down_[1]) {
- top_diff = top[i]->cpu_diff();
- for (int n = 0; n < num_; ++n) {
- caffe_cpu_gemv<Dtype>(CblasNoTrans, num_output_, N_,
- 1., top_diff + top[0]->offset(n),
- bias_multiplier_.cpu_data(), 1.,
- bias_diff);
+ if (this->bias_term_ && this->param_propagate_down_[1]) {
+ Dtype* bias_diff = this->blobs_[1]->mutable_cpu_diff();
+ for (int n = 0; n < this->num_; ++n) {
+ this->backward_cpu_bias(bias_diff, top_diff + top[i]->offset(n));
}
}
if (this->param_propagate_down_[0] || propagate_down[i]) {
- if (!top_diff) {
- top_diff = top[i]->cpu_diff();
- }
- Dtype* col_buff = NULL;
- if (!is_1x1_) {
- col_buff = col_buffer_.mutable_cpu_data();
- }
- const Dtype* bottom_data = bottom[i]->cpu_data();
- Dtype* bottom_diff = bottom[i]->mutable_cpu_diff();
- for (int n = 0; n < num_; ++n) {
- // Since we saved memory in the forward pass by not storing all col
- // data, we will need to recompute them.
- if (!is_1x1_) {
- im2col_cpu(bottom_data + bottom[i]->offset(n), channels_, height_,
- width_, kernel_h_, kernel_w_, pad_h_, pad_w_,
- stride_h_, stride_w_, col_buff);
- } else {
- col_buff = bottom[i]->mutable_cpu_data() + bottom[i]->offset(n);
- }
+ for (int n = 0; n < this->num_; ++n) {
// gradient w.r.t. weight. Note that we will accumulate diffs.
if (this->param_propagate_down_[0]) {
- for (int g = 0; g < group_; ++g) {
- caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_,
- (Dtype)1., top_diff + top[i]->offset(n) + top_offset * g,
- col_buff + col_offset * g, (Dtype)1.,
- weight_diff + weight_offset * g);
- }
+ this->weight_cpu_gemm(bottom_data + bottom[i]->offset(n),
+ top_diff + top[i]->offset(n), weight_diff);
}
// gradient w.r.t. bottom data, if necessary.
if (propagate_down[i]) {
- if (weight == NULL) {
- weight = this->blobs_[0]->cpu_data();
- }
- if (is_1x1_) {
- col_buff = bottom[i]->mutable_cpu_diff() + bottom[i]->offset(n);
- }
- for (int g = 0; g < group_; ++g) {
- caffe_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_,
- (Dtype)1., weight + weight_offset * g,
- top_diff + top[i]->offset(n) + top_offset * g,
- (Dtype)0., col_buff + col_offset * g);
- }
- // col2im back to the data
- if (!is_1x1_) {
- col2im_cpu(col_buff, channels_, height_, width_,
- kernel_h_, kernel_w_, pad_h_, pad_w_,
- stride_h_, stride_w_, bottom_diff + bottom[i]->offset(n));
- }
+ this->backward_cpu_gemm(top_diff + top[i]->offset(n), weight,
+ bottom_diff + bottom[i]->offset(n));
}
}
}
diff --git a/src/caffe/layers/conv_layer.cu b/src/caffe/layers/conv_layer.cu
index af14facb..3902fdf3 100644
--- a/src/caffe/layers/conv_layer.cu
+++ b/src/caffe/layers/conv_layer.cu
@@ -8,135 +8,64 @@
namespace caffe {
-/// @brief refer to CPU forward -- the BLAS implementation is the same.
template <typename Dtype>
void ConvolutionLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
+ const Dtype* weight = this->blobs_[0]->gpu_data();
for (int i = 0; i < bottom.size(); ++i) {
const Dtype* bottom_data = bottom[i]->gpu_data();
Dtype* top_data = top[i]->mutable_gpu_data();
- Dtype* col_buff = NULL;
- if (!is_1x1_) {
- col_buff = col_buffer_.mutable_gpu_data();
- }
- const Dtype* weight = this->blobs_[0]->gpu_data();
- int weight_offset = M_ * K_;
- int col_offset = K_ * N_;
- int top_offset = M_ * N_;
- for (int n = 0; n < num_; ++n) {
- // im2col transformation: unroll input regions for filtering
- // into column matrix for multplication.
- if (!is_1x1_) {
- im2col_gpu(bottom_data + bottom[i]->offset(n), channels_, height_,
- width_, kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_,
- col_buff);
- } else {
- col_buff = bottom[i]->mutable_gpu_data() + bottom[i]->offset(n);
- }
- // Take inner products for groups.
- for (int g = 0; g < group_; ++g) {
- caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_,
- (Dtype)1., weight + weight_offset * g, col_buff + col_offset * g,
- (Dtype)0., top_data + top[i]->offset(n) + top_offset * g);
- }
- // Add bias.
- if (bias_term_) {
- caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num_output_,
- N_, 1, (Dtype)1., this->blobs_[1]->gpu_data(),
- bias_multiplier_.gpu_data(),
- (Dtype)1., top_data + top[i]->offset(n));
+ for (int n = 0; n < this->num_; ++n) {
+ this->forward_gpu_gemm(bottom_data + bottom[i]->offset(n), weight,
+ top_data + top[i]->offset(n));
+ if (this->bias_term_) {
+ const Dtype* bias = this->blobs_[1]->gpu_data();
+ this->forward_gpu_bias(top_data + top[i]->offset(n), bias);
}
}
}
}
-/// @brief refer to CPU backward -- the BLAS implementation is the same.
template <typename Dtype>
void ConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
- const Dtype* weight = NULL;
- Dtype* weight_diff = NULL;
+ const Dtype* weight = this->blobs_[0]->gpu_data();
+ Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff();
if (this->param_propagate_down_[0]) {
- weight = this->blobs_[0]->gpu_data();
- weight_diff = this->blobs_[0]->mutable_gpu_diff();
caffe_gpu_set(this->blobs_[0]->count(), Dtype(0), weight_diff);
}
- Dtype* bias_diff = NULL;
- if (bias_term_ && this->param_propagate_down_[1]) {
- bias_diff = this->blobs_[1]->mutable_gpu_diff();
- caffe_gpu_set(this->blobs_[1]->count(), Dtype(0), bias_diff);
+ if (this->bias_term_ && this->param_propagate_down_[1]) {
+ caffe_gpu_set(this->blobs_[1]->count(), Dtype(0),
+ this->blobs_[1]->mutable_gpu_diff());
}
- const int weight_offset = M_ * K_;
- const int col_offset = K_ * N_;
- const int top_offset = M_ * N_;
for (int i = 0; i < top.size(); ++i) {
- const Dtype* top_diff = NULL;
+ const Dtype* top_diff = top[i]->gpu_diff();
// Bias gradient, if necessary.
- if (bias_term_ && this->param_propagate_down_[1]) {
- top_diff = top[i]->gpu_diff();
- for (int n = 0; n < num_; ++n) {
- caffe_gpu_gemv<Dtype>(CblasNoTrans, num_output_, N_,
- 1., top_diff + top[0]->offset(n),
- bias_multiplier_.gpu_data(), 1.,
- bias_diff);
+ if (this->bias_term_ && this->param_propagate_down_[1]) {
+ Dtype* bias_diff = this->blobs_[1]->mutable_gpu_diff();
+ for (int n = 0; n < this->num_; ++n) {
+ this->backward_gpu_bias(bias_diff, top_diff + top[i]->offset(n));
}
}
if (this->param_propagate_down_[0] || propagate_down[i]) {
- if (!top_diff) {
- top_diff = top[i]->gpu_diff();
- }
- Dtype* col_buff = NULL;
- if (!is_1x1_) {
- col_buff = col_buffer_.mutable_gpu_data();
- }
const Dtype* bottom_data = bottom[i]->gpu_data();
Dtype* bottom_diff = bottom[i]->mutable_gpu_diff();
- for (int n = 0; n < num_; ++n) {
- // Since we saved memory in the forward pass by not storing all col
- // data, we will need to recompute them.
- if (!is_1x1_) {
- im2col_gpu(bottom_data + bottom[i]->offset(n), channels_, height_,
- width_, kernel_h_, kernel_w_, pad_h_, pad_w_,
- stride_h_, stride_w_, col_buff);
- } else {
- col_buff = bottom[i]->mutable_gpu_data() + bottom[i]->offset(n);
- }
+ for (int n = 0; n < this->num_; ++n) {
// gradient w.r.t. weight. Note that we will accumulate diffs.
if (this->param_propagate_down_[0]) {
- for (int g = 0; g < group_; ++g) {
- caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_,
- (Dtype)1., top_diff + top[i]->offset(n) + top_offset * g,
- col_buff + col_offset * g, (Dtype)1.,
- weight_diff + weight_offset * g);
- }
+ this->weight_gpu_gemm(bottom_data + bottom[i]->offset(n),
+ top_diff + top[i]->offset(n), weight_diff);
}
- // gradient w.r.t. bottom data, if necessary
+ // gradient w.r.t. bottom data, if necessary.
if (propagate_down[i]) {
- if (weight == NULL) {
- weight = this->blobs_[0]->gpu_data();
- }
- if (is_1x1_) {
- col_buff = bottom[i]->mutable_gpu_diff() + bottom[i]->offset(n);
- }
- for (int g = 0; g < group_; ++g) {
- caffe_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_,
- (Dtype)1., weight + weight_offset * g,
- top_diff + top[i]->offset(n) + top_offset * g,
- (Dtype)0., col_buff + col_offset * g);
- }
- // col2im back to the data
- if (!is_1x1_) {
- col2im_gpu(col_buff, channels_, height_, width_,
- kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_,
- bottom_diff + bottom[i]->offset(n));
- }
+ this->backward_gpu_gemm(top_diff + top[i]->offset(n), weight,
+ bottom_diff + bottom[i]->offset(n));
}
}
}
}
}
-
INSTANTIATE_LAYER_GPU_FUNCS(ConvolutionLayer);
} // namespace caffe