summaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorRan <ran.manor@gmail.com>2015-08-15 20:09:43 +0300
committerRan <ran.manor@gmail.com>2015-08-23 00:30:45 +0300
commit374fb8c79c3f23ee36c46d0bcaeb2176037aa4b8 (patch)
tree6b2e997cba5f688dba7334a188783849b1b19193 /include
parent25b4920cd280d0647cf556cf4a06efd4fa5c4280 (diff)
downloadcaffeonacl-374fb8c79c3f23ee36c46d0bcaeb2176037aa4b8.tar.gz
caffeonacl-374fb8c79c3f23ee36c46d0bcaeb2176037aa4b8.tar.bz2
caffeonacl-374fb8c79c3f23ee36c46d0bcaeb2176037aa4b8.zip
Output accuracies per class.
Fixed case where number of samples in class can be zero. - Fixed ignore_label case, also added a test. - Two other fixes. Fixed lint errors. Small fix.
Diffstat (limited to 'include')
-rw-r--r--include/caffe/loss_layers.hpp8
1 files changed, 7 insertions, 1 deletions
diff --git a/include/caffe/loss_layers.hpp b/include/caffe/loss_layers.hpp
index 52826639..02687a94 100644
--- a/include/caffe/loss_layers.hpp
+++ b/include/caffe/loss_layers.hpp
@@ -39,7 +39,11 @@ class AccuracyLayer : public Layer<Dtype> {
virtual inline const char* type() const { return "Accuracy"; }
virtual inline int ExactNumBottomBlobs() const { return 2; }
- virtual inline int ExactNumTopBlobs() const { return 1; }
+
+ // If there are two top blobs, then the second blob will contain
+ // accuracies per class.
+ virtual inline int MinTopBlobs() const { return 1; }
+ virtual inline int MaxTopBlos() const { return 2; }
protected:
/**
@@ -86,6 +90,8 @@ class AccuracyLayer : public Layer<Dtype> {
bool has_ignore_label_;
/// The label indicating that an instance should be ignored.
int ignore_label_;
+ /// Keeps counts of the number of samples per class.
+ Blob<Dtype> nums_buffer_;
};
/**