summaryrefslogtreecommitdiff
path: root/src/caffe/layers/euclidean_loss_layer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/caffe/layers/euclidean_loss_layer.cpp')
-rw-r--r--src/caffe/layers/euclidean_loss_layer.cpp13
1 files changed, 6 insertions, 7 deletions
diff --git a/src/caffe/layers/euclidean_loss_layer.cpp b/src/caffe/layers/euclidean_loss_layer.cpp
index 17180d40..be83601f 100644
--- a/src/caffe/layers/euclidean_loss_layer.cpp
+++ b/src/caffe/layers/euclidean_loss_layer.cpp
@@ -8,8 +8,9 @@
namespace caffe {
template <typename Dtype>
-void EuclideanLossLayer<Dtype>::FurtherSetUp(
+void EuclideanLossLayer<Dtype>::LayerSetUp(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
+ LossLayer<Dtype>::LayerSetUp(bottom, top);
CHECK_EQ(bottom[0]->channels(), bottom[1]->channels());
CHECK_EQ(bottom[0]->height(), bottom[1]->height());
CHECK_EQ(bottom[0]->width(), bottom[1]->width());
@@ -18,7 +19,7 @@ void EuclideanLossLayer<Dtype>::FurtherSetUp(
}
template <typename Dtype>
-Dtype EuclideanLossLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void EuclideanLossLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
int count = bottom[0]->count();
caffe_sub(
@@ -28,10 +29,7 @@ Dtype EuclideanLossLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
diff_.mutable_cpu_data());
Dtype dot = caffe_cpu_dot(count, diff_.cpu_data(), diff_.cpu_data());
Dtype loss = dot / bottom[0]->num() / Dtype(2);
- if (top->size() == 1) {
- (*top)[0]->mutable_cpu_data()[0] = loss;
- }
- return loss;
+ (*top)[0]->mutable_cpu_data()[0] = loss;
}
template <typename Dtype>
@@ -40,9 +38,10 @@ void EuclideanLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
for (int i = 0; i < 2; ++i) {
if (propagate_down[i]) {
const Dtype sign = (i == 0) ? 1 : -1;
+ const Dtype alpha = sign * top[0]->cpu_diff()[0] / (*bottom)[i]->num();
caffe_cpu_axpby(
(*bottom)[i]->count(), // count
- sign / (*bottom)[i]->num(), // alpha
+ alpha, // alpha
diff_.cpu_data(), // a
Dtype(0), // beta
(*bottom)[i]->mutable_cpu_diff()); // b