diff options
Diffstat (limited to 'src/caffe/test/test_im2col_kernel.cu')
-rw-r--r-- | src/caffe/test/test_im2col_kernel.cu | 17 |
1 files changed, 11 insertions, 6 deletions
diff --git a/src/caffe/test/test_im2col_kernel.cu b/src/caffe/test/test_im2col_kernel.cu index 5671968b..37d1a152 100644 --- a/src/caffe/test/test_im2col_kernel.cu +++ b/src/caffe/test/test_im2col_kernel.cu @@ -17,8 +17,10 @@ namespace caffe { // Forward declare kernel functions template <typename Dtype> __global__ void im2col_gpu_kernel(const int n, const Dtype* data_im, - const int height, const int width, const int ksize, const int pad, - const int stride, const int height_col, const int width_col, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int height_col, const int width_col, Dtype* data_col); extern cudaDeviceProp CAFFE_TEST_CUDA_PROP; @@ -87,8 +89,10 @@ TYPED_TEST(Im2colKernelTest, TestGPU) { // CPU Version for (int n = 0; n < this->blob_bottom_->num(); ++n) { im2col_cpu(this->blob_bottom_->cpu_data() + this->blob_bottom_->offset(n), - this->channels_, this->height_, this->width_, this->kernel_size_, - this->pad_, this->stride_, cpu_data + this->blob_top_cpu_->offset(n)); + this->channels_, this->height_, this->width_, + this->kernel_size_, this->kernel_size_, this->pad_, this->pad_, + this->stride_, this->stride_, + cpu_data + this->blob_top_cpu_->offset(n)); } // GPU version @@ -102,8 +106,9 @@ TYPED_TEST(Im2colKernelTest, TestGPU) { // NOLINT_NEXT_LINE(whitespace/operators) im2col_gpu_kernel<TypeParam><<<grid_dim, CAFFE_CUDA_NUM_THREADS>>>( num_kernels, bottom_data + this->blob_bottom_->offset(n), - this->height_, this->width_, this->kernel_size_, this->pad_, - this->stride_, this->height_col_, this->width_col_, + this->height_, this->width_, this->kernel_size_, this->kernel_size_, + this->pad_, this->pad_, this->stride_, this->stride_, + this->height_col_, this->width_col_, top_data + this->blob_top_->offset(n)); CUDA_POST_KERNEL_CHECK; } |