summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorxmyqsh <xmyqsh@gmail.com>2017-01-19 05:19:48 +0800
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2017-01-18 13:19:48 -0800
commite744056d8f7ebcf7f0410a52d801d9ca552f69ad (patch)
tree7338748d6fee039ea87d08148d94e801d2c005e9 /src
parent381c0220e64958e7bbdb198e99cc7cedc6512de5 (diff)
downloadcaffeonacl-e744056d8f7ebcf7f0410a52d801d9ca552f69ad.tar.gz
caffeonacl-e744056d8f7ebcf7f0410a52d801d9ca552f69ad.tar.bz2
caffeonacl-e744056d8f7ebcf7f0410a52d801d9ca552f69ad.zip
remove redundant operations in Crop layer (#5138)
Diffstat (limited to 'src')
-rw-r--r--src/caffe/layers/crop_layer.cpp40
-rw-r--r--src/caffe/layers/crop_layer.cu22
2 files changed, 27 insertions, 35 deletions
diff --git a/src/caffe/layers/crop_layer.cpp b/src/caffe/layers/crop_layer.cpp
index d36b61ca..ef8c177c 100644
--- a/src/caffe/layers/crop_layer.cpp
+++ b/src/caffe/layers/crop_layer.cpp
@@ -86,27 +86,25 @@ void CropLayer<Dtype>::crop_copy(const vector<Blob<Dtype>*>& bottom,
}
} else {
// We are at the last dimensions, which is stored continuously in memory
- for (int i = 0; i < top[0]->shape(cur_dim); ++i) {
- // prepare index vector reduced(red) and with offsets(off)
- std::vector<int> ind_red(cur_dim, 0);
- std::vector<int> ind_off(cur_dim+1, 0);
- for (int j = 0; j < cur_dim; ++j) {
- ind_red[j] = indices[j];
- ind_off[j] = indices[j] + offsets[j];
- }
- ind_off[cur_dim] = offsets[cur_dim];
- // do the copy
- if (is_forward) {
- caffe_copy(top[0]->shape(cur_dim),
- src_data + bottom[0]->offset(ind_off),
- dest_data + top[0]->offset(ind_red));
- } else {
- // in the backwards pass the src_data is top_diff
- // and the dest_data is bottom_diff
- caffe_copy(top[0]->shape(cur_dim),
- src_data + top[0]->offset(ind_red),
- dest_data + bottom[0]->offset(ind_off));
- }
+ // prepare index vector reduced(red) and with offsets(off)
+ std::vector<int> ind_red(cur_dim, 0);
+ std::vector<int> ind_off(cur_dim+1, 0);
+ for (int j = 0; j < cur_dim; ++j) {
+ ind_red[j] = indices[j];
+ ind_off[j] = indices[j] + offsets[j];
+ }
+ ind_off[cur_dim] = offsets[cur_dim];
+ // do the copy
+ if (is_forward) {
+ caffe_copy(top[0]->shape(cur_dim),
+ src_data + bottom[0]->offset(ind_off),
+ dest_data + top[0]->offset(ind_red));
+ } else {
+ // in the backwards pass the src_data is top_diff
+ // and the dest_data is bottom_diff
+ caffe_copy(top[0]->shape(cur_dim),
+ src_data + top[0]->offset(ind_red),
+ dest_data + bottom[0]->offset(ind_off));
}
}
}
diff --git a/src/caffe/layers/crop_layer.cu b/src/caffe/layers/crop_layer.cu
index 1ea13253..677077cd 100644
--- a/src/caffe/layers/crop_layer.cu
+++ b/src/caffe/layers/crop_layer.cu
@@ -8,14 +8,12 @@ namespace caffe {
// strides in the last two dimensions.
template <typename Dtype>
__global__ void copy_kernel(const int n, const int height, const int width,
- const int src_outer_stride, const int src_inner_stride,
- const int dest_outer_stride, const int dest_inner_stride,
+ const int src_inner_stride,
+ const int dest_inner_stride,
const Dtype* src, Dtype* dest) {
CUDA_KERNEL_LOOP(index, n) {
- int src_start = index / height * src_outer_stride
- + index % height * src_inner_stride;
- int dest_start = index / height * dest_outer_stride
- + index % height * dest_inner_stride;
+ int src_start = index * src_inner_stride;
+ int dest_start = index * dest_inner_stride;
for (int i = 0; i < width; ++i) {
dest[dest_start + i] = src[src_start + i];
}
@@ -53,11 +51,7 @@ void CropLayer<Dtype>::crop_copy_gpu(const vector<Blob<Dtype>*>& bottom,
ind_off[cur_dim] = offsets[cur_dim];
ind_off[cur_dim+1] = offsets[cur_dim+1];
// Compute copy strides
- const int src_outer_stride =
- bottom[0]->shape(cur_dim)*bottom[0]->shape(cur_dim+1);
const int src_inner_stride = bottom[0]->shape(cur_dim+1);
- const int dest_outer_stride =
- top[0]->shape(cur_dim)*top[0]->shape(cur_dim+1);
const int dest_inner_stride = top[0]->shape(cur_dim+1);
if (is_forward) {
@@ -68,8 +62,8 @@ void CropLayer<Dtype>::crop_copy_gpu(const vector<Blob<Dtype>*>& bottom,
// NOLINT_NEXT_LINE(whitespace/operators)
copy_kernel<<<CAFFE_GET_BLOCKS(lines), CAFFE_CUDA_NUM_THREADS>>>(
lines, height, width,
- src_outer_stride, src_inner_stride,
- dest_outer_stride, dest_inner_stride,
+ src_inner_stride,
+ dest_inner_stride,
bottom_data, top_data);
} else {
@@ -80,8 +74,8 @@ void CropLayer<Dtype>::crop_copy_gpu(const vector<Blob<Dtype>*>& bottom,
// NOLINT_NEXT_LINE(whitespace/operators)
copy_kernel<<<CAFFE_GET_BLOCKS(lines), CAFFE_CUDA_NUM_THREADS>>>(
lines, height, width,
- dest_outer_stride, dest_inner_stride,
- src_outer_stride, src_inner_stride,
+ dest_inner_stride,
+ src_inner_stride,
top_diff, bottom_diff);
}
}