summaryrefslogtreecommitdiff
path: root/src/caffe/layers/tile_layer.cu
diff options
context:
space:
mode:
authorJeff Donahue <jeff.donahue@gmail.com>2015-08-18 18:15:20 -0700
committerJeff Donahue <jeff.donahue@gmail.com>2015-08-25 17:58:45 -0700
commitcbff2255bc8470299e15cc155ae7957a3acdd688 (patch)
tree18ad56ef9f997e9361b66707053af4be0342d971 /src/caffe/layers/tile_layer.cu
parent251e67ab3141bc8ac2adf97ea4e961e5664ae008 (diff)
downloadcaffeonacl-cbff2255bc8470299e15cc155ae7957a3acdd688.tar.gz
caffeonacl-cbff2255bc8470299e15cc155ae7957a3acdd688.tar.bz2
caffeonacl-cbff2255bc8470299e15cc155ae7957a3acdd688.zip
TileLayer: add CUDA kernels
Diffstat (limited to 'src/caffe/layers/tile_layer.cu')
-rw-r--r--src/caffe/layers/tile_layer.cu53
1 files changed, 39 insertions, 14 deletions
diff --git a/src/caffe/layers/tile_layer.cu b/src/caffe/layers/tile_layer.cu
index 3af8e2eb..7fd3bc47 100644
--- a/src/caffe/layers/tile_layer.cu
+++ b/src/caffe/layers/tile_layer.cu
@@ -7,16 +7,44 @@
namespace caffe {
template <typename Dtype>
+__global__ void Tile(const int nthreads, const Dtype* bottom_data,
+ const int tile_size, const int num_tiles, const int bottom_tile_axis,
+ Dtype* top_data) {
+ CUDA_KERNEL_LOOP(index, nthreads) {
+ const int d = index % tile_size;
+ const int b = (index / tile_size / num_tiles) % bottom_tile_axis;
+ const int n = index / tile_size / num_tiles / bottom_tile_axis;
+ const int bottom_index = (n * bottom_tile_axis + b) * tile_size + d;
+ top_data[index] = bottom_data[bottom_index];
+ }
+}
+
+template <typename Dtype>
void TileLayer<Dtype>::Forward_gpu(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
- for (int i = 0; i < outer_dim_; ++i) {
- for (int t = 0; t < tiles_; ++t) {
- caffe_copy(inner_dim_, bottom_data, top_data);
- top_data += inner_dim_;
+ const int bottom_tile_axis = bottom[0]->shape(axis_);
+ const int nthreads = top[0]->count();
+ Tile<Dtype> // NOLINT_NEXT_LINE(whitespace/operators)
+ <<<CAFFE_GET_BLOCKS(nthreads), CAFFE_CUDA_NUM_THREADS>>>(
+ nthreads, bottom_data, inner_dim_, tiles_, bottom_tile_axis, top_data);
+}
+
+template <typename Dtype>
+__global__ void TileBackward(const int nthreads, const Dtype* top_diff,
+ const int tile_size, const int num_tiles, const int bottom_tile_axis,
+ Dtype* bottom_diff) {
+ CUDA_KERNEL_LOOP(index, nthreads) {
+ const int d = index % tile_size;
+ const int b = (index / tile_size) % bottom_tile_axis;
+ const int n = index / tile_size / bottom_tile_axis;
+ bottom_diff[index] = 0;
+ int top_index = (n * num_tiles * bottom_tile_axis + b) * tile_size + d;
+ for (int t = 0; t < num_tiles; ++t) {
+ bottom_diff[index] += top_diff[top_index];
+ top_index += bottom_tile_axis * tile_size;
}
- bottom_data += inner_dim_;
}
}
@@ -26,15 +54,12 @@ void TileLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
if (!propagate_down[0]) { return; }
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
- for (int i = 0; i < outer_dim_; ++i) {
- caffe_copy(inner_dim_, top_diff, bottom_diff);
- top_diff += inner_dim_;
- for (int t = 1; t < tiles_; ++t) {
- caffe_gpu_axpy(inner_dim_, Dtype(1), top_diff, bottom_diff);
- top_diff += inner_dim_;
- }
- bottom_diff += inner_dim_;
- }
+ const int bottom_tile_axis = bottom[0]->shape(axis_);
+ const int tile_size = inner_dim_ / bottom_tile_axis;
+ const int nthreads = bottom[0]->count();
+ TileBackward<Dtype> // NOLINT_NEXT_LINE(whitespace/operators)
+ <<<CAFFE_GET_BLOCKS(nthreads), CAFFE_CUDA_NUM_THREADS>>>(
+ nthreads, top_diff, tile_size, tiles_, bottom_tile_axis, bottom_diff);
}
INSTANTIATE_LAYER_GPU_FUNCS(TileLayer);