summaryrefslogtreecommitdiff
path: root/src/caffe
diff options
context:
space:
mode:
authorPrzemysław Dolata <snowball91b@gmail.com>2018-08-20 07:55:36 +0200
committerWook Song <wook16.song@samsung.com>2020-01-23 22:50:49 +0900
commit985206a68a0a859f98790102e07222cc1ceb116f (patch)
tree3f2b810f46e329b02e23efb770f2a75fab728f8a /src/caffe
parentfe062ef6983f8ebefd8f2d6c140fd5c548dca59c (diff)
parentb6ad8b657fb689ebc061d800feacfaf3ab1185c5 (diff)
downloadcaffe-985206a68a0a859f98790102e07222cc1ceb116f.tar.gz
caffe-985206a68a0a859f98790102e07222cc1ceb116f.tar.bz2
caffe-985206a68a0a859f98790102e07222cc1ceb116f.zip
Merge pull request #6320 from Noiredd/clip
Clip layer - resurrection
Diffstat (limited to 'src/caffe')
-rw-r--r--src/caffe/layer_factory.cpp1
-rw-r--r--src/caffe/layers/clip_layer.cpp51
-rw-r--r--src/caffe/layers/clip_layer.cu67
-rw-r--r--src/caffe/proto/caffe.proto9
-rw-r--r--src/caffe/test/test_neuron_layer.cpp61
5 files changed, 188 insertions, 1 deletions
diff --git a/src/caffe/layer_factory.cpp b/src/caffe/layer_factory.cpp
index 9f9026b1..d9984431 100644
--- a/src/caffe/layer_factory.cpp
+++ b/src/caffe/layer_factory.cpp
@@ -7,6 +7,7 @@
#include "caffe/layer.hpp"
#include "caffe/layer_factory.hpp"
+#include "caffe/layers/clip_layer.hpp"
#include "caffe/layers/conv_layer.hpp"
#include "caffe/layers/deconv_layer.hpp"
#include "caffe/layers/lrn_layer.hpp"
diff --git a/src/caffe/layers/clip_layer.cpp b/src/caffe/layers/clip_layer.cpp
new file mode 100644
index 00000000..9d9a5967
--- /dev/null
+++ b/src/caffe/layers/clip_layer.cpp
@@ -0,0 +1,51 @@
+#include <algorithm>
+#include <vector>
+
+#include "caffe/layers/clip_layer.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void ClipLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ const Dtype* bottom_data = bottom[0]->cpu_data();
+ Dtype* top_data = top[0]->mutable_cpu_data();
+ const int count = bottom[0]->count();
+
+ Dtype min = this->layer_param_.clip_param().min();
+ Dtype max = this->layer_param_.clip_param().max();
+
+ for (int i = 0; i < count; ++i) {
+ top_data[i] = std::max(min, std::min(bottom_data[i], max));
+ }
+}
+
+template <typename Dtype>
+void ClipLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down,
+ const vector<Blob<Dtype>*>& bottom) {
+ if (propagate_down[0]) {
+ const Dtype* bottom_data = bottom[0]->cpu_data();
+ const Dtype* top_diff = top[0]->cpu_diff();
+ Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
+ const int count = bottom[0]->count();
+
+ Dtype min = this->layer_param_.clip_param().min();
+ Dtype max = this->layer_param_.clip_param().max();
+
+ for (int i = 0; i < count; ++i) {
+ bottom_diff[i] = top_diff[i] * (
+ bottom_data[i] >= min && bottom_data[i] <= max);
+ }
+ }
+}
+
+
+#ifdef CPU_ONLY
+STUB_GPU(ClipLayer);
+#endif
+
+INSTANTIATE_CLASS(ClipLayer);
+REGISTER_LAYER_CLASS(Clip);
+
+} // namespace caffe
diff --git a/src/caffe/layers/clip_layer.cu b/src/caffe/layers/clip_layer.cu
new file mode 100644
index 00000000..56f3be32
--- /dev/null
+++ b/src/caffe/layers/clip_layer.cu
@@ -0,0 +1,67 @@
+#include <vector>
+
+#include "caffe/layers/clip_layer.hpp"
+#include "caffe/util/math_functions.hpp"
+
+namespace caffe {
+
+__global__ void ClipForward(const int n, const float* in, float* out,
+ float p_min, float p_max) {
+ CUDA_KERNEL_LOOP(index, n) {
+ out[index] = fmaxf(p_min, fminf(in[index], p_max));
+ }
+}
+
+__global__ void ClipForward(const int n, const double* in, double* out,
+ double p_min, double p_max) {
+ CUDA_KERNEL_LOOP(index, n) {
+ out[index] = fmax(p_min, fmin(in[index], p_max));
+ }
+}
+
+template <typename Dtype>
+void ClipLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ const Dtype* bottom_data = bottom[0]->gpu_data();
+ Dtype* top_data = top[0]->mutable_gpu_data();
+ const int count = bottom[0]->count();
+ Dtype p_min = this->layer_param_.clip_param().min();
+ Dtype p_max = this->layer_param_.clip_param().max();
+ // NOLINT_NEXT_LINE(whitespace/operators)
+ ClipForward<<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
+ count, bottom_data, top_data, p_min, p_max);
+ CUDA_POST_KERNEL_CHECK;
+}
+
+template <typename Dtype>
+__global__ void ClipBackward(const int n, const Dtype* in_diff,
+ const Dtype* in_data, Dtype* out_diff, Dtype p_min, Dtype p_max) {
+ CUDA_KERNEL_LOOP(index, n) {
+ out_diff[index] = in_diff[index] * (
+ in_data[index] >= p_min && in_data[index] <= p_max);
+ }
+}
+
+template <typename Dtype>
+void ClipLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down,
+ const vector<Blob<Dtype>*>& bottom) {
+ if (propagate_down[0]) {
+ const Dtype* bottom_data = bottom[0]->gpu_data();
+ const Dtype* top_diff = top[0]->gpu_diff();
+ Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
+ const int count = bottom[0]->count();
+ Dtype p_min = this->layer_param_.clip_param().min();
+ Dtype p_max = this->layer_param_.clip_param().max();
+ // NOLINT_NEXT_LINE(whitespace/operators)
+ ClipBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
+ count, top_diff, bottom_data, bottom_diff, p_min, p_max);
+ CUDA_POST_KERNEL_CHECK;
+ }
+}
+
+
+INSTANTIATE_LAYER_GPU_FUNCS(ClipLayer);
+
+
+} // namespace caffe
diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto
index f784aa96..5c235c6f 100644
--- a/src/caffe/proto/caffe.proto
+++ b/src/caffe/proto/caffe.proto
@@ -322,7 +322,7 @@ message ParamSpec {
// NOTE
// Update the next available ID when you add a new LayerParameter field.
//
-// LayerParameter next available layer-specific ID: 148 (last added: swish_param)
+// LayerParameter next available layer-specific ID: 149 (last added: clip_param)
message LayerParameter {
optional string name = 1; // the layer name
optional string type = 2; // the layer type
@@ -378,6 +378,7 @@ message LayerParameter {
optional ArgMaxParameter argmax_param = 103;
optional BatchNormParameter batch_norm_param = 139;
optional BiasParameter bias_param = 141;
+ optional ClipParameter clip_param = 148;
optional ConcatParameter concat_param = 104;
optional ContrastiveLossParameter contrastive_loss_param = 105;
optional ConvolutionParameter convolution_param = 106;
@@ -505,6 +506,12 @@ message ArgMaxParameter {
optional int32 axis = 3;
}
+// Message that stores parameters used by ClipLayer
+message ClipParameter {
+ required float min = 1;
+ required float max = 2;
+}
+
message ConcatParameter {
// The axis along which to concatenate -- may be negative to index from the
// end (e.g., -1 for the last axis). Other axes must have the
diff --git a/src/caffe/test/test_neuron_layer.cpp b/src/caffe/test/test_neuron_layer.cpp
index 83d80fcd..d1ecc37b 100644
--- a/src/caffe/test/test_neuron_layer.cpp
+++ b/src/caffe/test/test_neuron_layer.cpp
@@ -10,6 +10,7 @@
#include "caffe/layers/absval_layer.hpp"
#include "caffe/layers/bnll_layer.hpp"
+#include "caffe/layers/clip_layer.hpp"
#include "caffe/layers/dropout_layer.hpp"
#include "caffe/layers/elu_layer.hpp"
#include "caffe/layers/exp_layer.hpp"
@@ -206,6 +207,66 @@ TYPED_TEST(NeuronLayerTest, TestAbsGradient) {
this->blob_top_vec_);
}
+TYPED_TEST(NeuronLayerTest, TestClip) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ CHECK(google::protobuf::TextFormat::ParseFromString(
+ "clip_param { min: -1, max: 2 }", &layer_param));
+ ClipLayer<Dtype> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ // Now, check values
+ const Dtype* bottom_data = this->blob_bottom_->cpu_data();
+ const Dtype* top_data = this->blob_top_->cpu_data();
+ for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+ EXPECT_GE(top_data[i], -1);
+ EXPECT_LE(top_data[i], 2);
+ EXPECT_TRUE(bottom_data[i] > -1 || top_data[i] == -1);
+ EXPECT_TRUE(bottom_data[i] < 2 || top_data[i] == 2);
+ EXPECT_TRUE(!(bottom_data[i] >= -1 && bottom_data[i] <= 2)
+ || top_data[i] == bottom_data[i]);
+ }
+}
+
+TYPED_TEST(NeuronLayerTest, TestClipGradient) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ CHECK(google::protobuf::TextFormat::ParseFromString(
+ "clip_param { min: -1, max: 2 }", &layer_param));
+ ClipLayer<Dtype> layer(layer_param);
+ // Unfortunately, it might happen that an input value lands exactly within
+ // the discontinuity region of the Clip function. In this case the numeric
+ // gradient is likely to differ significantly (i.e. by a value larger than
+ // checker tolerance) from the computed gradient. To handle such cases, we
+ // eliminate such values from the input blob before the gradient check.
+ const Dtype epsilon = 1e-2;
+ const Dtype min_range_start = layer_param.clip_param().min() - epsilon;
+ const Dtype min_range_end = layer_param.clip_param().min() + epsilon;
+ const Dtype max_range_start = layer_param.clip_param().max() - epsilon;
+ const Dtype max_range_end = layer_param.clip_param().max() + epsilon;
+ // The input blob is owned by the NeuronLayerTest object, so we begin with
+ // creating a temporary blob and copying the input data there.
+ Blob<Dtype> temp_bottom;
+ temp_bottom.ReshapeLike(*this->blob_bottom_);
+ const Dtype* bottom_data = this->blob_bottom_->cpu_data();
+ Dtype* temp_data_mutable = temp_bottom.mutable_cpu_data();
+ for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+ if (bottom_data[i] >= min_range_start &&
+ bottom_data[i] <= min_range_end) {
+ temp_data_mutable[i] = bottom_data[i] - epsilon;
+ } else if (bottom_data[i] >= max_range_start &&
+ bottom_data[i] <= max_range_end) {
+ temp_data_mutable[i] = bottom_data[i] + epsilon;
+ } else {
+ temp_data_mutable[i] = bottom_data[i];
+ }
+ }
+ vector<Blob<Dtype>*> temp_bottom_vec;
+ temp_bottom_vec.push_back(&temp_bottom);
+ GradientChecker<Dtype> checker(epsilon, 1e-3);
+ checker.CheckGradientEltwise(&layer, temp_bottom_vec, this->blob_top_vec_);
+}
+
TYPED_TEST(NeuronLayerTest, TestReLU) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;