diff options
author | Tim Meinhardt <meinhardt.tim@gmail.com> | 2015-09-15 16:56:16 +0200 |
---|---|---|
committer | Tim Meinhardt <meinhardt.tim@gmail.com> | 2015-09-25 12:05:54 +0200 |
commit | c77d5e5156f94720c1decd13f7f87fe78df9d4eb (patch) | |
tree | 336c049da64830427e267248b651da70d57cb387 | |
parent | 6c02c8b7daf123f64b944ede407d0022e98d6e0b (diff) | |
download | caffeonacl-c77d5e5156f94720c1decd13f7f87fe78df9d4eb.tar.gz caffeonacl-c77d5e5156f94720c1decd13f7f87fe78df9d4eb.tar.bz2 caffeonacl-c77d5e5156f94720c1decd13f7f87fe78df9d4eb.zip |
Implement ArgMaxLayer forward_cpu and reshape for axis param
-rw-r--r-- | src/caffe/layers/argmax_layer.cpp | 53 |
1 files changed, 38 insertions, 15 deletions
diff --git a/src/caffe/layers/argmax_layer.cpp b/src/caffe/layers/argmax_layer.cpp index dad3d08b..18ff5f5a 100644 --- a/src/caffe/layers/argmax_layer.cpp +++ b/src/caffe/layers/argmax_layer.cpp @@ -33,13 +33,19 @@ void ArgMaxLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom, template <typename Dtype> void ArgMaxLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { - if (out_max_val_) { + std::vector<int> shape(4, 1); + shape[0] = bottom[0]->shape(0); + // Produces max_ind + shape[2] = top_k_; + if (has_axis_) { + // Produces max_ind or max_val per axis + shape = bottom[0]->shape(); + shape[axis_] = top_k_; + } else if (out_max_val_) { // Produces max_ind and max_val - top[0]->Reshape(bottom[0]->num(), 2, top_k_, 1); - } else { - // Produces only max_ind - top[0]->Reshape(bottom[0]->num(), 1, top_k_, 1); + shape[1] = 2; } + top[0]->Reshape(shape); } template <typename Dtype> @@ -47,23 +53,40 @@ void ArgMaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { const Dtype* bottom_data = bottom[0]->cpu_data(); Dtype* top_data = top[0]->mutable_cpu_data(); - int num = bottom[0]->num(); - int dim = bottom[0]->count() / bottom[0]->num(); + int dim, axis_dist; + if (has_axis_) { + dim = bottom[0]->shape(axis_); + // Distance between values of axis in blob + axis_dist = bottom[0]->count(axis_) / dim; + } else { + dim = bottom[0]->count(1); + axis_dist = 1; + } + int num = bottom[0]->count() / dim; + std::vector<std::pair<Dtype, int> > bottom_data_vector(dim); for (int i = 0; i < num; ++i) { - std::vector<std::pair<Dtype, int> > bottom_data_vector; for (int j = 0; j < dim; ++j) { - bottom_data_vector.push_back( - std::make_pair(bottom_data[i * dim + j], j)); + bottom_data_vector[j] = std::make_pair( + bottom_data[(i / axis_dist * dim + j) * axis_dist + i % axis_dist], j); } std::partial_sort( bottom_data_vector.begin(), bottom_data_vector.begin() + top_k_, bottom_data_vector.end(), std::greater<std::pair<Dtype, int> >()); for (int j = 0; j < top_k_; ++j) { - top_data[top[0]->offset(i, 0, j)] = bottom_data_vector[j].second; - } - if (out_max_val_) { - for (int j = 0; j < top_k_; ++j) { - top_data[top[0]->offset(i, 1, j)] = bottom_data_vector[j].first; + if (out_max_val_) { + if (has_axis_) { + // Produces max_val per axis + top_data[(i / axis_dist * top_k_ + j) * axis_dist + i % axis_dist] = + bottom_data_vector[j].first; + } else { + // Produces max_ind and max_val + top_data[top[0]->offset(i, 0, j)] = bottom_data_vector[j].second; + top_data[top[0]->offset(i, 1, j)] = bottom_data_vector[j].first; + } + } else { + // Produces max_ind per axis + top_data[(i / axis_dist * top_k_ + j) * axis_dist + i % axis_dist] = + bottom_data_vector[j].second; } } } |