summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/caffe/layers/argmax_layer.cpp53
1 files changed, 38 insertions, 15 deletions
diff --git a/src/caffe/layers/argmax_layer.cpp b/src/caffe/layers/argmax_layer.cpp
index dad3d08b..18ff5f5a 100644
--- a/src/caffe/layers/argmax_layer.cpp
+++ b/src/caffe/layers/argmax_layer.cpp
@@ -33,13 +33,19 @@ void ArgMaxLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
template <typename Dtype>
void ArgMaxLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
- if (out_max_val_) {
+ std::vector<int> shape(4, 1);
+ shape[0] = bottom[0]->shape(0);
+ // Produces max_ind
+ shape[2] = top_k_;
+ if (has_axis_) {
+ // Produces max_ind or max_val per axis
+ shape = bottom[0]->shape();
+ shape[axis_] = top_k_;
+ } else if (out_max_val_) {
// Produces max_ind and max_val
- top[0]->Reshape(bottom[0]->num(), 2, top_k_, 1);
- } else {
- // Produces only max_ind
- top[0]->Reshape(bottom[0]->num(), 1, top_k_, 1);
+ shape[1] = 2;
}
+ top[0]->Reshape(shape);
}
template <typename Dtype>
@@ -47,23 +53,40 @@ void ArgMaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = top[0]->mutable_cpu_data();
- int num = bottom[0]->num();
- int dim = bottom[0]->count() / bottom[0]->num();
+ int dim, axis_dist;
+ if (has_axis_) {
+ dim = bottom[0]->shape(axis_);
+ // Distance between values of axis in blob
+ axis_dist = bottom[0]->count(axis_) / dim;
+ } else {
+ dim = bottom[0]->count(1);
+ axis_dist = 1;
+ }
+ int num = bottom[0]->count() / dim;
+ std::vector<std::pair<Dtype, int> > bottom_data_vector(dim);
for (int i = 0; i < num; ++i) {
- std::vector<std::pair<Dtype, int> > bottom_data_vector;
for (int j = 0; j < dim; ++j) {
- bottom_data_vector.push_back(
- std::make_pair(bottom_data[i * dim + j], j));
+ bottom_data_vector[j] = std::make_pair(
+ bottom_data[(i / axis_dist * dim + j) * axis_dist + i % axis_dist], j);
}
std::partial_sort(
bottom_data_vector.begin(), bottom_data_vector.begin() + top_k_,
bottom_data_vector.end(), std::greater<std::pair<Dtype, int> >());
for (int j = 0; j < top_k_; ++j) {
- top_data[top[0]->offset(i, 0, j)] = bottom_data_vector[j].second;
- }
- if (out_max_val_) {
- for (int j = 0; j < top_k_; ++j) {
- top_data[top[0]->offset(i, 1, j)] = bottom_data_vector[j].first;
+ if (out_max_val_) {
+ if (has_axis_) {
+ // Produces max_val per axis
+ top_data[(i / axis_dist * top_k_ + j) * axis_dist + i % axis_dist] =
+ bottom_data_vector[j].first;
+ } else {
+ // Produces max_ind and max_val
+ top_data[top[0]->offset(i, 0, j)] = bottom_data_vector[j].second;
+ top_data[top[0]->offset(i, 1, j)] = bottom_data_vector[j].first;
+ }
+ } else {
+ // Produces max_ind per axis
+ top_data[(i / axis_dist * top_k_ + j) * axis_dist + i % axis_dist] =
+ bottom_data_vector[j].second;
}
}
}