diff options
author | Dmytro Mishkin <ducha.aiki@gmail.com> | 2015-02-25 17:00:22 +0200 |
---|---|---|
committer | Carl Doersch <cdoersch@cs.cmu.edu> | 2015-10-20 21:04:08 -0700 |
commit | 2f05b03371e5936a478c7ad2946d0cd7c013920c (patch) | |
tree | ef569944b7bebf611e6e68df17146bd91e0bcd96 /include | |
parent | 8c8e832e71985ba89dcb7c8a60697322c54b5f5b (diff) | |
download | caffeonacl-2f05b03371e5936a478c7ad2946d0cd7c013920c.tar.gz caffeonacl-2f05b03371e5936a478c7ad2946d0cd7c013920c.tar.bz2 caffeonacl-2f05b03371e5936a478c7ad2946d0cd7c013920c.zip |
Added batch normalization layer with test and examples
Diffstat (limited to 'include')
-rw-r--r-- | include/caffe/common_layers.hpp | 50 |
1 files changed, 49 insertions, 1 deletions
diff --git a/include/caffe/common_layers.hpp b/include/caffe/common_layers.hpp index 21a27d75..09605db9 100644 --- a/include/caffe/common_layers.hpp +++ b/include/caffe/common_layers.hpp @@ -79,6 +79,55 @@ class ArgMaxLayer : public Layer<Dtype> { }; /** +* @brief Batch Normalization per-channel with scale & shift linear transform. +* +*/ +template <typename Dtype> +class BatchNormLayer : public Layer<Dtype> { + public: + explicit BatchNormLayer(const LayerParameter& param) + : Layer<Dtype>(param) {} + virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top); + + virtual void Reshape(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top); + + virtual inline const char* type() const { return "BN"; } + virtual inline int ExactNumBottomBlobs() const { return 1; } + virtual inline int ExactNumTopBlobs() const { return 1; } + + protected: + virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top); + virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top); + virtual void Backward_cpu(const vector<Blob<Dtype>*>& top, + const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom); + virtual void Backward_gpu(const vector<Blob<Dtype>*>& top, + const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom); + + // spatial mean & variance + Blob<Dtype> spatial_mean_, spatial_variance_; + // batch mean & variance + Blob<Dtype> batch_mean_, batch_variance_; + // buffer blob + Blob<Dtype> buffer_blob_; + + Blob<Dtype> x_norm_; + // x_sum_multiplier is used to carry out sum using BLAS + Blob<Dtype> spatial_sum_multiplier_, batch_sum_multiplier_; + + // dimension + int N_; + int C_; + int H_; + int W_; + // eps + Dtype var_eps_; +}; + +/** * @brief Index into the input blob along its first axis. * * This layer can be used to select, reorder, and even replicate examples in a @@ -146,7 +195,6 @@ class BatchReindexLayer : public Layer<Dtype> { const Dtype* ridx_data); }; - /** * @brief Takes at least two Blob%s and concatenates them along either the num * or channel dimension, outputting the result. |