1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
|
#include <cstring>
#include <vector>
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/filler.hpp"
#include "caffe/util/im2col.hpp"
#include "caffe/vision_layers.hpp"
#include "caffe/test/test_caffe_main.hpp"
namespace caffe {
// Forward declare kernel functions
template <typename Dtype>
__global__ void im2col_gpu_kernel(const int n, const Dtype* data_im,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int height_col, const int width_col,
Dtype* data_col);
extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
template <typename Dtype>
class Im2colKernelTest : public GPUDeviceTest<Dtype> {
protected:
Im2colKernelTest()
// big so launches > 1024 threads
: blob_bottom_(new Blob<Dtype>(5, 500, 10, 10)),
blob_top_(new Blob<Dtype>()),
blob_top_cpu_(new Blob<Dtype>()) {
FillerParameter filler_param;
GaussianFiller<Dtype> filler(filler_param);
filler.Fill(this->blob_bottom_);
height_ = blob_bottom_->height();
width_ = blob_bottom_->width();
channels_ = blob_bottom_->channels();
pad_ = 0;
stride_ = 2;
kernel_size_ = 3;
height_col_ = (height_ + 2 * pad_ - kernel_size_) / stride_ + 1;
width_col_ = (width_ + 2 * pad_ - kernel_size_) / stride_ + 1;
}
virtual ~Im2colKernelTest() {
delete blob_bottom_;
delete blob_top_;
delete blob_top_cpu_;
}
Blob<Dtype>* const blob_bottom_;
Blob<Dtype>* const blob_top_;
Blob<Dtype>* const blob_top_cpu_;
int height_;
int width_;
int channels_;
int pad_;
int stride_;
int kernel_size_;
int height_col_;
int width_col_;
};
TYPED_TEST_CASE(Im2colKernelTest, TestDtypes);
TYPED_TEST(Im2colKernelTest, TestGPU) {
// Reshape the blobs to correct size for im2col output
this->blob_top_->Reshape(this->blob_bottom_->num(),
this->channels_ * this->kernel_size_ * this->kernel_size_,
this->height_col_,
this->width_col_);
this->blob_top_cpu_->Reshape(this->blob_bottom_->num(),
this->channels_ * this->kernel_size_ * this->kernel_size_,
this->height_col_,
this->width_col_);
const TypeParam* bottom_data = this->blob_bottom_->gpu_data();
TypeParam* top_data = this->blob_top_->mutable_gpu_data();
TypeParam* cpu_data = this->blob_top_cpu_->mutable_cpu_data();
// CPU Version
for (int n = 0; n < this->blob_bottom_->num(); ++n) {
im2col_cpu(this->blob_bottom_->cpu_data() + this->blob_bottom_->offset(n),
this->channels_, this->height_, this->width_,
this->kernel_size_, this->kernel_size_, this->pad_, this->pad_,
this->stride_, this->stride_,
cpu_data + this->blob_top_cpu_->offset(n));
}
// GPU version
int num_kernels = this->channels_ * this->height_col_ * this->width_col_;
int default_grid_dim = CAFFE_GET_BLOCKS(num_kernels);
// Launch with different grid sizes
for (int grid_div = 2; grid_div <= 8; grid_div++) {
for (int n = 0; n < this->blob_bottom_->num(); ++n) {
int grid_dim = default_grid_dim/grid_div;
// NOLINT_NEXT_LINE(whitespace/operators)
im2col_gpu_kernel<TypeParam><<<grid_dim, CAFFE_CUDA_NUM_THREADS>>>(
num_kernels, bottom_data + this->blob_bottom_->offset(n),
this->height_, this->width_, this->kernel_size_, this->kernel_size_,
this->pad_, this->pad_, this->stride_, this->stride_,
this->height_col_, this->width_col_,
top_data + this->blob_top_->offset(n));
CUDA_POST_KERNEL_CHECK;
}
// Compare results against CPU version
for (int i = 0; i < this->blob_top_->count(); ++i) {
TypeParam cpuval = cpu_data[i];
TypeParam gpuval = this->blob_top_->cpu_data()[i];
EXPECT_EQ(cpuval, gpuval);
if (cpuval != gpuval) {
break;
}
}
}
}
} // namespace caffe
|