1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
|
// Copyright 2013 Yangqing Jia
#include <algorithm>
#include <vector>
#include "caffe/layer.hpp"
#include "caffe/vision_layers.hpp"
using std::max;
namespace caffe {
const float kBNLL_THRESHOLD = 50.;
template <typename Dtype>
void BNLLLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
const int count = bottom[0]->count();
for (int i = 0; i < count; ++i) {
top_data[i] = bottom_data[i] > 0 ?
bottom_data[i] + log(1. + exp(-bottom_data[i])) :
log(1. + exp(bottom_data[i]));
}
}
template <typename Dtype>
Dtype BNLLLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down,
vector<Blob<Dtype>*>* bottom) {
if (propagate_down) {
const Dtype* bottom_data = (*bottom)[0]->cpu_data();
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
const int count = (*bottom)[0]->count();
Dtype expval;
for (int i = 0; i < count; ++i) {
expval = exp(min(bottom_data[i], Dtype(kBNLL_THRESHOLD)));
bottom_diff[i] = top_diff[i] * expval / (expval + 1.);
}
}
return Dtype(0);
}
template <typename Dtype>
__global__ void BNLLForward(const int n, const Dtype* in, Dtype* out) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < n) {
out[index] = in[index] > 0 ?
in[index] + log(1. + exp(-in[index])) :
log(1. + exp(in[index]));
}
}
template <typename Dtype>
void BNLLLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
const int count = bottom[0]->count();
// NOLINT_NEXT_LINE(whitespace/operators)
BNLLForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, top_data);
CUDA_POST_KERNEL_CHECK;
}
template <typename Dtype>
__global__ void BNLLBackward(const int n, const Dtype* in_diff,
const Dtype* in_data, Dtype* out_diff) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < n) {
Dtype expval = exp(min(in_data[index], Dtype(kBNLL_THRESHOLD)));
out_diff[index] = in_diff[index] * expval / (expval + 1.);
}
}
template <typename Dtype>
Dtype BNLLLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down,
vector<Blob<Dtype>*>* bottom) {
if (propagate_down) {
const Dtype* bottom_data = (*bottom)[0]->gpu_data();
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
const int count = (*bottom)[0]->count();
// NOLINT_NEXT_LINE(whitespace/operators)
BNLLBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, top_diff, bottom_data, bottom_diff);
CUDA_POST_KERNEL_CHECK;
}
return Dtype(0);
}
INSTANTIATE_CLASS(BNLLLayer);
} // namespace caffe
|