summaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorCarl Doersch <cdoersch@cs.cmu.edu>2015-11-06 14:41:30 -0800
committerCarl Doersch <cdoersch@cs.cmu.edu>2015-11-22 14:47:10 -0800
commit8b2aa7093cba002a5f286d47658de72a961d1299 (patch)
treefaacccd7e3a9c016daeaffb3e6e1d823d23dda09 /include
parent0ec116e39c1433feaf9756cd2651c51d810fcbc6 (diff)
downloadcaffeonacl-8b2aa7093cba002a5f286d47658de72a961d1299.tar.gz
caffeonacl-8b2aa7093cba002a5f286d47658de72a961d1299.tar.bz2
caffeonacl-8b2aa7093cba002a5f286d47658de72a961d1299.zip
Better normalization options for SoftmaxWithLoss layer.
Diffstat (limited to 'include')
-rw-r--r--include/caffe/loss_layers.hpp11
1 files changed, 8 insertions, 3 deletions
diff --git a/include/caffe/loss_layers.hpp b/include/caffe/loss_layers.hpp
index d08ad9b6..d6569c4a 100644
--- a/include/caffe/loss_layers.hpp
+++ b/include/caffe/loss_layers.hpp
@@ -747,6 +747,12 @@ class SoftmaxWithLossLayer : public LossLayer<Dtype> {
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+ /// Read the normalization mode parameter and compute the normalizer based
+ /// on the blob size. If normalization_mode is VALID, the count of valid
+ /// outputs will be read from valid_count, unless it is -1 in which case
+ /// all outputs are assumed to be valid.
+ virtual Dtype get_normalizer(
+ LossParameter_NormalizationMode normalization_mode, int valid_count);
/// The internal SoftmaxLayer used to map predictions to a distribution.
shared_ptr<Layer<Dtype> > softmax_layer_;
@@ -760,9 +766,8 @@ class SoftmaxWithLossLayer : public LossLayer<Dtype> {
bool has_ignore_label_;
/// The label indicating that an instance should be ignored.
int ignore_label_;
- /// Whether to normalize the loss by the total number of values present
- /// (otherwise just by the batch size).
- bool normalize_;
+ /// How to normalize the output loss.
+ LossParameter_NormalizationMode normalization_;
int softmax_axis_, outer_num_, inner_num_;
};