diff options
author | Yangqing Jia <jiayq84@gmail.com> | 2013-10-25 13:55:23 -0700 |
---|---|---|
committer | Yangqing Jia <jiayq84@gmail.com> | 2013-10-25 13:55:23 -0700 |
commit | 98b8515b56eb6aac97db61df06da1cf196e39353 (patch) | |
tree | 0c8cc52fb2cd4a64c889620be9435bb949946a2b /src | |
parent | 74c417068057a742e961b3bb1cc449e3ac7b1c12 (diff) | |
download | caffe-98b8515b56eb6aac97db61df06da1cf196e39353.tar.gz caffe-98b8515b56eb6aac97db61df06da1cf196e39353.tar.bz2 caffe-98b8515b56eb6aac97db61df06da1cf196e39353.zip |
bugfix
Diffstat (limited to 'src')
-rw-r--r-- | src/caffe/layers/bnll_layer.cu | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/src/caffe/layers/bnll_layer.cu b/src/caffe/layers/bnll_layer.cu index c9a33ed5..fd261a35 100644 --- a/src/caffe/layers/bnll_layer.cu +++ b/src/caffe/layers/bnll_layer.cu @@ -30,9 +30,10 @@ Dtype BNLLLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top, const Dtype* top_diff = top[0]->cpu_diff(); Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); const int count = (*bottom)[0]->count(); + Dtype expval; for (int i = 0; i < count; ++i) { - Dtype expval = exp(min(bottom_data[index], Dtype(kBNLL_THRESHOLD))); - bottom_diff[index] = top_diff[index] * expval / (expval + 1.); + expval = exp(min(bottom_data[i], Dtype(kBNLL_THRESHOLD))); + bottom_diff[i] = top_diff[i] * expval / (expval + 1.); } } return Dtype(0); @@ -42,7 +43,7 @@ template <typename Dtype> __global__ void BNLLForward(const int n, const Dtype* in, Dtype* out) { int index = threadIdx.x + blockIdx.x * blockDim.x; if (index < n) { - out[index] = log(1. + exp(min(in[index], Dtype(kBNLL_THRESHOLD))); + out[index] = log(1. + exp(min(in[index], Dtype(kBNLL_THRESHOLD)))); } } |