summaryrefslogtreecommitdiff
path: root/src/caffe/test/test_im2col_kernel.cu
diff options
context:
space:
mode:
Diffstat (limited to 'src/caffe/test/test_im2col_kernel.cu')
-rw-r--r--src/caffe/test/test_im2col_kernel.cu17
1 files changed, 11 insertions, 6 deletions
diff --git a/src/caffe/test/test_im2col_kernel.cu b/src/caffe/test/test_im2col_kernel.cu
index 5671968b..37d1a152 100644
--- a/src/caffe/test/test_im2col_kernel.cu
+++ b/src/caffe/test/test_im2col_kernel.cu
@@ -17,8 +17,10 @@ namespace caffe {
// Forward declare kernel functions
template <typename Dtype>
__global__ void im2col_gpu_kernel(const int n, const Dtype* data_im,
- const int height, const int width, const int ksize, const int pad,
- const int stride, const int height_col, const int width_col,
+ const int height, const int width, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int height_col, const int width_col,
Dtype* data_col);
extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
@@ -87,8 +89,10 @@ TYPED_TEST(Im2colKernelTest, TestGPU) {
// CPU Version
for (int n = 0; n < this->blob_bottom_->num(); ++n) {
im2col_cpu(this->blob_bottom_->cpu_data() + this->blob_bottom_->offset(n),
- this->channels_, this->height_, this->width_, this->kernel_size_,
- this->pad_, this->stride_, cpu_data + this->blob_top_cpu_->offset(n));
+ this->channels_, this->height_, this->width_,
+ this->kernel_size_, this->kernel_size_, this->pad_, this->pad_,
+ this->stride_, this->stride_,
+ cpu_data + this->blob_top_cpu_->offset(n));
}
// GPU version
@@ -102,8 +106,9 @@ TYPED_TEST(Im2colKernelTest, TestGPU) {
// NOLINT_NEXT_LINE(whitespace/operators)
im2col_gpu_kernel<TypeParam><<<grid_dim, CAFFE_CUDA_NUM_THREADS>>>(
num_kernels, bottom_data + this->blob_bottom_->offset(n),
- this->height_, this->width_, this->kernel_size_, this->pad_,
- this->stride_, this->height_col_, this->width_col_,
+ this->height_, this->width_, this->kernel_size_, this->kernel_size_,
+ this->pad_, this->pad_, this->stride_, this->stride_,
+ this->height_col_, this->width_col_,
top_data + this->blob_top_->offset(n));
CUDA_POST_KERNEL_CHECK;
}