summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/caffe/layers/batch_norm_layer.hpp6
-rw-r--r--src/caffe/layers/batch_norm_layer.cpp8
2 files changed, 10 insertions, 4 deletions
diff --git a/include/caffe/layers/batch_norm_layer.hpp b/include/caffe/layers/batch_norm_layer.hpp
index a26ad1a4..43f7b28b 100644
--- a/include/caffe/layers/batch_norm_layer.hpp
+++ b/include/caffe/layers/batch_norm_layer.hpp
@@ -22,10 +22,8 @@ namespace caffe {
* mean/variance statistics via a running average, which is then used at test
* time to allow deterministic outputs for each input. You can manually toggle
* whether the network is accumulating or using the statistics via the
- * use_global_stats option. IMPORTANT: for this feature to work, you MUST set
- * the learning rate to zero for all three blobs, i.e., param {lr_mult: 0} three
- * times in the layer definition. For reference, these three blobs are (0)
- * mean, (1) variance, and (2) the moving average factor.
+ * use_global_stats option. For reference, these statistics are kept in the
+ * layer's three blobs: (0) mean, (1) variance, and (2) moving average factor.
*
* Note that the original paper also included a per-channel learned bias and
* scaling factor. To implement this in Caffe, define a `ScaleLayer` configured
diff --git a/src/caffe/layers/batch_norm_layer.cpp b/src/caffe/layers/batch_norm_layer.cpp
index a69d8f99..0b1037ed 100644
--- a/src/caffe/layers/batch_norm_layer.cpp
+++ b/src/caffe/layers/batch_norm_layer.cpp
@@ -34,6 +34,14 @@ void BatchNormLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
this->blobs_[i]->mutable_cpu_data());
}
}
+ // Mask statistics from optimization by setting local learning rates
+ // for mean, variance, and the bias correction to zero.
+ CHECK_EQ(this->layer_param_.param_size(), 0)
+ << "Cannot configure batch normalization statistics as layer parameters.";
+ for (int i = 0; i < this->blobs_.size(); ++i) {
+ ParamSpec* fixed_param_spec = this->layer_param_.add_param();
+ fixed_param_spec->set_lr_mult(0.);
+ }
}
template <typename Dtype>