diff options
Diffstat (limited to 'src/caffe/layers/im2col_layer.cu')
-rw-r--r-- | src/caffe/layers/im2col_layer.cu | 41 |
1 files changed, 33 insertions, 8 deletions
diff --git a/src/caffe/layers/im2col_layer.cu b/src/caffe/layers/im2col_layer.cu index 9c338b14..cd507623 100644 --- a/src/caffe/layers/im2col_layer.cu +++ b/src/caffe/layers/im2col_layer.cu @@ -12,10 +12,23 @@ void Im2colLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* top_data = top[0]->mutable_gpu_data(); - for (int n = 0; n < bottom[0]->num(); ++n) { - im2col_gpu(bottom_data + bottom[0]->offset(n), channels_, height_, - width_, kernel_h_, kernel_w_, pad_h_, pad_w_, - stride_h_, stride_w_, top_data + top[0]->offset(n)); + const int num_kernels = channels_ * top[0]->count(channel_axis_ + 1); + for (int n = 0; n < num_; ++n) { + if (!force_nd_im2col_ && num_spatial_axes_ == 2) { + im2col_gpu(bottom_data + n * bottom_dim_, channels_, + bottom[0]->shape(channel_axis_ + 1), + bottom[0]->shape(channel_axis_ + 2), + kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1], + pad_.cpu_data()[0], pad_.cpu_data()[1], + stride_.cpu_data()[0], stride_.cpu_data()[1], + top_data + n * top_dim_); + } else { + im2col_nd_gpu(bottom_data + n * bottom_dim_, num_spatial_axes_, + num_kernels, bottom[0]->gpu_shape() + channel_axis_, + top[0]->gpu_shape() + channel_axis_, + kernel_shape_.gpu_data(), pad_.gpu_data(), stride_.gpu_data(), + top_data + n * top_dim_); + } } } @@ -24,10 +37,22 @@ void Im2colLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) { const Dtype* top_diff = top[0]->gpu_diff(); Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); - for (int n = 0; n < top[0]->num(); ++n) { - col2im_gpu(top_diff + top[0]->offset(n), channels_, height_, width_, - kernel_h_, kernel_w_, pad_h_, pad_w_, - stride_h_, stride_w_, bottom_diff + bottom[0]->offset(n)); + for (int n = 0; n < num_; ++n) { + if (!force_nd_im2col_ && num_spatial_axes_ == 2) { + col2im_gpu(top_diff + n * top_dim_, channels_, + bottom[0]->shape(channel_axis_ + 1), + bottom[0]->shape(channel_axis_ + 2), + kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1], + pad_.cpu_data()[0], pad_.cpu_data()[1], + stride_.cpu_data()[0], stride_.cpu_data()[1], + bottom_diff + n * bottom_dim_); + } else { + col2im_nd_gpu(top_diff + n * top_dim_, num_spatial_axes_, bottom_dim_, + bottom[0]->gpu_shape() + channel_axis_, + top[0]->gpu_shape() + channel_axis_, + kernel_shape_.gpu_data(), pad_.gpu_data(), stride_.gpu_data(), + bottom_diff + n * bottom_dim_); + } } } |