summaryrefslogtreecommitdiff
path: root/src/caffe/layers
diff options
context:
space:
mode:
Diffstat (limited to 'src/caffe/layers')
-rw-r--r--src/caffe/layers/clip_layer.cpp51
-rw-r--r--src/caffe/layers/clip_layer.cu67
-rw-r--r--src/caffe/layers/embed_layer.cu5
-rw-r--r--src/caffe/layers/hdf5_data_layer.cpp2
-rw-r--r--src/caffe/layers/hdf5_data_layer.cu2
-rw-r--r--src/caffe/layers/hdf5_output_layer.cpp2
-rw-r--r--src/caffe/layers/hdf5_output_layer.cu2
-rw-r--r--src/caffe/layers/inner_product_layer.cpp2
-rw-r--r--src/caffe/layers/pooling_layer.cpp23
-rw-r--r--src/caffe/layers/swish_layer.cpp68
-rw-r--r--src/caffe/layers/swish_layer.cu54
11 files changed, 272 insertions, 6 deletions
diff --git a/src/caffe/layers/clip_layer.cpp b/src/caffe/layers/clip_layer.cpp
new file mode 100644
index 00000000..9d9a5967
--- /dev/null
+++ b/src/caffe/layers/clip_layer.cpp
@@ -0,0 +1,51 @@
+#include <algorithm>
+#include <vector>
+
+#include "caffe/layers/clip_layer.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void ClipLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ const Dtype* bottom_data = bottom[0]->cpu_data();
+ Dtype* top_data = top[0]->mutable_cpu_data();
+ const int count = bottom[0]->count();
+
+ Dtype min = this->layer_param_.clip_param().min();
+ Dtype max = this->layer_param_.clip_param().max();
+
+ for (int i = 0; i < count; ++i) {
+ top_data[i] = std::max(min, std::min(bottom_data[i], max));
+ }
+}
+
+template <typename Dtype>
+void ClipLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down,
+ const vector<Blob<Dtype>*>& bottom) {
+ if (propagate_down[0]) {
+ const Dtype* bottom_data = bottom[0]->cpu_data();
+ const Dtype* top_diff = top[0]->cpu_diff();
+ Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
+ const int count = bottom[0]->count();
+
+ Dtype min = this->layer_param_.clip_param().min();
+ Dtype max = this->layer_param_.clip_param().max();
+
+ for (int i = 0; i < count; ++i) {
+ bottom_diff[i] = top_diff[i] * (
+ bottom_data[i] >= min && bottom_data[i] <= max);
+ }
+ }
+}
+
+
+#ifdef CPU_ONLY
+STUB_GPU(ClipLayer);
+#endif
+
+INSTANTIATE_CLASS(ClipLayer);
+REGISTER_LAYER_CLASS(Clip);
+
+} // namespace caffe
diff --git a/src/caffe/layers/clip_layer.cu b/src/caffe/layers/clip_layer.cu
new file mode 100644
index 00000000..56f3be32
--- /dev/null
+++ b/src/caffe/layers/clip_layer.cu
@@ -0,0 +1,67 @@
+#include <vector>
+
+#include "caffe/layers/clip_layer.hpp"
+#include "caffe/util/math_functions.hpp"
+
+namespace caffe {
+
+__global__ void ClipForward(const int n, const float* in, float* out,
+ float p_min, float p_max) {
+ CUDA_KERNEL_LOOP(index, n) {
+ out[index] = fmaxf(p_min, fminf(in[index], p_max));
+ }
+}
+
+__global__ void ClipForward(const int n, const double* in, double* out,
+ double p_min, double p_max) {
+ CUDA_KERNEL_LOOP(index, n) {
+ out[index] = fmax(p_min, fmin(in[index], p_max));
+ }
+}
+
+template <typename Dtype>
+void ClipLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ const Dtype* bottom_data = bottom[0]->gpu_data();
+ Dtype* top_data = top[0]->mutable_gpu_data();
+ const int count = bottom[0]->count();
+ Dtype p_min = this->layer_param_.clip_param().min();
+ Dtype p_max = this->layer_param_.clip_param().max();
+ // NOLINT_NEXT_LINE(whitespace/operators)
+ ClipForward<<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
+ count, bottom_data, top_data, p_min, p_max);
+ CUDA_POST_KERNEL_CHECK;
+}
+
+template <typename Dtype>
+__global__ void ClipBackward(const int n, const Dtype* in_diff,
+ const Dtype* in_data, Dtype* out_diff, Dtype p_min, Dtype p_max) {
+ CUDA_KERNEL_LOOP(index, n) {
+ out_diff[index] = in_diff[index] * (
+ in_data[index] >= p_min && in_data[index] <= p_max);
+ }
+}
+
+template <typename Dtype>
+void ClipLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down,
+ const vector<Blob<Dtype>*>& bottom) {
+ if (propagate_down[0]) {
+ const Dtype* bottom_data = bottom[0]->gpu_data();
+ const Dtype* top_diff = top[0]->gpu_diff();
+ Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
+ const int count = bottom[0]->count();
+ Dtype p_min = this->layer_param_.clip_param().min();
+ Dtype p_max = this->layer_param_.clip_param().max();
+ // NOLINT_NEXT_LINE(whitespace/operators)
+ ClipBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
+ count, top_diff, bottom_data, bottom_diff, p_min, p_max);
+ CUDA_POST_KERNEL_CHECK;
+ }
+}
+
+
+INSTANTIATE_LAYER_GPU_FUNCS(ClipLayer);
+
+
+} // namespace caffe
diff --git a/src/caffe/layers/embed_layer.cu b/src/caffe/layers/embed_layer.cu
index 6324a3a8..3cf39fd9 100644
--- a/src/caffe/layers/embed_layer.cu
+++ b/src/caffe/layers/embed_layer.cu
@@ -15,6 +15,11 @@ __global__ void EmbedForward(const int nthreads, const Dtype* bottom_data,
const int n = top_index / N;
const int d = top_index % N;
const int index = static_cast<int>(bottom_data[n]);
+ #ifdef DEBUG
+ assert(index >= 0);
+ assert(index < K);
+ assert(static_cast<Dtype>(index) == bottom_data[n]);
+ #endif
const int weight_index = index * N + d;
top_data[top_index] = weight[weight_index];
}
diff --git a/src/caffe/layers/hdf5_data_layer.cpp b/src/caffe/layers/hdf5_data_layer.cpp
index 00716a92..7668854c 100644
--- a/src/caffe/layers/hdf5_data_layer.cpp
+++ b/src/caffe/layers/hdf5_data_layer.cpp
@@ -1,3 +1,4 @@
+#ifdef USE_HDF5
/*
TODO:
- load file in a separate thread ("prefetch")
@@ -184,3 +185,4 @@ INSTANTIATE_CLASS(HDF5DataLayer);
REGISTER_LAYER_CLASS(HDF5Data);
} // namespace caffe
+#endif // USE_HDF5
diff --git a/src/caffe/layers/hdf5_data_layer.cu b/src/caffe/layers/hdf5_data_layer.cu
index 33eebd41..70cd9f32 100644
--- a/src/caffe/layers/hdf5_data_layer.cu
+++ b/src/caffe/layers/hdf5_data_layer.cu
@@ -1,3 +1,4 @@
+#ifdef USE_HDF5
/*
TODO:
- only load parts of the file, in accordance with a prototxt param "max_mem"
@@ -34,3 +35,4 @@ void HDF5DataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
INSTANTIATE_LAYER_GPU_FUNCS(HDF5DataLayer);
} // namespace caffe
+#endif // USE_HDF5
diff --git a/src/caffe/layers/hdf5_output_layer.cpp b/src/caffe/layers/hdf5_output_layer.cpp
index f8f1edcd..28c453a2 100644
--- a/src/caffe/layers/hdf5_output_layer.cpp
+++ b/src/caffe/layers/hdf5_output_layer.cpp
@@ -1,3 +1,4 @@
+#ifdef USE_HDF5
#include <vector>
#include "hdf5.h"
@@ -72,3 +73,4 @@ INSTANTIATE_CLASS(HDF5OutputLayer);
REGISTER_LAYER_CLASS(HDF5Output);
} // namespace caffe
+#endif // USE_HDF5
diff --git a/src/caffe/layers/hdf5_output_layer.cu b/src/caffe/layers/hdf5_output_layer.cu
index c1685cd3..891aea03 100644
--- a/src/caffe/layers/hdf5_output_layer.cu
+++ b/src/caffe/layers/hdf5_output_layer.cu
@@ -1,3 +1,4 @@
+#ifdef USE_HDF5
#include <vector>
#include "hdf5.h"
@@ -37,3 +38,4 @@ void HDF5OutputLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
INSTANTIATE_LAYER_GPU_FUNCS(HDF5OutputLayer);
} // namespace caffe
+#endif // USE_HDF5
diff --git a/src/caffe/layers/inner_product_layer.cpp b/src/caffe/layers/inner_product_layer.cpp
index e65349f0..57fdbe1f 100644
--- a/src/caffe/layers/inner_product_layer.cpp
+++ b/src/caffe/layers/inner_product_layer.cpp
@@ -42,7 +42,7 @@ void InnerProductLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
shared_ptr<Filler<Dtype> > weight_filler(GetFiller<Dtype>(
this->layer_param_.inner_product_param().weight_filler()));
weight_filler->Fill(this->blobs_[0].get());
- // If necessary, intiialize and fill the bias term
+ // If necessary, initialize and fill the bias term
if (bias_term_) {
vector<int> bias_shape(1, N_);
this->blobs_[1].reset(new Blob<Dtype>(bias_shape));
diff --git a/src/caffe/layers/pooling_layer.cpp b/src/caffe/layers/pooling_layer.cpp
index 90897db0..32dc0482 100644
--- a/src/caffe/layers/pooling_layer.cpp
+++ b/src/caffe/layers/pooling_layer.cpp
@@ -35,6 +35,7 @@ void PoolingLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
|| (!pool_param.has_stride_h() && !pool_param.has_stride_w()))
<< "Stride is stride OR stride_h and stride_w are required.";
global_pooling_ = pool_param.global_pooling();
+ round_mode_ = pool_param.round_mode();
if (global_pooling_) {
kernel_h_ = bottom[0]->height();
kernel_w_ = bottom[0]->width();
@@ -87,10 +88,22 @@ void PoolingLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
kernel_h_ = bottom[0]->height();
kernel_w_ = bottom[0]->width();
}
- pooled_height_ = static_cast<int>(ceil(static_cast<float>(
- height_ + 2 * pad_h_ - kernel_h_) / stride_h_)) + 1;
- pooled_width_ = static_cast<int>(ceil(static_cast<float>(
- width_ + 2 * pad_w_ - kernel_w_) / stride_w_)) + 1;
+ switch (round_mode_) {
+ case PoolingParameter_RoundMode_CEIL:
+ pooled_height_ = static_cast<int>(ceil(static_cast<float>(
+ height_ + 2 * pad_h_ - kernel_h_) / stride_h_)) + 1;
+ pooled_width_ = static_cast<int>(ceil(static_cast<float>(
+ width_ + 2 * pad_w_ - kernel_w_) / stride_w_)) + 1;
+ break;
+ case PoolingParameter_RoundMode_FLOOR:
+ pooled_height_ = static_cast<int>(floor(static_cast<float>(
+ height_ + 2 * pad_h_ - kernel_h_) / stride_h_)) + 1;
+ pooled_width_ = static_cast<int>(floor(static_cast<float>(
+ width_ + 2 * pad_w_ - kernel_w_) / stride_w_)) + 1;
+ break;
+ default:
+ LOG(FATAL) << "Unknown rounding mode.";
+ }
if (pad_h_ || pad_w_) {
// If we have padding, ensure that the last pooling starts strictly
// inside the image (instead of at the padding); otherwise clip the last.
@@ -132,7 +145,7 @@ void PoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const int top_count = top[0]->count();
// We'll output the mask to top[1] if it's of size >1.
const bool use_top_mask = top.size() > 1;
- int* mask = NULL; // suppress warnings about uninitalized variables
+ int* mask = NULL; // suppress warnings about uninitialized variables
Dtype* top_mask = NULL;
// Different pooling methods. We explicitly do the switch outside the for
// loop to save time, although this results in more code.
diff --git a/src/caffe/layers/swish_layer.cpp b/src/caffe/layers/swish_layer.cpp
new file mode 100644
index 00000000..28935679
--- /dev/null
+++ b/src/caffe/layers/swish_layer.cpp
@@ -0,0 +1,68 @@
+#include <cmath>
+#include <vector>
+
+#include "caffe/layers/swish_layer.hpp"
+#include "caffe/util/math_functions.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void SwishLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ NeuronLayer<Dtype>::LayerSetUp(bottom, top);
+ sigmoid_bottom_vec_.clear();
+ sigmoid_bottom_vec_.push_back(sigmoid_input_.get());
+ sigmoid_top_vec_.clear();
+ sigmoid_top_vec_.push_back(sigmoid_output_.get());
+ sigmoid_layer_->SetUp(sigmoid_bottom_vec_, sigmoid_top_vec_);
+}
+
+template <typename Dtype>
+void SwishLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ NeuronLayer<Dtype>::Reshape(bottom, top);
+ sigmoid_input_->ReshapeLike(*bottom[0]);
+ sigmoid_layer_->Reshape(sigmoid_bottom_vec_, sigmoid_top_vec_);
+}
+
+template <typename Dtype>
+void SwishLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ const Dtype* bottom_data = bottom[0]->cpu_data();
+ Dtype* sigmoid_input_data = sigmoid_input_->mutable_cpu_data();
+ Dtype* top_data = top[0]->mutable_cpu_data();
+ const int count = bottom[0]->count();
+ Dtype beta = this->layer_param_.swish_param().beta();
+ caffe_copy(count, bottom_data, sigmoid_input_data);
+ caffe_scal(count, beta, sigmoid_input_data);
+ sigmoid_layer_->Forward(sigmoid_bottom_vec_, sigmoid_top_vec_);
+ caffe_mul(count, bottom_data, sigmoid_output_->cpu_data(), top_data);
+}
+
+template <typename Dtype>
+void SwishLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down,
+ const vector<Blob<Dtype>*>& bottom) {
+ if (propagate_down[0]) {
+ const Dtype* top_data = top[0]->cpu_data();
+ const Dtype* top_diff = top[0]->cpu_diff();
+ const Dtype* sigmoid_output_data = sigmoid_output_->cpu_data();
+ Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
+ const int count = bottom[0]->count();
+ Dtype beta = this->layer_param_.swish_param().beta();
+ for (int i = 0; i < count; ++i) {
+ const Dtype swish_x = top_data[i];
+ bottom_diff[i] = top_diff[i] * (beta * swish_x + sigmoid_output_data[i]
+ * (1. - beta * swish_x));
+ }
+ }
+}
+
+#ifdef CPU_ONLY
+STUB_GPU(SwishLayer);
+#endif
+
+INSTANTIATE_CLASS(SwishLayer);
+REGISTER_LAYER_CLASS(Swish);
+
+} // namespace caffe
diff --git a/src/caffe/layers/swish_layer.cu b/src/caffe/layers/swish_layer.cu
new file mode 100644
index 00000000..c4fef53b
--- /dev/null
+++ b/src/caffe/layers/swish_layer.cu
@@ -0,0 +1,54 @@
+#include <cmath>
+#include <vector>
+
+#include "caffe/layers/swish_layer.hpp"
+#include "caffe/util/math_functions.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void SwishLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ const Dtype* bottom_data = bottom[0]->gpu_data();
+ Dtype* sigmoid_input_data = sigmoid_input_->mutable_gpu_data();
+ Dtype* top_data = top[0]->mutable_gpu_data();
+ const int count = bottom[0]->count();
+ Dtype beta = this->layer_param_.swish_param().beta();
+ caffe_copy(count, bottom_data, sigmoid_input_data);
+ caffe_gpu_scal(count, beta, sigmoid_input_data);
+ sigmoid_layer_->Forward(sigmoid_bottom_vec_, sigmoid_top_vec_);
+ caffe_gpu_mul(count, bottom_data, sigmoid_output_->gpu_data(), top_data);
+}
+
+template <typename Dtype>
+__global__ void SwishBackward(const int n, const Dtype* in_diff,
+ const Dtype* out_data, const Dtype* sigmoid_output_data, Dtype* out_diff,
+ const Dtype beta) {
+ CUDA_KERNEL_LOOP(index, n) {
+ const Dtype swish_x = out_data[index];
+ out_diff[index] = in_diff[index] * (beta * swish_x
+ + sigmoid_output_data[index] * (1 - beta * swish_x));
+ }
+}
+
+template <typename Dtype>
+void SwishLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down,
+ const vector<Blob<Dtype>*>& bottom) {
+ if (propagate_down[0]) {
+ const Dtype* top_data = top[0]->gpu_data();
+ const Dtype* top_diff = top[0]->gpu_diff();
+ const Dtype* sigmoid_output_data = sigmoid_output_->gpu_data();
+ Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
+ const int count = bottom[0]->count();
+ Dtype beta = this->layer_param_.swish_param().beta();
+ // NOLINT_NEXT_LINE(whitespace/operators)
+ SwishBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
+ count, top_diff, top_data, sigmoid_output_data, bottom_diff, beta);
+ CUDA_POST_KERNEL_CHECK;
+ }
+}
+
+INSTANTIATE_LAYER_GPU_FUNCS(SwishLayer);
+
+} // namespace caffe