summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorYangqing Jia <jiayq84@gmail.com>2013-11-11 17:08:27 -0800
committerYangqing Jia <jiayq84@gmail.com>2013-11-11 17:08:27 -0800
commit652d744360fb317a41e0456c0d1226262c5a70d4 (patch)
treed74f78a0ec4917d51a8674804f0535d625f57901 /src
parent76bf486e3de96fe6f5619693742e5c1bb19074b2 (diff)
parentc8e7cce7316c29b32d0b2dcdc98bff75afa0ab40 (diff)
downloadcaffe-652d744360fb317a41e0456c0d1226262c5a70d4.tar.gz
caffe-652d744360fb317a41e0456c0d1226262c5a70d4.tar.bz2
caffe-652d744360fb317a41e0456c0d1226262c5a70d4.zip
Merge branch 'master' of github.com:Yangqing/caffe
Conflicts: Makefile
Diffstat (limited to 'src')
-rw-r--r--src/caffe/layer_factory.cpp2
-rw-r--r--src/caffe/layers/loss_layer.cu47
-rw-r--r--src/caffe/layers/pooling_layer.cu3
-rw-r--r--src/caffe/net.cpp2
-rw-r--r--src/caffe/util/io.cpp1
5 files changed, 52 insertions, 3 deletions
diff --git a/src/caffe/layer_factory.cpp b/src/caffe/layer_factory.cpp
index 178607f4..b663cb2c 100644
--- a/src/caffe/layer_factory.cpp
+++ b/src/caffe/layer_factory.cpp
@@ -33,6 +33,8 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) {
return new EuclideanLossLayer<Dtype>(param);
} else if (type == "im2col") {
return new Im2colLayer<Dtype>(param);
+ } else if (type == "infogain_loss") {
+ return new InfogainLossLayer<Dtype>(param);
} else if (type == "innerproduct") {
return new InnerProductLayer<Dtype>(param);
} else if (type == "lrn") {
diff --git a/src/caffe/layers/loss_layer.cu b/src/caffe/layers/loss_layer.cu
index 0a6f5ee9..ac05ba41 100644
--- a/src/caffe/layers/loss_layer.cu
+++ b/src/caffe/layers/loss_layer.cu
@@ -6,6 +6,7 @@
#include "caffe/layer.hpp"
#include "caffe/vision_layers.hpp"
#include "caffe/util/math_functions.hpp"
+#include "caffe/util/io.hpp"
using std::max;
@@ -17,7 +18,7 @@ template <typename Dtype>
void MultinomialLogisticLossLayer<Dtype>::SetUp(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
CHECK_EQ(bottom.size(), 2) << "Loss Layer takes two blobs as input.";
- CHECK_EQ(top->size(), 0) << "Loss Layer takes no as output.";
+ CHECK_EQ(top->size(), 0) << "Loss Layer takes no output.";
CHECK_EQ(bottom[0]->num(), bottom[1]->num())
<< "The data and label should have the same number.";
CHECK_EQ(bottom[1]->channels(), 1);
@@ -50,6 +51,49 @@ Dtype MultinomialLogisticLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>
template <typename Dtype>
+void InfogainLossLayer<Dtype>::SetUp(
+ const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
+ CHECK_EQ(bottom.size(), 2) << "Loss Layer takes two blobs as input.";
+ CHECK_EQ(top->size(), 0) << "Loss Layer takes no output.";
+ CHECK_EQ(bottom[0]->num(), bottom[1]->num())
+ << "The data and label should have the same number.";
+ CHECK_EQ(bottom[1]->channels(), 1);
+ CHECK_EQ(bottom[1]->height(), 1);
+ CHECK_EQ(bottom[1]->width(), 1);
+ BlobProto blob_proto;
+ ReadProtoFromBinaryFile(this->layer_param_.source(), &blob_proto);
+ infogain_.FromProto(blob_proto);
+ CHECK_EQ(infogain_.num(), 1);
+ CHECK_EQ(infogain_.channels(), 1);
+ CHECK_EQ(infogain_.height(), infogain_.width());
+};
+
+
+template <typename Dtype>
+Dtype InfogainLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down,
+ vector<Blob<Dtype>*>* bottom) {
+ const Dtype* bottom_data = (*bottom)[0]->cpu_data();
+ const Dtype* bottom_label = (*bottom)[1]->cpu_data();
+ const Dtype* infogain_mat = infogain_.cpu_data();
+ Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
+ int num = (*bottom)[0]->num();
+ int dim = (*bottom)[0]->count() / (*bottom)[0]->num();
+ CHECK_EQ(infogain_.height(), dim);
+ Dtype loss = 0;
+ for (int i = 0; i < num; ++i) {
+ int label = static_cast<int>(bottom_label[i]);
+ for (int j = 0; j < dim; ++j) {
+ Dtype prob = max(bottom_data[i * dim + j], kLOG_THRESHOLD);
+ loss -= infogain_mat[label * dim + j] * log(prob);
+ bottom_diff[i * dim + j] = - infogain_mat[label * dim + j] / prob / num;
+ }
+ }
+ return loss / num;
+}
+
+
+template <typename Dtype>
void EuclideanLossLayer<Dtype>::SetUp(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
CHECK_EQ(bottom.size(), 2) << "Loss Layer takes two blobs as input.";
@@ -122,6 +166,7 @@ void AccuracyLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
}
INSTANTIATE_CLASS(MultinomialLogisticLossLayer);
+INSTANTIATE_CLASS(InfogainLossLayer);
INSTANTIATE_CLASS(EuclideanLossLayer);
INSTANTIATE_CLASS(AccuracyLayer);
diff --git a/src/caffe/layers/pooling_layer.cu b/src/caffe/layers/pooling_layer.cu
index 1cbb4abe..4fd326cb 100644
--- a/src/caffe/layers/pooling_layer.cu
+++ b/src/caffe/layers/pooling_layer.cu
@@ -120,7 +120,8 @@ __global__ void StoPoolForwardTest(const int nthreads,
int hend = min(hstart + ksize, height);
int wstart = pw * stride;
int wend = min(wstart + ksize, width);
- Dtype cumsum = 0.;
+ // We set cumsum to be 0 to avoid divide-by-zero problems
+ Dtype cumsum = FLT_MIN;
Dtype cumvalues = 0.;
bottom_data += (n * channels + c) * height * width;
// First pass: get sum
diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp
index dec42036..3266064d 100644
--- a/src/caffe/net.cpp
+++ b/src/caffe/net.cpp
@@ -287,7 +287,7 @@ void Net<Dtype>::CopyTrainedLayersFrom(const NetParameter& param) {
DLOG(INFO) << "Ignoring source layer " << source_layer_name;
continue;
}
- DLOG(INFO) << "Loading source layer " << source_layer_name;
+ LOG(INFO) << "Copying source layer " << source_layer_name;
vector<shared_ptr<Blob<Dtype> > >& target_blobs =
layers_[target_layer_id]->blobs();
CHECK_EQ(target_blobs.size(), source_layer.blobs_size())
diff --git a/src/caffe/util/io.cpp b/src/caffe/util/io.cpp
index c4682f52..ced807ae 100644
--- a/src/caffe/util/io.cpp
+++ b/src/caffe/util/io.cpp
@@ -7,6 +7,7 @@
#include <google/protobuf/io/coded_stream.h>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
+#include <opencv2/imgproc/imgproc.hpp>
#include <algorithm>
#include <string>