diff options
author | Tim Meinhardt <meinhardt.tim@gmail.com> | 2015-09-15 16:55:26 +0200 |
---|---|---|
committer | Tim Meinhardt <meinhardt.tim@gmail.com> | 2015-09-25 12:05:44 +0200 |
commit | 6c02c8b7daf123f64b944ede407d0022e98d6e0b (patch) | |
tree | 1663f70e57a3eb7140d1289baaabb5c8928c47d5 | |
parent | 674b349522249094b5449f4bc1f90635dc625f4f (diff) | |
download | caffeonacl-6c02c8b7daf123f64b944ede407d0022e98d6e0b.tar.gz caffeonacl-6c02c8b7daf123f64b944ede407d0022e98d6e0b.tar.bz2 caffeonacl-6c02c8b7daf123f64b944ede407d0022e98d6e0b.zip |
Add argmax_param axis
-rw-r--r-- | include/caffe/common_layers.hpp | 2 | ||||
-rw-r--r-- | src/caffe/layers/argmax_layer.cpp | 22 | ||||
-rw-r--r-- | src/caffe/proto/caffe.proto | 5 |
3 files changed, 24 insertions, 5 deletions
diff --git a/include/caffe/common_layers.hpp b/include/caffe/common_layers.hpp index 89bab8d6..491f9edb 100644 --- a/include/caffe/common_layers.hpp +++ b/include/caffe/common_layers.hpp @@ -68,6 +68,8 @@ class ArgMaxLayer : public Layer<Dtype> { } bool out_max_val_; size_t top_k_; + bool has_axis_; + int axis_; }; /** diff --git a/src/caffe/layers/argmax_layer.cpp b/src/caffe/layers/argmax_layer.cpp index c4040cdc..dad3d08b 100644 --- a/src/caffe/layers/argmax_layer.cpp +++ b/src/caffe/layers/argmax_layer.cpp @@ -11,11 +11,23 @@ namespace caffe { template <typename Dtype> void ArgMaxLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { - out_max_val_ = this->layer_param_.argmax_param().out_max_val(); - top_k_ = this->layer_param_.argmax_param().top_k(); - CHECK_GE(top_k_, 1) << " top k must not be less than 1."; - CHECK_LE(top_k_, bottom[0]->count() / bottom[0]->num()) - << "top_k must be less than or equal to the number of classes."; + const ArgMaxParameter& argmax_param = this->layer_param_.argmax_param(); + out_max_val_ = argmax_param.out_max_val(); + top_k_ = argmax_param.top_k(); + has_axis_ = argmax_param.has_axis(); + CHECK_GE(top_k_, 1) << "top k must not be less than 1."; + if (has_axis_) { + axis_ = bottom[0]->CanonicalAxisIndex(argmax_param.axis()); + CHECK_GE(axis_, 0) << "axis must not be less than 0."; + CHECK_LE(axis_, bottom[0]->num_axes()) << + "axis must be less than or equal to the number of axis."; + CHECK_LE(top_k_, bottom[0]->shape(axis_)) + << "top_k must be less than or equal to the dimension of the axis."; + } else { + CHECK_LE(top_k_, bottom[0]->count(1)) + << "top_k must be less than or equal to" + " the dimension of the flattened bottom blob per instance."; + } } template <typename Dtype> diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index f52c941b..a8747c12 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -443,6 +443,11 @@ message ArgMaxParameter { // If true produce pairs (argmax, maxval) optional bool out_max_val = 1 [default = false]; optional uint32 top_k = 2 [default = 1]; + // The axis along which to maximise -- may be negative to index from the + // end (e.g., -1 for the last axis). + // By default ArgMaxLayer maximizes over the flattened trailing dimensions + // for each index of the first / num dimension. + optional int32 axis = 3; } message ConcatParameter { |