diff options
-rw-r--r-- | include/caffe/layers/batch_norm_layer.hpp | 6 | ||||
-rw-r--r-- | src/caffe/layers/batch_norm_layer.cpp | 8 |
2 files changed, 10 insertions, 4 deletions
diff --git a/include/caffe/layers/batch_norm_layer.hpp b/include/caffe/layers/batch_norm_layer.hpp index a26ad1a4..43f7b28b 100644 --- a/include/caffe/layers/batch_norm_layer.hpp +++ b/include/caffe/layers/batch_norm_layer.hpp @@ -22,10 +22,8 @@ namespace caffe { * mean/variance statistics via a running average, which is then used at test * time to allow deterministic outputs for each input. You can manually toggle * whether the network is accumulating or using the statistics via the - * use_global_stats option. IMPORTANT: for this feature to work, you MUST set - * the learning rate to zero for all three blobs, i.e., param {lr_mult: 0} three - * times in the layer definition. For reference, these three blobs are (0) - * mean, (1) variance, and (2) the moving average factor. + * use_global_stats option. For reference, these statistics are kept in the + * layer's three blobs: (0) mean, (1) variance, and (2) moving average factor. * * Note that the original paper also included a per-channel learned bias and * scaling factor. To implement this in Caffe, define a `ScaleLayer` configured diff --git a/src/caffe/layers/batch_norm_layer.cpp b/src/caffe/layers/batch_norm_layer.cpp index a69d8f99..0b1037ed 100644 --- a/src/caffe/layers/batch_norm_layer.cpp +++ b/src/caffe/layers/batch_norm_layer.cpp @@ -34,6 +34,14 @@ void BatchNormLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom, this->blobs_[i]->mutable_cpu_data()); } } + // Mask statistics from optimization by setting local learning rates + // for mean, variance, and the bias correction to zero. + CHECK_EQ(this->layer_param_.param_size(), 0) + << "Cannot configure batch normalization statistics as layer parameters."; + for (int i = 0; i < this->blobs_.size(); ++i) { + ParamSpec* fixed_param_spec = this->layer_param_.add_param(); + fixed_param_spec->set_lr_mult(0.); + } } template <typename Dtype> |