diff options
-rw-r--r-- | src/caffe/test/test_argmax_layer.cpp | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/src/caffe/test/test_argmax_layer.cpp b/src/caffe/test/test_argmax_layer.cpp index 895c3d37..d3018f90 100644 --- a/src/caffe/test/test_argmax_layer.cpp +++ b/src/caffe/test/test_argmax_layer.cpp @@ -16,7 +16,7 @@ template <typename Dtype> class ArgMaxLayerTest : public CPUDeviceTest<Dtype> { protected: ArgMaxLayerTest() - : blob_bottom_(new Blob<Dtype>(10, 20, 1, 1)), + : blob_bottom_(new Blob<Dtype>(10, 10, 20, 20)), blob_top_(new Blob<Dtype>()), top_k_(5) { Caffe::set_random_seed(1701); @@ -112,6 +112,7 @@ TYPED_TEST(ArgMaxLayerTest, TestCPUTopK) { layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); // Now, check values + const TypeParam* bottom_data = this->blob_bottom_->cpu_data(); int max_ind; TypeParam max_val; int num = this->blob_bottom_->num(); @@ -121,10 +122,10 @@ TYPED_TEST(ArgMaxLayerTest, TestCPUTopK) { EXPECT_LE(this->blob_top_->data_at(i, 0, 0, 0), dim); for (int j = 0; j < this->top_k_; ++j) { max_ind = this->blob_top_->data_at(i, 0, j, 0); - max_val = this->blob_bottom_->data_at(i, max_ind, 0, 0); + max_val = bottom_data[i * dim + max_ind]; int count = 0; for (int k = 0; k < dim; ++k) { - if (this->blob_bottom_->data_at(i, k, 0, 0) > max_val) { + if (bottom_data[i * dim + k] > max_val) { ++count; } } @@ -142,6 +143,7 @@ TYPED_TEST(ArgMaxLayerTest, TestCPUMaxValTopK) { layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); // Now, check values + const TypeParam* bottom_data = this->blob_bottom_->cpu_data(); int max_ind; TypeParam max_val; int num = this->blob_bottom_->num(); @@ -152,10 +154,10 @@ TYPED_TEST(ArgMaxLayerTest, TestCPUMaxValTopK) { for (int j = 0; j < this->top_k_; ++j) { max_ind = this->blob_top_->data_at(i, 0, j, 0); max_val = this->blob_top_->data_at(i, 1, j, 0); - EXPECT_EQ(this->blob_bottom_->data_at(i, max_ind, 0, 0), max_val); + EXPECT_EQ(bottom_data[i * dim + max_ind], max_val); int count = 0; for (int k = 0; k < dim; ++k) { - if (this->blob_bottom_->data_at(i, k, 0, 0) > max_val) { + if (bottom_data[i * dim + k] > max_val) { ++count; } } |