summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTim Meinhardt <meinhardt.tim@gmail.com>2015-09-15 16:55:26 +0200
committerTim Meinhardt <meinhardt.tim@gmail.com>2015-09-25 12:05:44 +0200
commit6c02c8b7daf123f64b944ede407d0022e98d6e0b (patch)
tree1663f70e57a3eb7140d1289baaabb5c8928c47d5
parent674b349522249094b5449f4bc1f90635dc625f4f (diff)
downloadcaffeonacl-6c02c8b7daf123f64b944ede407d0022e98d6e0b.tar.gz
caffeonacl-6c02c8b7daf123f64b944ede407d0022e98d6e0b.tar.bz2
caffeonacl-6c02c8b7daf123f64b944ede407d0022e98d6e0b.zip
Add argmax_param axis
-rw-r--r--include/caffe/common_layers.hpp2
-rw-r--r--src/caffe/layers/argmax_layer.cpp22
-rw-r--r--src/caffe/proto/caffe.proto5
3 files changed, 24 insertions, 5 deletions
diff --git a/include/caffe/common_layers.hpp b/include/caffe/common_layers.hpp
index 89bab8d6..491f9edb 100644
--- a/include/caffe/common_layers.hpp
+++ b/include/caffe/common_layers.hpp
@@ -68,6 +68,8 @@ class ArgMaxLayer : public Layer<Dtype> {
}
bool out_max_val_;
size_t top_k_;
+ bool has_axis_;
+ int axis_;
};
/**
diff --git a/src/caffe/layers/argmax_layer.cpp b/src/caffe/layers/argmax_layer.cpp
index c4040cdc..dad3d08b 100644
--- a/src/caffe/layers/argmax_layer.cpp
+++ b/src/caffe/layers/argmax_layer.cpp
@@ -11,11 +11,23 @@ namespace caffe {
template <typename Dtype>
void ArgMaxLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
- out_max_val_ = this->layer_param_.argmax_param().out_max_val();
- top_k_ = this->layer_param_.argmax_param().top_k();
- CHECK_GE(top_k_, 1) << " top k must not be less than 1.";
- CHECK_LE(top_k_, bottom[0]->count() / bottom[0]->num())
- << "top_k must be less than or equal to the number of classes.";
+ const ArgMaxParameter& argmax_param = this->layer_param_.argmax_param();
+ out_max_val_ = argmax_param.out_max_val();
+ top_k_ = argmax_param.top_k();
+ has_axis_ = argmax_param.has_axis();
+ CHECK_GE(top_k_, 1) << "top k must not be less than 1.";
+ if (has_axis_) {
+ axis_ = bottom[0]->CanonicalAxisIndex(argmax_param.axis());
+ CHECK_GE(axis_, 0) << "axis must not be less than 0.";
+ CHECK_LE(axis_, bottom[0]->num_axes()) <<
+ "axis must be less than or equal to the number of axis.";
+ CHECK_LE(top_k_, bottom[0]->shape(axis_))
+ << "top_k must be less than or equal to the dimension of the axis.";
+ } else {
+ CHECK_LE(top_k_, bottom[0]->count(1))
+ << "top_k must be less than or equal to"
+ " the dimension of the flattened bottom blob per instance.";
+ }
}
template <typename Dtype>
diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto
index f52c941b..a8747c12 100644
--- a/src/caffe/proto/caffe.proto
+++ b/src/caffe/proto/caffe.proto
@@ -443,6 +443,11 @@ message ArgMaxParameter {
// If true produce pairs (argmax, maxval)
optional bool out_max_val = 1 [default = false];
optional uint32 top_k = 2 [default = 1];
+ // The axis along which to maximise -- may be negative to index from the
+ // end (e.g., -1 for the last axis).
+ // By default ArgMaxLayer maximizes over the flattened trailing dimensions
+ // for each index of the first / num dimension.
+ optional int32 axis = 3;
}
message ConcatParameter {