From 6c02c8b7daf123f64b944ede407d0022e98d6e0b Mon Sep 17 00:00:00 2001 From: Tim Meinhardt Date: Tue, 15 Sep 2015 16:55:26 +0200 Subject: Add argmax_param axis --- include/caffe/common_layers.hpp | 2 ++ src/caffe/layers/argmax_layer.cpp | 22 +++++++++++++++++----- src/caffe/proto/caffe.proto | 5 +++++ 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/include/caffe/common_layers.hpp b/include/caffe/common_layers.hpp index 89bab8d6..491f9edb 100644 --- a/include/caffe/common_layers.hpp +++ b/include/caffe/common_layers.hpp @@ -68,6 +68,8 @@ class ArgMaxLayer : public Layer { } bool out_max_val_; size_t top_k_; + bool has_axis_; + int axis_; }; /** diff --git a/src/caffe/layers/argmax_layer.cpp b/src/caffe/layers/argmax_layer.cpp index c4040cdc..dad3d08b 100644 --- a/src/caffe/layers/argmax_layer.cpp +++ b/src/caffe/layers/argmax_layer.cpp @@ -11,11 +11,23 @@ namespace caffe { template void ArgMaxLayer::LayerSetUp(const vector*>& bottom, const vector*>& top) { - out_max_val_ = this->layer_param_.argmax_param().out_max_val(); - top_k_ = this->layer_param_.argmax_param().top_k(); - CHECK_GE(top_k_, 1) << " top k must not be less than 1."; - CHECK_LE(top_k_, bottom[0]->count() / bottom[0]->num()) - << "top_k must be less than or equal to the number of classes."; + const ArgMaxParameter& argmax_param = this->layer_param_.argmax_param(); + out_max_val_ = argmax_param.out_max_val(); + top_k_ = argmax_param.top_k(); + has_axis_ = argmax_param.has_axis(); + CHECK_GE(top_k_, 1) << "top k must not be less than 1."; + if (has_axis_) { + axis_ = bottom[0]->CanonicalAxisIndex(argmax_param.axis()); + CHECK_GE(axis_, 0) << "axis must not be less than 0."; + CHECK_LE(axis_, bottom[0]->num_axes()) << + "axis must be less than or equal to the number of axis."; + CHECK_LE(top_k_, bottom[0]->shape(axis_)) + << "top_k must be less than or equal to the dimension of the axis."; + } else { + CHECK_LE(top_k_, bottom[0]->count(1)) + << "top_k must be less than or equal to" + " the dimension of the flattened bottom blob per instance."; + } } template diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index f52c941b..a8747c12 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -443,6 +443,11 @@ message ArgMaxParameter { // If true produce pairs (argmax, maxval) optional bool out_max_val = 1 [default = false]; optional uint32 top_k = 2 [default = 1]; + // The axis along which to maximise -- may be negative to index from the + // end (e.g., -1 for the last axis). + // By default ArgMaxLayer maximizes over the flattened trailing dimensions + // for each index of the first / num dimension. + optional int32 axis = 3; } message ConcatParameter { -- cgit v1.2.3