summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEvan Shelhamer <shelhamer@imaginarynumber.net>2015-09-24 03:38:44 (GMT)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2015-09-24 03:38:44 (GMT)
commitff6b288dc3451a889aba9d1cbf0716f548729e9c (patch)
treed4aeee9aac6f92cc6a6fd86d263d63c9e5fb61a5
parentd425e59d7494e84be63ad570a149194edce73935 (diff)
parentb8c81bd2bfbc5bc2e394395bf2c1f435cb32b2a1 (diff)
downloadcaffeonacl-ff6b288dc3451a889aba9d1cbf0716f548729e9c.zip
caffeonacl-ff6b288dc3451a889aba9d1cbf0716f548729e9c.tar.gz
caffeonacl-ff6b288dc3451a889aba9d1cbf0716f548729e9c.tar.bz2
Merge pull request #3112 from shelhamer/test-reshape-harder
[test] Test Reshape more rigorously
-rw-r--r--src/caffe/test/test_net.cpp26
1 files changed, 20 insertions, 6 deletions
diff --git a/src/caffe/test/test_net.cpp b/src/caffe/test/test_net.cpp
index 12998d8..ab4afba 100644
--- a/src/caffe/test/test_net.cpp
+++ b/src/caffe/test/test_net.cpp
@@ -2262,15 +2262,17 @@ TEST_F(FilterNetTest, TestFilterInOutByExcludeMultiRule) {
TYPED_TEST(NetTest, TestReshape) {
typedef typename TypeParam::Dtype Dtype;
// We set up bottom blobs of two different sizes, switch between
- // them, and check that forward and backward both run and the results
- // are the same.
+ // them, check that forward and backward both run and the results
+ // are the same, and check that the output shapes change.
Caffe::set_random_seed(this->seed_);
Caffe::set_mode(Caffe::CPU);
FillerParameter filler_param;
filler_param.set_std(1);
GaussianFiller<Dtype> filler(filler_param);
- Blob<Dtype> blob1(4, 3, 9, 11);
- Blob<Dtype> blob2(2, 3, 12, 10);
+ // Check smaller shape first as larger first could hide realloc failures.
+ Blob<Dtype> blob1(2, 3, 12, 10);
+ Blob<Dtype> blob2(4, 3, 9, 11);
+ ASSERT_LT(blob1.count(), blob2.count());
filler.Fill(&blob1);
filler.Fill(&blob2);
@@ -2304,7 +2306,7 @@ TYPED_TEST(NetTest, TestReshape) {
this->net_->ForwardPrefilled();
this->net_->Backward();
for (int i = 0; i < output1.count(); ++i) {
- CHECK_EQ(*(output1.cpu_data() + i), *(output_blob->cpu_data() + i));
+ EXPECT_FLOAT_EQ(*(output1.cpu_data() + i), *(output_blob->cpu_data() + i));
}
input_blob->Reshape(blob2.num(), blob2.channels(), blob2.height(),
@@ -2313,8 +2315,20 @@ TYPED_TEST(NetTest, TestReshape) {
this->net_->ForwardPrefilled();
this->net_->Backward();
for (int i = 0; i < output2.count(); ++i) {
- CHECK_EQ(*(output2.cpu_data() + i), *(output_blob->cpu_data() + i));
+ EXPECT_FLOAT_EQ(*(output2.cpu_data() + i), *(output_blob->cpu_data() + i));
}
+
+ EXPECT_EQ(output1.num(), blob1.num());
+ EXPECT_EQ(output2.num(), blob2.num());
+ bool same_spatial_shape = true;
+ const int kFirstSpatialAxis = 2;
+ for (int i = kFirstSpatialAxis; i < output1.num_axes(); ++i) {
+ if (output1.shape(i) != output2.shape(i)) {
+ same_spatial_shape = false;
+ break;
+ }
+ }
+ EXPECT_FALSE(same_spatial_shape);
}
TYPED_TEST(NetTest, TestSkipPropagateDown) {