From cf301543fb2a62ba26dbb9307356e78634519123 Mon Sep 17 00:00:00 2001 From: Yangqing Jia Date: Sun, 27 Oct 2013 10:11:23 -0700 Subject: bugfix and made the C++ interface for creating leveldb --- src/caffe/layers/bnll_layer.cu | 8 ++++++-- src/caffe/util/io.cpp | 4 ++-- 2 files changed, 8 insertions(+), 4 deletions(-) (limited to 'src') diff --git a/src/caffe/layers/bnll_layer.cu b/src/caffe/layers/bnll_layer.cu index fd261a35..2c06a63d 100644 --- a/src/caffe/layers/bnll_layer.cu +++ b/src/caffe/layers/bnll_layer.cu @@ -17,7 +17,9 @@ void BNLLLayer::Forward_cpu(const vector*>& bottom, Dtype* top_data = (*top)[0]->mutable_cpu_data(); const int count = bottom[0]->count(); for (int i = 0; i < count; ++i) { - top_data[i] = log(1. + exp(min(bottom_data[i], Dtype(kBNLL_THRESHOLD)))); + top_data[i] = bottom_data[i] > 0 ? + bottom_data[i] + log(1. + exp(-bottom_data[i])) : + log(1. + exp(bottom_data[i])); } } @@ -43,7 +45,9 @@ template __global__ void BNLLForward(const int n, const Dtype* in, Dtype* out) { int index = threadIdx.x + blockIdx.x * blockDim.x; if (index < n) { - out[index] = log(1. + exp(min(in[index], Dtype(kBNLL_THRESHOLD)))); + out[index] = in[index] > 0 ? + in[index] + log(1. + exp(-in[index])) : + log(1. + exp(in[index])); } } diff --git a/src/caffe/util/io.cpp b/src/caffe/util/io.cpp index 5e5510f5..a3c520f0 100644 --- a/src/caffe/util/io.cpp +++ b/src/caffe/util/io.cpp @@ -67,7 +67,7 @@ void WriteProtoToBinaryFile(const Message& proto, const char* filename) { void ReadImageToDatum(const string& filename, const int label, Datum* datum) { - Mat cv_img; + cv::Mat cv_img; cv_img = cv::imread(filename, CV_LOAD_IMAGE_COLOR); CHECK(cv_img.data) << "Could not open or find the image."; datum->set_channels(3); @@ -80,7 +80,7 @@ void ReadImageToDatum(const string& filename, const int label, Datum* datum) { for (int c = 0; c < 3; ++c) { for (int h = 0; h < cv_img.rows; ++h) { for (int w = 0; w < cv_img.cols; ++w) { - datum_string->push_back(static_cast(cv_img.at(h, w)[c])); + datum_string->push_back(static_cast(cv_img.at(h, w)[c])); } } } -- cgit v1.2.3