summaryrefslogtreecommitdiff
path: root/src/caffe/layers/im2col_layer.cu
diff options
context:
space:
mode:
Diffstat (limited to 'src/caffe/layers/im2col_layer.cu')
-rw-r--r--src/caffe/layers/im2col_layer.cu41
1 files changed, 33 insertions, 8 deletions
diff --git a/src/caffe/layers/im2col_layer.cu b/src/caffe/layers/im2col_layer.cu
index 9c338b14..cd507623 100644
--- a/src/caffe/layers/im2col_layer.cu
+++ b/src/caffe/layers/im2col_layer.cu
@@ -12,10 +12,23 @@ void Im2colLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
- for (int n = 0; n < bottom[0]->num(); ++n) {
- im2col_gpu(bottom_data + bottom[0]->offset(n), channels_, height_,
- width_, kernel_h_, kernel_w_, pad_h_, pad_w_,
- stride_h_, stride_w_, top_data + top[0]->offset(n));
+ const int num_kernels = channels_ * top[0]->count(channel_axis_ + 1);
+ for (int n = 0; n < num_; ++n) {
+ if (!force_nd_im2col_ && num_spatial_axes_ == 2) {
+ im2col_gpu(bottom_data + n * bottom_dim_, channels_,
+ bottom[0]->shape(channel_axis_ + 1),
+ bottom[0]->shape(channel_axis_ + 2),
+ kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1],
+ pad_.cpu_data()[0], pad_.cpu_data()[1],
+ stride_.cpu_data()[0], stride_.cpu_data()[1],
+ top_data + n * top_dim_);
+ } else {
+ im2col_nd_gpu(bottom_data + n * bottom_dim_, num_spatial_axes_,
+ num_kernels, bottom[0]->gpu_shape() + channel_axis_,
+ top[0]->gpu_shape() + channel_axis_,
+ kernel_shape_.gpu_data(), pad_.gpu_data(), stride_.gpu_data(),
+ top_data + n * top_dim_);
+ }
}
}
@@ -24,10 +37,22 @@ void Im2colLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
- for (int n = 0; n < top[0]->num(); ++n) {
- col2im_gpu(top_diff + top[0]->offset(n), channels_, height_, width_,
- kernel_h_, kernel_w_, pad_h_, pad_w_,
- stride_h_, stride_w_, bottom_diff + bottom[0]->offset(n));
+ for (int n = 0; n < num_; ++n) {
+ if (!force_nd_im2col_ && num_spatial_axes_ == 2) {
+ col2im_gpu(top_diff + n * top_dim_, channels_,
+ bottom[0]->shape(channel_axis_ + 1),
+ bottom[0]->shape(channel_axis_ + 2),
+ kernel_shape_.cpu_data()[0], kernel_shape_.cpu_data()[1],
+ pad_.cpu_data()[0], pad_.cpu_data()[1],
+ stride_.cpu_data()[0], stride_.cpu_data()[1],
+ bottom_diff + n * bottom_dim_);
+ } else {
+ col2im_nd_gpu(top_diff + n * top_dim_, num_spatial_axes_, bottom_dim_,
+ bottom[0]->gpu_shape() + channel_axis_,
+ top[0]->gpu_shape() + channel_axis_,
+ kernel_shape_.gpu_data(), pad_.gpu_data(), stride_.gpu_data(),
+ bottom_diff + n * bottom_dim_);
+ }
}
}