diff options
author | Jeff Donahue <jeff.donahue@gmail.com> | 2014-11-01 21:22:11 -0700 |
---|---|---|
committer | Jeff Donahue <jeff.donahue@gmail.com> | 2015-01-19 17:03:58 -0800 |
commit | a0ada4fc9e0075d8ba57f128a111fd5e214266c8 (patch) | |
tree | 37404db542bd8b11c61d38d5d1e5b059d3e262ff /src/caffe | |
parent | 350d880ea28c4016a20c5ccec00c909928e7799e (diff) | |
download | caffeonacl-a0ada4fc9e0075d8ba57f128a111fd5e214266c8.tar.gz caffeonacl-a0ada4fc9e0075d8ba57f128a111fd5e214266c8.tar.bz2 caffeonacl-a0ada4fc9e0075d8ba57f128a111fd5e214266c8.zip |
Unroll kernels in SoftmaxLayer...from terrible performance to mediocre
performance.
Diffstat (limited to 'src/caffe')
-rw-r--r-- | src/caffe/layers/softmax_layer.cu | 48 |
1 files changed, 24 insertions, 24 deletions
diff --git a/src/caffe/layers/softmax_layer.cu b/src/caffe/layers/softmax_layer.cu index 292ad2b3..6b8871a0 100644 --- a/src/caffe/layers/softmax_layer.cu +++ b/src/caffe/layers/softmax_layer.cu @@ -25,14 +25,13 @@ __global__ void kernel_channel_max(const int num, const int channels, } template <typename Dtype> -__global__ void kernel_channel_subtract(const int num, const int channels, - const int spatial_dim, Dtype* data, const Dtype* channel_max) { - CUDA_KERNEL_LOOP(index, num * spatial_dim) { - int n = index / spatial_dim; +__global__ void kernel_channel_subtract(const int count, + const int num, const int channels, + const int spatial_dim, const Dtype* channel_max, Dtype* data) { + CUDA_KERNEL_LOOP(index, count) { + int n = index / channels / spatial_dim; int s = index % spatial_dim; - for (int c = 0; c < channels; ++c) { - data[(n * channels + c) * spatial_dim + s] -= channel_max[index]; - } + data[index] -= channel_max[n * spatial_dim + s]; } } @@ -58,14 +57,13 @@ __global__ void kernel_channel_sum(const int num, const int channels, } template <typename Dtype> -__global__ void kernel_channel_div(const int num, const int channels, - const int spatial_dim, Dtype* data, const Dtype* channel_sum) { - CUDA_KERNEL_LOOP(index, num * spatial_dim) { - int n = index / spatial_dim; +__global__ void kernel_channel_div(const int count, + const int num, const int channels, + const int spatial_dim, const Dtype* channel_sum, Dtype* data) { + CUDA_KERNEL_LOOP(index, count) { + int n = index / channels / spatial_dim; int s = index % spatial_dim; - for (int c = 0; c < channels; ++c) { - data[(n * channels + c) * spatial_dim + s] /= channel_sum[index]; - } + data[index] /= channel_sum[n * spatial_dim + s]; } } @@ -91,10 +89,11 @@ void SoftmaxLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom, const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* top_data = top[0]->mutable_gpu_data(); Dtype* scale_data = scale_.mutable_gpu_data(); + int count = bottom[0]->count(); int num = bottom[0]->num(); int channels = bottom[0]->channels(); int spatial_dim = bottom[0]->height() * bottom[0]->width(); - caffe_copy(bottom[0]->count(), bottom_data, top_data); + caffe_copy(count, bottom_data, top_data); // We need to subtract the max to avoid numerical issues, compute the exp, // and then normalize. // compute max @@ -104,9 +103,9 @@ void SoftmaxLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom, scale_data); // subtract // NOLINT_NEXT_LINE(whitespace/operators) - kernel_channel_subtract<Dtype><<<CAFFE_GET_BLOCKS(num * spatial_dim), - CAFFE_CUDA_NUM_THREADS>>>(num, channels, spatial_dim, top_data, - scale_data); + kernel_channel_subtract<Dtype><<<CAFFE_GET_BLOCKS(count), + CAFFE_CUDA_NUM_THREADS>>>(count, num, channels, spatial_dim, + scale_data, top_data); // exponentiate // NOLINT_NEXT_LINE(whitespace/operators) kernel_exp<Dtype><<<CAFFE_GET_BLOCKS(num * channels * spatial_dim), @@ -119,9 +118,9 @@ void SoftmaxLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom, scale_data); // divide // NOLINT_NEXT_LINE(whitespace/operators) - kernel_channel_div<Dtype><<<CAFFE_GET_BLOCKS(num * spatial_dim), - CAFFE_CUDA_NUM_THREADS>>>(num, channels, spatial_dim, top_data, - scale_data); + kernel_channel_div<Dtype><<<CAFFE_GET_BLOCKS(count), + CAFFE_CUDA_NUM_THREADS>>>(count, num, channels, spatial_dim, + scale_data, top_data); } template <typename Dtype> @@ -131,6 +130,7 @@ void SoftmaxLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top, const Dtype* top_data = top[0]->gpu_data(); Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); Dtype* scale_data = scale_.mutable_gpu_data(); + int count = top[0]->count(); int num = top[0]->num(); int channels = top[0]->channels(); int spatial_dim = top[0]->height() * top[0]->width(); @@ -141,9 +141,9 @@ void SoftmaxLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top, CAFFE_CUDA_NUM_THREADS>>>(num, channels, spatial_dim, top_diff, top_data, scale_data); // NOLINT_NEXT_LINE(whitespace/operators) - kernel_channel_subtract<Dtype><<<CAFFE_GET_BLOCKS(num * spatial_dim), - CAFFE_CUDA_NUM_THREADS>>>(num, channels, spatial_dim, bottom_diff, - scale_data); + kernel_channel_subtract<Dtype><<<CAFFE_GET_BLOCKS(count), + CAFFE_CUDA_NUM_THREADS>>>(count, num, channels, spatial_dim, + scale_data, bottom_diff); // elementwise multiplication caffe_gpu_mul<Dtype>(top[0]->count(), bottom_diff, top_data, bottom_diff); } |