summaryrefslogtreecommitdiff
path: root/include/caffe/common_layers.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'include/caffe/common_layers.hpp')
-rw-r--r--include/caffe/common_layers.hpp14
1 files changed, 11 insertions, 3 deletions
diff --git a/include/caffe/common_layers.hpp b/include/caffe/common_layers.hpp
index 89bab8d6..d1ddaee4 100644
--- a/include/caffe/common_layers.hpp
+++ b/include/caffe/common_layers.hpp
@@ -21,7 +21,8 @@ namespace caffe {
*
* Intended for use after a classification layer to produce a prediction.
* If parameter out_max_val is set to true, output is a vector of pairs
- * (max_ind, max_val) for each image.
+ * (max_ind, max_val) for each image. The axis parameter specifies an axis
+ * along which to maximise.
*
* NOTE: does not implement Backwards operation.
*/
@@ -34,7 +35,11 @@ class ArgMaxLayer : public Layer<Dtype> {
* - top_k (\b optional uint, default 1).
* the number @f$ K @f$ of maximal items to output.
* - out_max_val (\b optional bool, default false).
- * if set, output a vector of pairs (max_ind, max_val) for each image.
+ * if set, output a vector of pairs (max_ind, max_val) unless axis is set then
+ * output max_val along the specified axis.
+ * - axis (\b optional int).
+ * if set, maximise along the specified axis else maximise the flattened
+ * trailing dimensions for each index of the first / num dimension.
*/
explicit ArgMaxLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
@@ -54,7 +59,8 @@ class ArgMaxLayer : public Layer<Dtype> {
* the inputs @f$ x @f$
* @param top output Blob vector (length 1)
* -# @f$ (N \times 1 \times K \times 1) @f$ or, if out_max_val
- * @f$ (N \times 2 \times K \times 1) @f$
+ * @f$ (N \times 2 \times K \times 1) @f$ unless axis set than e.g.
+ * @f$ (N \times K \times H \times W) @f$ if axis == 1
* the computed outputs @f$
* y_n = \arg\max\limits_i x_{ni}
* @f$ (for @f$ K = 1 @f$).
@@ -68,6 +74,8 @@ class ArgMaxLayer : public Layer<Dtype> {
}
bool out_max_val_;
size_t top_k_;
+ bool has_axis_;
+ int axis_;
};
/**