summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/caffe/test/test_argmax_layer.cpp12
1 files changed, 7 insertions, 5 deletions
diff --git a/src/caffe/test/test_argmax_layer.cpp b/src/caffe/test/test_argmax_layer.cpp
index 895c3d37..d3018f90 100644
--- a/src/caffe/test/test_argmax_layer.cpp
+++ b/src/caffe/test/test_argmax_layer.cpp
@@ -16,7 +16,7 @@ template <typename Dtype>
class ArgMaxLayerTest : public CPUDeviceTest<Dtype> {
protected:
ArgMaxLayerTest()
- : blob_bottom_(new Blob<Dtype>(10, 20, 1, 1)),
+ : blob_bottom_(new Blob<Dtype>(10, 10, 20, 20)),
blob_top_(new Blob<Dtype>()),
top_k_(5) {
Caffe::set_random_seed(1701);
@@ -112,6 +112,7 @@ TYPED_TEST(ArgMaxLayerTest, TestCPUTopK) {
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
// Now, check values
+ const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
int max_ind;
TypeParam max_val;
int num = this->blob_bottom_->num();
@@ -121,10 +122,10 @@ TYPED_TEST(ArgMaxLayerTest, TestCPUTopK) {
EXPECT_LE(this->blob_top_->data_at(i, 0, 0, 0), dim);
for (int j = 0; j < this->top_k_; ++j) {
max_ind = this->blob_top_->data_at(i, 0, j, 0);
- max_val = this->blob_bottom_->data_at(i, max_ind, 0, 0);
+ max_val = bottom_data[i * dim + max_ind];
int count = 0;
for (int k = 0; k < dim; ++k) {
- if (this->blob_bottom_->data_at(i, k, 0, 0) > max_val) {
+ if (bottom_data[i * dim + k] > max_val) {
++count;
}
}
@@ -142,6 +143,7 @@ TYPED_TEST(ArgMaxLayerTest, TestCPUMaxValTopK) {
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
// Now, check values
+ const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
int max_ind;
TypeParam max_val;
int num = this->blob_bottom_->num();
@@ -152,10 +154,10 @@ TYPED_TEST(ArgMaxLayerTest, TestCPUMaxValTopK) {
for (int j = 0; j < this->top_k_; ++j) {
max_ind = this->blob_top_->data_at(i, 0, j, 0);
max_val = this->blob_top_->data_at(i, 1, j, 0);
- EXPECT_EQ(this->blob_bottom_->data_at(i, max_ind, 0, 0), max_val);
+ EXPECT_EQ(bottom_data[i * dim + max_ind], max_val);
int count = 0;
for (int k = 0; k < dim; ++k) {
- if (this->blob_bottom_->data_at(i, k, 0, 0) > max_val) {
+ if (bottom_data[i * dim + k] > max_val) {
++count;
}
}