summaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorJeff Donahue <jeff.donahue@gmail.com>2015-11-22 19:33:36 -0800
committerJeff Donahue <jeff.donahue@gmail.com>2015-11-22 19:33:36 -0800
commit8e8d97d6206cac99eae3c16baaa2275a14e64ca7 (patch)
tree1c7f6afb5050670155f46d14b013a56aaaf79946 /include
parentdf21b6a1c1d200d54d7917f3d78951473dbf02fb (diff)
parent8b2aa7093cba002a5f286d47658de72a961d1299 (diff)
downloadcaffeonacl-8e8d97d6206cac99eae3c16baaa2275a14e64ca7.tar.gz
caffeonacl-8e8d97d6206cac99eae3c16baaa2275a14e64ca7.tar.bz2
caffeonacl-8e8d97d6206cac99eae3c16baaa2275a14e64ca7.zip
Merge pull request #3296 from cdoersch/normalize_batch
Better normalization options for SoftmaxWithLoss layer
Diffstat (limited to 'include')
-rw-r--r--include/caffe/loss_layers.hpp11
1 files changed, 8 insertions, 3 deletions
diff --git a/include/caffe/loss_layers.hpp b/include/caffe/loss_layers.hpp
index e2e3e48c..53d07025 100644
--- a/include/caffe/loss_layers.hpp
+++ b/include/caffe/loss_layers.hpp
@@ -747,6 +747,12 @@ class SoftmaxWithLossLayer : public LossLayer<Dtype> {
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+ /// Read the normalization mode parameter and compute the normalizer based
+ /// on the blob size. If normalization_mode is VALID, the count of valid
+ /// outputs will be read from valid_count, unless it is -1 in which case
+ /// all outputs are assumed to be valid.
+ virtual Dtype get_normalizer(
+ LossParameter_NormalizationMode normalization_mode, int valid_count);
/// The internal SoftmaxLayer used to map predictions to a distribution.
shared_ptr<Layer<Dtype> > softmax_layer_;
@@ -760,9 +766,8 @@ class SoftmaxWithLossLayer : public LossLayer<Dtype> {
bool has_ignore_label_;
/// The label indicating that an instance should be ignored.
int ignore_label_;
- /// Whether to normalize the loss by the total number of values present
- /// (otherwise just by the batch size).
- bool normalize_;
+ /// How to normalize the output loss.
+ LossParameter_NormalizationMode normalization_;
int softmax_axis_, outer_num_, inner_num_;
};