diff options
author | Kai Li <kaili_kloud@163.com> | 2014-07-04 17:19:24 +0800 |
---|---|---|
committer | Kai Li <kaili_kloud@163.com> | 2014-07-20 00:26:40 +0800 |
commit | 7722514cddc127ba782ad97732532e6afea0db6e (patch) | |
tree | dc87011ad5497e7ed5a856a2d7981518ae4df5a0 /src/caffe/layers/argmax_layer.cpp | |
parent | 61bd7ea32470684201b0a8a9d3edea14edb302f2 (diff) | |
download | caffeonacl-7722514cddc127ba782ad97732532e6afea0db6e.tar.gz caffeonacl-7722514cddc127ba782ad97732532e6afea0db6e.tar.bz2 caffeonacl-7722514cddc127ba782ad97732532e6afea0db6e.zip |
Extend the ArgMaxLayer to output top k results
Diffstat (limited to 'src/caffe/layers/argmax_layer.cpp')
-rw-r--r-- | src/caffe/layers/argmax_layer.cpp | 54 |
1 files changed, 41 insertions, 13 deletions
diff --git a/src/caffe/layers/argmax_layer.cpp b/src/caffe/layers/argmax_layer.cpp index cc31c0f5..1fe2a89e 100644 --- a/src/caffe/layers/argmax_layer.cpp +++ b/src/caffe/layers/argmax_layer.cpp @@ -1,12 +1,12 @@ // Copyright 2014 BVLC and contributors. -#include <vector> #include <cfloat> +#include <queue> +#include <vector> #include "caffe/layer.hpp" #include "caffe/vision_layers.hpp" - namespace caffe { template <typename Dtype> @@ -14,36 +14,65 @@ void ArgMaxLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) { Layer<Dtype>::SetUp(bottom, top); out_max_val_ = this->layer_param_.argmax_param().out_max_val(); + top_k_ = this->layer_param_.argmax_param().top_k(); + CHECK_GE(top_k_, 1) << " top k must not be less than 1."; + CHECK_LE(top_k_, bottom[0]->count() / bottom[0]->num()) + << "top_k must be less than or equal to the number of classes."; if (out_max_val_) { // Produces max_ind and max_val - (*top)[0]->Reshape(bottom[0]->num(), 2, 1, 1); + (*top)[0]->Reshape(bottom[0]->num(), 2, top_k_, 1); } else { // Produces only max_ind - (*top)[0]->Reshape(bottom[0]->num(), 1, 1, 1); + (*top)[0]->Reshape(bottom[0]->num(), 1, top_k_, 1); } } template <typename Dtype> +class IDAndValueComparator { + public: + bool operator() (const std::pair<size_t, Dtype>& lhs, + const std::pair<size_t, Dtype>& rhs) const { + return lhs.second < rhs.second || (lhs.second == rhs.second && + lhs.first < rhs.first); + } +}; + +template <typename Dtype> Dtype ArgMaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) { const Dtype* bottom_data = bottom[0]->cpu_data(); Dtype* top_data = (*top)[0]->mutable_cpu_data(); int num = bottom[0]->num(); int dim = bottom[0]->count() / bottom[0]->num(); + Dtype value; for (int i = 0; i < num; ++i) { - Dtype max_val = -FLT_MAX; - int max_ind = 0; + std::priority_queue<std::pair<size_t, Dtype>, + std::vector<std::pair<size_t, Dtype> >, IDAndValueComparator<Dtype> > + top_k_results; for (int j = 0; j < dim; ++j) { - if (bottom_data[i * dim + j] > max_val) { - max_val = bottom_data[i * dim + j]; - max_ind = j; + value = -(bottom_data[i * dim + j]); + if (top_k_results.size() >= top_k_) { + if (value < top_k_results.top().second) { + top_k_results.pop(); + top_k_results.push(std::make_pair(j, value)); + } + } else { + top_k_results.push(std::make_pair(j, value)); } } if (out_max_val_) { - top_data[i * 2] = max_ind; - top_data[i * 2 + 1] = max_val; + for (int j = 0; j < top_k_; ++j) { + top_data[i * 2 * top_k_ + (top_k_ - 1 - j) * 2] = + top_k_results.top().first; + top_data[i * 2 * top_k_ + (top_k_ - 1 - j) * 2 + 1] = + -(top_k_results.top().second); + top_k_results.pop(); + } } else { - top_data[i] = max_ind; + for (int j = 0; j < top_k_; ++j) { + top_data[i * top_k_ + (top_k_ - 1 - j)] = top_k_results.top().first; + top_k_results.pop(); + } } } return Dtype(0); @@ -51,5 +80,4 @@ Dtype ArgMaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom, INSTANTIATE_CLASS(ArgMaxLayer); - } // namespace caffe |