summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJeff Donahue <jeff.donahue@gmail.com>2015-08-26 11:42:38 -0700
committerJeff Donahue <jeff.donahue@gmail.com>2015-08-26 11:42:38 -0700
commit215bea0624a64e5de8f794b60530ba4a00496f31 (patch)
treed7362e1170bfba3cc76e5b855b9028ab9ef54775 /src
parent93b48356dce1185a9f746dff0563009f217c0576 (diff)
parent1f3f9529df6285a5be5f8e72bd1922a6a0cec4d8 (diff)
downloadcaffeonacl-215bea0624a64e5de8f794b60530ba4a00496f31.tar.gz
caffeonacl-215bea0624a64e5de8f794b60530ba4a00496f31.tar.bz2
caffeonacl-215bea0624a64e5de8f794b60530ba4a00496f31.zip
Merge pull request #2964 from jyegerlehner/mvn-layer-fixes
Fix MVNLayer
Diffstat (limited to 'src')
-rw-r--r--src/caffe/layers/mvn_layer.cpp15
-rw-r--r--src/caffe/layers/mvn_layer.cu7
-rw-r--r--src/caffe/test/test_mvn_layer.cpp13
3 files changed, 27 insertions, 8 deletions
diff --git a/src/caffe/layers/mvn_layer.cpp b/src/caffe/layers/mvn_layer.cpp
index 3e79bddc..325691b1 100644
--- a/src/caffe/layers/mvn_layer.cpp
+++ b/src/caffe/layers/mvn_layer.cpp
@@ -18,8 +18,12 @@ void MVNLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
1, 1);
temp_.Reshape(bottom[0]->num(), bottom[0]->channels(),
bottom[0]->height(), bottom[0]->width());
- sum_multiplier_.Reshape(1, 1,
- bottom[0]->height(), bottom[0]->width());
+ if ( this->layer_param_.mvn_param().across_channels() ) {
+ sum_multiplier_.Reshape(1, bottom[0]->channels(), bottom[0]->height(),
+ bottom[0]->width());
+ } else {
+ sum_multiplier_.Reshape(1, 1, bottom[0]->height(), bottom[0]->width());
+ }
Dtype* multiplier_data = sum_multiplier_.mutable_cpu_data();
caffe_set(sum_multiplier_.count(), Dtype(1), multiplier_data);
eps_ = this->layer_param_.mvn_param().eps();
@@ -130,7 +134,12 @@ void MVNLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
caffe_div(temp_.count(), bottom_diff, temp_.cpu_data(), bottom_diff);
} else {
- caffe_copy(temp_.count(), top_diff, bottom_diff);
+ caffe_cpu_gemv<Dtype>(CblasNoTrans, num, dim, 1. / dim, top_diff,
+ sum_multiplier_.cpu_data(), 0., mean_.mutable_cpu_data());
+ caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, dim, 1, -1.,
+ mean_.cpu_data(), sum_multiplier_.cpu_data(), 0.,
+ temp_.mutable_cpu_data());
+ caffe_add(temp_.count(), top_diff, temp_.cpu_data(), bottom_diff);
}
}
diff --git a/src/caffe/layers/mvn_layer.cu b/src/caffe/layers/mvn_layer.cu
index 3888a0c7..d86a2e73 100644
--- a/src/caffe/layers/mvn_layer.cu
+++ b/src/caffe/layers/mvn_layer.cu
@@ -113,7 +113,12 @@ void MVNLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
caffe_gpu_div(temp_.count(), bottom_diff, temp_.gpu_data(), bottom_diff);
} else {
- caffe_copy(temp_.count(), top_diff, bottom_diff);
+ caffe_gpu_gemv<Dtype>(CblasNoTrans, num, dim, 1. / dim, top_diff,
+ sum_multiplier_.gpu_data(), 0., mean_.mutable_gpu_data());
+ caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, dim, 1, -1.,
+ mean_.gpu_data(), sum_multiplier_.gpu_data(), 0.,
+ temp_.mutable_gpu_data());
+ caffe_gpu_add(temp_.count(), top_diff, temp_.gpu_data(), bottom_diff);
}
}
diff --git a/src/caffe/test/test_mvn_layer.cpp b/src/caffe/test/test_mvn_layer.cpp
index 933b4326..be23d86e 100644
--- a/src/caffe/test/test_mvn_layer.cpp
+++ b/src/caffe/test/test_mvn_layer.cpp
@@ -6,6 +6,7 @@
#include "caffe/common.hpp"
#include "caffe/common_layers.hpp"
#include "caffe/filler.hpp"
+#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "caffe/test/test_caffe_main.hpp"
@@ -73,7 +74,8 @@ TYPED_TEST(MVNLayerTest, TestForward) {
TYPED_TEST(MVNLayerTest, TestForwardMeanOnly) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
- layer_param.ParseFromString("mvn_param{normalize_variance: false}");
+ CHECK(google::protobuf::TextFormat::ParseFromString(
+ "mvn_param{normalize_variance: false}", &layer_param));
MVNLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
@@ -105,7 +107,8 @@ TYPED_TEST(MVNLayerTest, TestForwardMeanOnly) {
TYPED_TEST(MVNLayerTest, TestForwardAcrossChannels) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
- layer_param.ParseFromString("mvn_param{across_channels: true}");
+ CHECK(google::protobuf::TextFormat::ParseFromString(
+ "mvn_param{across_channels: true}", &layer_param));
MVNLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
@@ -149,7 +152,8 @@ TYPED_TEST(MVNLayerTest, TestGradient) {
TYPED_TEST(MVNLayerTest, TestGradientMeanOnly) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
- layer_param.ParseFromString("mvn_param{normalize_variance: false}");
+ CHECK(google::protobuf::TextFormat::ParseFromString(
+ "mvn_param{normalize_variance: false}", &layer_param));
MVNLayer<Dtype> layer(layer_param);
GradientChecker<Dtype> checker(1e-2, 1e-3);
checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
@@ -159,7 +163,8 @@ TYPED_TEST(MVNLayerTest, TestGradientMeanOnly) {
TYPED_TEST(MVNLayerTest, TestGradientAcrossChannels) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
- layer_param.ParseFromString("mvn_param{across_channels: true}");
+ CHECK(google::protobuf::TextFormat::ParseFromString(
+ "mvn_param{across_channels: true}", &layer_param));
MVNLayer<Dtype> layer(layer_param);
GradientChecker<Dtype> checker(1e-2, 1e-3);
checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,