diff options
author | Jeff Donahue <jeff.donahue@gmail.com> | 2015-08-26 11:42:38 -0700 |
---|---|---|
committer | Jeff Donahue <jeff.donahue@gmail.com> | 2015-08-26 11:42:38 -0700 |
commit | 215bea0624a64e5de8f794b60530ba4a00496f31 (patch) | |
tree | d7362e1170bfba3cc76e5b855b9028ab9ef54775 /src | |
parent | 93b48356dce1185a9f746dff0563009f217c0576 (diff) | |
parent | 1f3f9529df6285a5be5f8e72bd1922a6a0cec4d8 (diff) | |
download | caffeonacl-215bea0624a64e5de8f794b60530ba4a00496f31.tar.gz caffeonacl-215bea0624a64e5de8f794b60530ba4a00496f31.tar.bz2 caffeonacl-215bea0624a64e5de8f794b60530ba4a00496f31.zip |
Merge pull request #2964 from jyegerlehner/mvn-layer-fixes
Fix MVNLayer
Diffstat (limited to 'src')
-rw-r--r-- | src/caffe/layers/mvn_layer.cpp | 15 | ||||
-rw-r--r-- | src/caffe/layers/mvn_layer.cu | 7 | ||||
-rw-r--r-- | src/caffe/test/test_mvn_layer.cpp | 13 |
3 files changed, 27 insertions, 8 deletions
diff --git a/src/caffe/layers/mvn_layer.cpp b/src/caffe/layers/mvn_layer.cpp index 3e79bddc..325691b1 100644 --- a/src/caffe/layers/mvn_layer.cpp +++ b/src/caffe/layers/mvn_layer.cpp @@ -18,8 +18,12 @@ void MVNLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom, 1, 1); temp_.Reshape(bottom[0]->num(), bottom[0]->channels(), bottom[0]->height(), bottom[0]->width()); - sum_multiplier_.Reshape(1, 1, - bottom[0]->height(), bottom[0]->width()); + if ( this->layer_param_.mvn_param().across_channels() ) { + sum_multiplier_.Reshape(1, bottom[0]->channels(), bottom[0]->height(), + bottom[0]->width()); + } else { + sum_multiplier_.Reshape(1, 1, bottom[0]->height(), bottom[0]->width()); + } Dtype* multiplier_data = sum_multiplier_.mutable_cpu_data(); caffe_set(sum_multiplier_.count(), Dtype(1), multiplier_data); eps_ = this->layer_param_.mvn_param().eps(); @@ -130,7 +134,12 @@ void MVNLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top, caffe_div(temp_.count(), bottom_diff, temp_.cpu_data(), bottom_diff); } else { - caffe_copy(temp_.count(), top_diff, bottom_diff); + caffe_cpu_gemv<Dtype>(CblasNoTrans, num, dim, 1. / dim, top_diff, + sum_multiplier_.cpu_data(), 0., mean_.mutable_cpu_data()); + caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, dim, 1, -1., + mean_.cpu_data(), sum_multiplier_.cpu_data(), 0., + temp_.mutable_cpu_data()); + caffe_add(temp_.count(), top_diff, temp_.cpu_data(), bottom_diff); } } diff --git a/src/caffe/layers/mvn_layer.cu b/src/caffe/layers/mvn_layer.cu index 3888a0c7..d86a2e73 100644 --- a/src/caffe/layers/mvn_layer.cu +++ b/src/caffe/layers/mvn_layer.cu @@ -113,7 +113,12 @@ void MVNLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top, caffe_gpu_div(temp_.count(), bottom_diff, temp_.gpu_data(), bottom_diff); } else { - caffe_copy(temp_.count(), top_diff, bottom_diff); + caffe_gpu_gemv<Dtype>(CblasNoTrans, num, dim, 1. / dim, top_diff, + sum_multiplier_.gpu_data(), 0., mean_.mutable_gpu_data()); + caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, dim, 1, -1., + mean_.gpu_data(), sum_multiplier_.gpu_data(), 0., + temp_.mutable_gpu_data()); + caffe_gpu_add(temp_.count(), top_diff, temp_.gpu_data(), bottom_diff); } } diff --git a/src/caffe/test/test_mvn_layer.cpp b/src/caffe/test/test_mvn_layer.cpp index 933b4326..be23d86e 100644 --- a/src/caffe/test/test_mvn_layer.cpp +++ b/src/caffe/test/test_mvn_layer.cpp @@ -6,6 +6,7 @@ #include "caffe/common.hpp" #include "caffe/common_layers.hpp" #include "caffe/filler.hpp" +#include "google/protobuf/text_format.h" #include "gtest/gtest.h" #include "caffe/test/test_caffe_main.hpp" @@ -73,7 +74,8 @@ TYPED_TEST(MVNLayerTest, TestForward) { TYPED_TEST(MVNLayerTest, TestForwardMeanOnly) { typedef typename TypeParam::Dtype Dtype; LayerParameter layer_param; - layer_param.ParseFromString("mvn_param{normalize_variance: false}"); + CHECK(google::protobuf::TextFormat::ParseFromString( + "mvn_param{normalize_variance: false}", &layer_param)); MVNLayer<Dtype> layer(layer_param); layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); @@ -105,7 +107,8 @@ TYPED_TEST(MVNLayerTest, TestForwardMeanOnly) { TYPED_TEST(MVNLayerTest, TestForwardAcrossChannels) { typedef typename TypeParam::Dtype Dtype; LayerParameter layer_param; - layer_param.ParseFromString("mvn_param{across_channels: true}"); + CHECK(google::protobuf::TextFormat::ParseFromString( + "mvn_param{across_channels: true}", &layer_param)); MVNLayer<Dtype> layer(layer_param); layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); @@ -149,7 +152,8 @@ TYPED_TEST(MVNLayerTest, TestGradient) { TYPED_TEST(MVNLayerTest, TestGradientMeanOnly) { typedef typename TypeParam::Dtype Dtype; LayerParameter layer_param; - layer_param.ParseFromString("mvn_param{normalize_variance: false}"); + CHECK(google::protobuf::TextFormat::ParseFromString( + "mvn_param{normalize_variance: false}", &layer_param)); MVNLayer<Dtype> layer(layer_param); GradientChecker<Dtype> checker(1e-2, 1e-3); checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, @@ -159,7 +163,8 @@ TYPED_TEST(MVNLayerTest, TestGradientMeanOnly) { TYPED_TEST(MVNLayerTest, TestGradientAcrossChannels) { typedef typename TypeParam::Dtype Dtype; LayerParameter layer_param; - layer_param.ParseFromString("mvn_param{across_channels: true}"); + CHECK(google::protobuf::TextFormat::ParseFromString( + "mvn_param{across_channels: true}", &layer_param)); MVNLayer<Dtype> layer(layer_param); GradientChecker<Dtype> checker(1e-2, 1e-3); checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, |