summaryrefslogtreecommitdiff
path: root/src/caffe/layers/loss_layer.cpp
diff options
context:
space:
mode:
authorSergey Karayev <sergeykarayev@gmail.com>2014-04-28 19:39:36 -0700
committerSergey Karayev <sergeykarayev@gmail.com>2014-05-16 10:58:38 -0700
commit48a8a64c9b1596f60d6eabff4c0df887a3ea53bf (patch)
tree0fe7e4e72fdd9f81906a55dd8026bd21d9f72346 /src/caffe/layers/loss_layer.cpp
parent2779ce9e0e121a7bec6de90ca304a316cc150cf2 (diff)
downloadcaffeonacl-48a8a64c9b1596f60d6eabff4c0df887a3ea53bf.tar.gz
caffeonacl-48a8a64c9b1596f60d6eabff4c0df887a3ea53bf.tar.bz2
caffeonacl-48a8a64c9b1596f60d6eabff4c0df887a3ea53bf.zip
Split all loss layers into own .cpp files
Diffstat (limited to 'src/caffe/layers/loss_layer.cpp')
-rw-r--r--src/caffe/layers/loss_layer.cpp230
1 files changed, 1 insertions, 229 deletions
diff --git a/src/caffe/layers/loss_layer.cpp b/src/caffe/layers/loss_layer.cpp
index 3fc34a6d..1efb6235 100644
--- a/src/caffe/layers/loss_layer.cpp
+++ b/src/caffe/layers/loss_layer.cpp
@@ -14,8 +14,6 @@ using std::max;
namespace caffe {
-const float kLOG_THRESHOLD = 1e-20;
-
template <typename Dtype>
void LossLayer<Dtype>::SetUp(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
@@ -26,232 +24,6 @@ void LossLayer<Dtype>::SetUp(
FurtherSetUp(bottom, top);
}
-template <typename Dtype>
-void LossLayer<Dtype>::FurtherSetUp(
- const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
- // Nothing to do
-}
-
-template <typename Dtype>
-void MultinomialLogisticLossLayer<Dtype>::FurtherSetUp(
- const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
- CHECK_EQ(bottom[1]->channels(), 1);
- CHECK_EQ(bottom[1]->height(), 1);
- CHECK_EQ(bottom[1]->width(), 1);
-}
-
-template <typename Dtype>
-Dtype MultinomialLogisticLossLayer<Dtype>::Forward_cpu(
- const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
- const Dtype* bottom_data = bottom[0]->cpu_data();
- const Dtype* bottom_label = bottom[1]->cpu_data();
- int num = bottom[0]->num();
- int dim = bottom[0]->count() / bottom[0]->num();
- Dtype loss = 0;
- for (int i = 0; i < num; ++i) {
- int label = static_cast<int>(bottom_label[i]);
- Dtype prob = max(bottom_data[i * dim + label], Dtype(kLOG_THRESHOLD));
- loss -= log(prob);
- }
- return loss / num;
-}
-
-template <typename Dtype>
-void MultinomialLogisticLossLayer<Dtype>::Backward_cpu(
- const vector<Blob<Dtype>*>& top, const bool propagate_down,
- vector<Blob<Dtype>*>* bottom) {
- const Dtype* bottom_data = (*bottom)[0]->cpu_data();
- const Dtype* bottom_label = (*bottom)[1]->cpu_data();
- Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
- int num = (*bottom)[0]->num();
- int dim = (*bottom)[0]->count() / (*bottom)[0]->num();
- memset(bottom_diff, 0, sizeof(Dtype) * (*bottom)[0]->count());
- for (int i = 0; i < num; ++i) {
- int label = static_cast<int>(bottom_label[i]);
- Dtype prob = max(bottom_data[i * dim + label], Dtype(kLOG_THRESHOLD));
- bottom_diff[i * dim + label] = -1. / prob / num;
- }
-}
-
-
-template <typename Dtype>
-void InfogainLossLayer<Dtype>::FurtherSetUp(
- const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
- CHECK_EQ(bottom[1]->channels(), 1);
- CHECK_EQ(bottom[1]->height(), 1);
- CHECK_EQ(bottom[1]->width(), 1);
-
- BlobProto blob_proto;
- ReadProtoFromBinaryFile(
- this->layer_param_.infogain_loss_param().source(), &blob_proto);
- infogain_.FromProto(blob_proto);
- CHECK_EQ(infogain_.num(), 1);
- CHECK_EQ(infogain_.channels(), 1);
- CHECK_EQ(infogain_.height(), infogain_.width());
-}
-
-
-template <typename Dtype>
-Dtype InfogainLossLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
- vector<Blob<Dtype>*>* top) {
- const Dtype* bottom_data = bottom[0]->cpu_data();
- const Dtype* bottom_label = bottom[1]->cpu_data();
- const Dtype* infogain_mat = infogain_.cpu_data();
- int num = bottom[0]->num();
- int dim = bottom[0]->count() / bottom[0]->num();
- CHECK_EQ(infogain_.height(), dim);
- Dtype loss = 0;
- for (int i = 0; i < num; ++i) {
- int label = static_cast<int>(bottom_label[i]);
- for (int j = 0; j < dim; ++j) {
- Dtype prob = max(bottom_data[i * dim + j], Dtype(kLOG_THRESHOLD));
- loss -= infogain_mat[label * dim + j] * log(prob);
- }
- }
- return loss / num;
-}
-
-template <typename Dtype>
-void InfogainLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
- const bool propagate_down,
- vector<Blob<Dtype>*>* bottom) {
- const Dtype* bottom_data = (*bottom)[0]->cpu_data();
- const Dtype* bottom_label = (*bottom)[1]->cpu_data();
- const Dtype* infogain_mat = infogain_.cpu_data();
- Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
- int num = (*bottom)[0]->num();
- int dim = (*bottom)[0]->count() / (*bottom)[0]->num();
- CHECK_EQ(infogain_.height(), dim);
- for (int i = 0; i < num; ++i) {
- int label = static_cast<int>(bottom_label[i]);
- for (int j = 0; j < dim; ++j) {
- Dtype prob = max(bottom_data[i * dim + j], Dtype(kLOG_THRESHOLD));
- bottom_diff[i * dim + j] = - infogain_mat[label * dim + j] / prob / num;
- }
- }
-}
-
-
-template <typename Dtype>
-void EuclideanLossLayer<Dtype>::FurtherSetUp(
- const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
- CHECK_EQ(bottom[0]->channels(), bottom[1]->channels());
- CHECK_EQ(bottom[0]->height(), bottom[1]->height());
- CHECK_EQ(bottom[0]->width(), bottom[1]->width());
- difference_.Reshape(bottom[0]->num(), bottom[0]->channels(),
- bottom[0]->height(), bottom[0]->width());
-}
-
-template <typename Dtype>
-Dtype EuclideanLossLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
- vector<Blob<Dtype>*>* top) {
- int count = bottom[0]->count();
- int num = bottom[0]->num();
- caffe_sub(count, bottom[0]->cpu_data(), bottom[1]->cpu_data(),
- difference_.mutable_cpu_data());
- Dtype loss = caffe_cpu_dot(
- count, difference_.cpu_data(), difference_.cpu_data()) / num / Dtype(2);
- return loss;
-}
-
-template <typename Dtype>
-void EuclideanLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
- const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
- int count = (*bottom)[0]->count();
- int num = (*bottom)[0]->num();
- // Compute the gradient
- caffe_cpu_axpby(count, Dtype(1) / num, difference_.cpu_data(), Dtype(0),
- (*bottom)[0]->mutable_cpu_diff());
-}
-
-template <typename Dtype>
-Dtype HingeLossLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
- vector<Blob<Dtype>*>* top) {
- const Dtype* bottom_data = bottom[0]->cpu_data();
- Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
- const Dtype* label = bottom[1]->cpu_data();
- int num = bottom[0]->num();
- int count = bottom[0]->count();
- int dim = count / num;
-
- caffe_copy(count, bottom_data, bottom_diff);
- for (int i = 0; i < num; ++i) {
- bottom_diff[i * dim + static_cast<int>(label[i])] *= -1;
- }
- for (int i = 0; i < num; ++i) {
- for (int j = 0; j < dim; ++j) {
- bottom_diff[i * dim + j] = max(Dtype(0), 1 + bottom_diff[i * dim + j]);
- }
- }
- return caffe_cpu_asum(count, bottom_diff) / num;
-}
-
-template <typename Dtype>
-void HingeLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
- const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
- Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
- const Dtype* label = (*bottom)[1]->cpu_data();
- int num = (*bottom)[0]->num();
- int count = (*bottom)[0]->count();
- int dim = count / num;
-
- caffe_cpu_sign(count, bottom_diff, bottom_diff);
- for (int i = 0; i < num; ++i) {
- bottom_diff[i * dim + static_cast<int>(label[i])] *= -1;
- }
- caffe_scal(count, Dtype(1. / num), bottom_diff);
-}
-
-template <typename Dtype>
-void AccuracyLayer<Dtype>::SetUp(
- const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
- CHECK_EQ(bottom.size(), 2) << "Accuracy Layer takes two blobs as input.";
- CHECK_EQ(top->size(), 1) << "Accuracy Layer takes 1 output.";
- CHECK_EQ(bottom[0]->num(), bottom[1]->num())
- << "The data and label should have the same number.";
- CHECK_EQ(bottom[1]->channels(), 1);
- CHECK_EQ(bottom[1]->height(), 1);
- CHECK_EQ(bottom[1]->width(), 1);
- (*top)[0]->Reshape(1, 2, 1, 1);
-}
-
-template <typename Dtype>
-Dtype AccuracyLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
- vector<Blob<Dtype>*>* top) {
- Dtype accuracy = 0;
- Dtype logprob = 0;
- const Dtype* bottom_data = bottom[0]->cpu_data();
- const Dtype* bottom_label = bottom[1]->cpu_data();
- int num = bottom[0]->num();
- int dim = bottom[0]->count() / bottom[0]->num();
- for (int i = 0; i < num; ++i) {
- // Accuracy
- Dtype maxval = -FLT_MAX;
- int max_id = 0;
- for (int j = 0; j < dim; ++j) {
- if (bottom_data[i * dim + j] > maxval) {
- maxval = bottom_data[i * dim + j];
- max_id = j;
- }
- }
- if (max_id == static_cast<int>(bottom_label[i])) {
- ++accuracy;
- }
- Dtype prob = max(bottom_data[i * dim + static_cast<int>(bottom_label[i])],
- Dtype(kLOG_THRESHOLD));
- logprob -= log(prob);
- }
- // LOG(INFO) << "Accuracy: " << accuracy;
- (*top)[0]->mutable_cpu_data()[0] = accuracy / num;
- (*top)[0]->mutable_cpu_data()[1] = logprob / num;
- // Accuracy layer should not be used as a loss function.
- return Dtype(0);
-}
-
-INSTANTIATE_CLASS(MultinomialLogisticLossLayer);
-INSTANTIATE_CLASS(InfogainLossLayer);
-INSTANTIATE_CLASS(EuclideanLossLayer);
-INSTANTIATE_CLASS(HingeLossLayer);
-INSTANTIATE_CLASS(AccuracyLayer);
+INSTANTIATE_CLASS(LossLayer);
} // namespace caffe