diff options
author | Carl Doersch <cdoersch@cs.cmu.edu> | 2015-11-06 14:41:30 -0800 |
---|---|---|
committer | Carl Doersch <cdoersch@cs.cmu.edu> | 2015-11-22 14:47:10 -0800 |
commit | 8b2aa7093cba002a5f286d47658de72a961d1299 (patch) | |
tree | faacccd7e3a9c016daeaffb3e6e1d823d23dda09 /include | |
parent | 0ec116e39c1433feaf9756cd2651c51d810fcbc6 (diff) | |
download | caffeonacl-8b2aa7093cba002a5f286d47658de72a961d1299.tar.gz caffeonacl-8b2aa7093cba002a5f286d47658de72a961d1299.tar.bz2 caffeonacl-8b2aa7093cba002a5f286d47658de72a961d1299.zip |
Better normalization options for SoftmaxWithLoss layer.
Diffstat (limited to 'include')
-rw-r--r-- | include/caffe/loss_layers.hpp | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/include/caffe/loss_layers.hpp b/include/caffe/loss_layers.hpp index d08ad9b6..d6569c4a 100644 --- a/include/caffe/loss_layers.hpp +++ b/include/caffe/loss_layers.hpp @@ -747,6 +747,12 @@ class SoftmaxWithLossLayer : public LossLayer<Dtype> { virtual void Backward_gpu(const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom); + /// Read the normalization mode parameter and compute the normalizer based + /// on the blob size. If normalization_mode is VALID, the count of valid + /// outputs will be read from valid_count, unless it is -1 in which case + /// all outputs are assumed to be valid. + virtual Dtype get_normalizer( + LossParameter_NormalizationMode normalization_mode, int valid_count); /// The internal SoftmaxLayer used to map predictions to a distribution. shared_ptr<Layer<Dtype> > softmax_layer_; @@ -760,9 +766,8 @@ class SoftmaxWithLossLayer : public LossLayer<Dtype> { bool has_ignore_label_; /// The label indicating that an instance should be ignored. int ignore_label_; - /// Whether to normalize the loss by the total number of values present - /// (otherwise just by the batch size). - bool normalize_; + /// How to normalize the output loss. + LossParameter_NormalizationMode normalization_; int softmax_axis_, outer_num_, inner_num_; }; |