summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorEvan Shelhamer <shelhamer@imaginarynumber.net>2015-09-30 17:43:05 -0700
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2015-09-30 17:43:05 -0700
commit01e15d0d6d2c9cc5b03739a258aab774336056a2 (patch)
treea127a152aedb36094fc50c7688b55c2c3a3715c2 /src
parent942df002368bfe95e285ff29ccdfe4a8b616b413 (diff)
parentdef3d3cc49b908e54f787be377c299e6e6cbf16c (diff)
downloadcaffeonacl-01e15d0d6d2c9cc5b03739a258aab774336056a2.tar.gz
caffeonacl-01e15d0d6d2c9cc5b03739a258aab774336056a2.tar.bz2
caffeonacl-01e15d0d6d2c9cc5b03739a258aab774336056a2.zip
Merge pull request #3069 from timmeinhardt/argmax
Add argmax_param "axis" to maximise output along the specified axis
Diffstat (limited to 'src')
-rw-r--r--src/caffe/layers/argmax_layer.cpp77
-rw-r--r--src/caffe/proto/caffe.proto5
-rw-r--r--src/caffe/test/test_argmax_layer.cpp137
3 files changed, 194 insertions, 25 deletions
diff --git a/src/caffe/layers/argmax_layer.cpp b/src/caffe/layers/argmax_layer.cpp
index c4040cdc..0c0a932d 100644
--- a/src/caffe/layers/argmax_layer.cpp
+++ b/src/caffe/layers/argmax_layer.cpp
@@ -11,23 +11,43 @@ namespace caffe {
template <typename Dtype>
void ArgMaxLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
- out_max_val_ = this->layer_param_.argmax_param().out_max_val();
- top_k_ = this->layer_param_.argmax_param().top_k();
- CHECK_GE(top_k_, 1) << " top k must not be less than 1.";
- CHECK_LE(top_k_, bottom[0]->count() / bottom[0]->num())
- << "top_k must be less than or equal to the number of classes.";
+ const ArgMaxParameter& argmax_param = this->layer_param_.argmax_param();
+ out_max_val_ = argmax_param.out_max_val();
+ top_k_ = argmax_param.top_k();
+ has_axis_ = argmax_param.has_axis();
+ CHECK_GE(top_k_, 1) << "top k must not be less than 1.";
+ if (has_axis_) {
+ axis_ = bottom[0]->CanonicalAxisIndex(argmax_param.axis());
+ CHECK_GE(axis_, 0) << "axis must not be less than 0.";
+ CHECK_LE(axis_, bottom[0]->num_axes()) <<
+ "axis must be less than or equal to the number of axis.";
+ CHECK_LE(top_k_, bottom[0]->shape(axis_))
+ << "top_k must be less than or equal to the dimension of the axis.";
+ } else {
+ CHECK_LE(top_k_, bottom[0]->count(1))
+ << "top_k must be less than or equal to"
+ " the dimension of the flattened bottom blob per instance.";
+ }
}
template <typename Dtype>
void ArgMaxLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
- if (out_max_val_) {
- // Produces max_ind and max_val
- top[0]->Reshape(bottom[0]->num(), 2, top_k_, 1);
+ std::vector<int> shape(bottom[0]->num_axes(), 1);
+ if (has_axis_) {
+ // Produces max_ind or max_val per axis
+ shape = bottom[0]->shape();
+ shape[axis_] = top_k_;
} else {
- // Produces only max_ind
- top[0]->Reshape(bottom[0]->num(), 1, top_k_, 1);
+ shape[0] = bottom[0]->shape(0);
+ // Produces max_ind
+ shape[2] = top_k_;
+ if (out_max_val_) {
+ // Produces max_ind and max_val
+ shape[1] = 2;
+ }
}
+ top[0]->Reshape(shape);
}
template <typename Dtype>
@@ -35,23 +55,40 @@ void ArgMaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = top[0]->mutable_cpu_data();
- int num = bottom[0]->num();
- int dim = bottom[0]->count() / bottom[0]->num();
+ int dim, axis_dist;
+ if (has_axis_) {
+ dim = bottom[0]->shape(axis_);
+ // Distance between values of axis in blob
+ axis_dist = bottom[0]->count(axis_) / dim;
+ } else {
+ dim = bottom[0]->count(1);
+ axis_dist = 1;
+ }
+ int num = bottom[0]->count() / dim;
+ std::vector<std::pair<Dtype, int> > bottom_data_vector(dim);
for (int i = 0; i < num; ++i) {
- std::vector<std::pair<Dtype, int> > bottom_data_vector;
for (int j = 0; j < dim; ++j) {
- bottom_data_vector.push_back(
- std::make_pair(bottom_data[i * dim + j], j));
+ bottom_data_vector[j] = std::make_pair(
+ bottom_data[(i / axis_dist * dim + j) * axis_dist + i % axis_dist], j);
}
std::partial_sort(
bottom_data_vector.begin(), bottom_data_vector.begin() + top_k_,
bottom_data_vector.end(), std::greater<std::pair<Dtype, int> >());
for (int j = 0; j < top_k_; ++j) {
- top_data[top[0]->offset(i, 0, j)] = bottom_data_vector[j].second;
- }
- if (out_max_val_) {
- for (int j = 0; j < top_k_; ++j) {
- top_data[top[0]->offset(i, 1, j)] = bottom_data_vector[j].first;
+ if (out_max_val_) {
+ if (has_axis_) {
+ // Produces max_val per axis
+ top_data[(i / axis_dist * top_k_ + j) * axis_dist + i % axis_dist]
+ = bottom_data_vector[j].first;
+ } else {
+ // Produces max_ind and max_val
+ top_data[2 * i * top_k_ + j] = bottom_data_vector[j].second;
+ top_data[2 * i * top_k_ + top_k_ + j] = bottom_data_vector[j].first;
+ }
+ } else {
+ // Produces max_ind per axis
+ top_data[(i / axis_dist * top_k_ + j) * axis_dist + i % axis_dist]
+ = bottom_data_vector[j].second;
}
}
}
diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto
index f52c941b..a8747c12 100644
--- a/src/caffe/proto/caffe.proto
+++ b/src/caffe/proto/caffe.proto
@@ -443,6 +443,11 @@ message ArgMaxParameter {
// If true produce pairs (argmax, maxval)
optional bool out_max_val = 1 [default = false];
optional uint32 top_k = 2 [default = 1];
+ // The axis along which to maximise -- may be negative to index from the
+ // end (e.g., -1 for the last axis).
+ // By default ArgMaxLayer maximizes over the flattened trailing dimensions
+ // for each index of the first / num dimension.
+ optional int32 axis = 3;
}
message ConcatParameter {
diff --git a/src/caffe/test/test_argmax_layer.cpp b/src/caffe/test/test_argmax_layer.cpp
index 895c3d37..bbf19099 100644
--- a/src/caffe/test/test_argmax_layer.cpp
+++ b/src/caffe/test/test_argmax_layer.cpp
@@ -16,7 +16,7 @@ template <typename Dtype>
class ArgMaxLayerTest : public CPUDeviceTest<Dtype> {
protected:
ArgMaxLayerTest()
- : blob_bottom_(new Blob<Dtype>(10, 20, 1, 1)),
+ : blob_bottom_(new Blob<Dtype>(10, 10, 20, 20)),
blob_top_(new Blob<Dtype>()),
top_k_(5) {
Caffe::set_random_seed(1701);
@@ -55,6 +55,43 @@ TYPED_TEST(ArgMaxLayerTest, TestSetupMaxVal) {
EXPECT_EQ(this->blob_top_->channels(), 2);
}
+TYPED_TEST(ArgMaxLayerTest, TestSetupAxis) {
+ LayerParameter layer_param;
+ ArgMaxParameter* argmax_param = layer_param.mutable_argmax_param();
+ argmax_param->set_axis(0);
+ ArgMaxLayer<TypeParam> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ EXPECT_EQ(this->blob_top_->shape(0), argmax_param->top_k());
+ EXPECT_EQ(this->blob_top_->shape(1), this->blob_bottom_->shape(0));
+ EXPECT_EQ(this->blob_top_->shape(2), this->blob_bottom_->shape(2));
+ EXPECT_EQ(this->blob_top_->shape(3), this->blob_bottom_->shape(3));
+}
+
+TYPED_TEST(ArgMaxLayerTest, TestSetupAxisNegativeIndexing) {
+ LayerParameter layer_param;
+ ArgMaxParameter* argmax_param = layer_param.mutable_argmax_param();
+ argmax_param->set_axis(-2);
+ ArgMaxLayer<TypeParam> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ EXPECT_EQ(this->blob_top_->shape(0), this->blob_bottom_->shape(0));
+ EXPECT_EQ(this->blob_top_->shape(1), this->blob_bottom_->shape(1));
+ EXPECT_EQ(this->blob_top_->shape(2), argmax_param->top_k());
+ EXPECT_EQ(this->blob_top_->shape(3), this->blob_bottom_->shape(3));
+}
+
+TYPED_TEST(ArgMaxLayerTest, TestSetupAxisMaxVal) {
+ LayerParameter layer_param;
+ ArgMaxParameter* argmax_param = layer_param.mutable_argmax_param();
+ argmax_param->set_axis(2);
+ argmax_param->set_out_max_val(true);
+ ArgMaxLayer<TypeParam> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ EXPECT_EQ(this->blob_top_->shape(0), this->blob_bottom_->shape(0));
+ EXPECT_EQ(this->blob_top_->shape(1), this->blob_bottom_->shape(1));
+ EXPECT_EQ(this->blob_top_->shape(2), argmax_param->top_k());
+ EXPECT_EQ(this->blob_top_->shape(3), this->blob_bottom_->shape(3));
+}
+
TYPED_TEST(ArgMaxLayerTest, TestCPU) {
LayerParameter layer_param;
ArgMaxLayer<TypeParam> layer(layer_param);
@@ -112,6 +149,7 @@ TYPED_TEST(ArgMaxLayerTest, TestCPUTopK) {
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
// Now, check values
+ const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
int max_ind;
TypeParam max_val;
int num = this->blob_bottom_->num();
@@ -121,10 +159,10 @@ TYPED_TEST(ArgMaxLayerTest, TestCPUTopK) {
EXPECT_LE(this->blob_top_->data_at(i, 0, 0, 0), dim);
for (int j = 0; j < this->top_k_; ++j) {
max_ind = this->blob_top_->data_at(i, 0, j, 0);
- max_val = this->blob_bottom_->data_at(i, max_ind, 0, 0);
+ max_val = bottom_data[i * dim + max_ind];
int count = 0;
for (int k = 0; k < dim; ++k) {
- if (this->blob_bottom_->data_at(i, k, 0, 0) > max_val) {
+ if (bottom_data[i * dim + k] > max_val) {
++count;
}
}
@@ -142,6 +180,7 @@ TYPED_TEST(ArgMaxLayerTest, TestCPUMaxValTopK) {
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
// Now, check values
+ const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
int max_ind;
TypeParam max_val;
int num = this->blob_bottom_->num();
@@ -152,10 +191,10 @@ TYPED_TEST(ArgMaxLayerTest, TestCPUMaxValTopK) {
for (int j = 0; j < this->top_k_; ++j) {
max_ind = this->blob_top_->data_at(i, 0, j, 0);
max_val = this->blob_top_->data_at(i, 1, j, 0);
- EXPECT_EQ(this->blob_bottom_->data_at(i, max_ind, 0, 0), max_val);
+ EXPECT_EQ(bottom_data[i * dim + max_ind], max_val);
int count = 0;
for (int k = 0; k < dim; ++k) {
- if (this->blob_bottom_->data_at(i, k, 0, 0) > max_val) {
+ if (bottom_data[i * dim + k] > max_val) {
++count;
}
}
@@ -164,5 +203,93 @@ TYPED_TEST(ArgMaxLayerTest, TestCPUMaxValTopK) {
}
}
+TYPED_TEST(ArgMaxLayerTest, TestCPUAxis) {
+ LayerParameter layer_param;
+ ArgMaxParameter* argmax_param = layer_param.mutable_argmax_param();
+ argmax_param->set_axis(0);
+ ArgMaxLayer<TypeParam> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ // Now, check values
+ int max_ind;
+ TypeParam max_val;
+ std::vector<int> shape = this->blob_bottom_->shape();
+ for (int i = 0; i < shape[1]; ++i) {
+ for (int j = 0; j < shape[2]; ++j) {
+ for (int k = 0; k < shape[3]; ++k) {
+ max_ind = this->blob_top_->data_at(0, i, j, k);
+ max_val = this->blob_bottom_->data_at(max_ind, i, j, k);
+ EXPECT_GE(max_ind, 0);
+ EXPECT_LE(max_ind, shape[0]);
+ for (int l = 0; l < shape[0]; ++l) {
+ EXPECT_LE(this->blob_bottom_->data_at(l, i, j, k), max_val);
+ }
+ }
+ }
+ }
+}
+
+TYPED_TEST(ArgMaxLayerTest, TestCPUAxisTopK) {
+ LayerParameter layer_param;
+ ArgMaxParameter* argmax_param = layer_param.mutable_argmax_param();
+ argmax_param->set_axis(2);
+ argmax_param->set_top_k(this->top_k_);
+ ArgMaxLayer<TypeParam> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ // Now, check values
+ int max_ind;
+ TypeParam max_val;
+ std::vector<int> shape = this->blob_bottom_->shape();
+ for (int i = 0; i < shape[0]; ++i) {
+ for (int j = 0; j < shape[1]; ++j) {
+ for (int k = 0; k < shape[3]; ++k) {
+ for (int m = 0; m < this->top_k_; ++m) {
+ max_ind = this->blob_top_->data_at(i, j, m, k);
+ max_val = this->blob_bottom_->data_at(i, j, max_ind, k);
+ EXPECT_GE(max_ind, 0);
+ EXPECT_LE(max_ind, shape[2]);
+ int count = 0;
+ for (int l = 0; l < shape[2]; ++l) {
+ if (this->blob_bottom_->data_at(i, j, l, k) > max_val) {
+ ++count;
+ }
+ }
+ EXPECT_EQ(m, count);
+ }
+ }
+ }
+ }
+}
+
+TYPED_TEST(ArgMaxLayerTest, TestCPUAxisMaxValTopK) {
+ LayerParameter layer_param;
+ ArgMaxParameter* argmax_param = layer_param.mutable_argmax_param();
+ argmax_param->set_axis(-1);
+ argmax_param->set_top_k(this->top_k_);
+ argmax_param->set_out_max_val(true);
+ ArgMaxLayer<TypeParam> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ // Now, check values
+ TypeParam max_val;
+ std::vector<int> shape = this->blob_bottom_->shape();
+ for (int i = 0; i < shape[0]; ++i) {
+ for (int j = 0; j < shape[1]; ++j) {
+ for (int k = 0; k < shape[2]; ++k) {
+ for (int m = 0; m < this->top_k_; ++m) {
+ max_val = this->blob_top_->data_at(i, j, k, m);
+ int count = 0;
+ for (int l = 0; l < shape[3]; ++l) {
+ if (this->blob_bottom_->data_at(i, j, k, l) > max_val) {
+ ++count;
+ }
+ }
+ EXPECT_EQ(m, count);
+ }
+ }
+ }
+ }
+}
} // namespace caffe