summaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorDmytro Mishkin <ducha.aiki@gmail.com>2015-02-25 17:00:22 +0200
committerCarl Doersch <cdoersch@cs.cmu.edu>2015-10-20 21:04:08 -0700
commit2f05b03371e5936a478c7ad2946d0cd7c013920c (patch)
treeef569944b7bebf611e6e68df17146bd91e0bcd96 /include
parent8c8e832e71985ba89dcb7c8a60697322c54b5f5b (diff)
downloadcaffeonacl-2f05b03371e5936a478c7ad2946d0cd7c013920c.tar.gz
caffeonacl-2f05b03371e5936a478c7ad2946d0cd7c013920c.tar.bz2
caffeonacl-2f05b03371e5936a478c7ad2946d0cd7c013920c.zip
Added batch normalization layer with test and examples
Diffstat (limited to 'include')
-rw-r--r--include/caffe/common_layers.hpp50
1 files changed, 49 insertions, 1 deletions
diff --git a/include/caffe/common_layers.hpp b/include/caffe/common_layers.hpp
index 21a27d75..09605db9 100644
--- a/include/caffe/common_layers.hpp
+++ b/include/caffe/common_layers.hpp
@@ -79,6 +79,55 @@ class ArgMaxLayer : public Layer<Dtype> {
};
/**
+* @brief Batch Normalization per-channel with scale & shift linear transform.
+*
+*/
+template <typename Dtype>
+class BatchNormLayer : public Layer<Dtype> {
+ public:
+ explicit BatchNormLayer(const LayerParameter& param)
+ : Layer<Dtype>(param) {}
+ virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+
+ virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+
+ virtual inline const char* type() const { return "BN"; }
+ virtual inline int ExactNumBottomBlobs() const { return 1; }
+ virtual inline int ExactNumTopBlobs() const { return 1; }
+
+ protected:
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+ virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+ virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+
+ // spatial mean & variance
+ Blob<Dtype> spatial_mean_, spatial_variance_;
+ // batch mean & variance
+ Blob<Dtype> batch_mean_, batch_variance_;
+ // buffer blob
+ Blob<Dtype> buffer_blob_;
+
+ Blob<Dtype> x_norm_;
+ // x_sum_multiplier is used to carry out sum using BLAS
+ Blob<Dtype> spatial_sum_multiplier_, batch_sum_multiplier_;
+
+ // dimension
+ int N_;
+ int C_;
+ int H_;
+ int W_;
+ // eps
+ Dtype var_eps_;
+};
+
+/**
* @brief Index into the input blob along its first axis.
*
* This layer can be used to select, reorder, and even replicate examples in a
@@ -146,7 +195,6 @@ class BatchReindexLayer : public Layer<Dtype> {
const Dtype* ridx_data);
};
-
/**
* @brief Takes at least two Blob%s and concatenates them along either the num
* or channel dimension, outputting the result.