summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPrzemysław Dolata <snowball91b@gmail.com>2017-10-19 07:53:59 (GMT)
committerGitHub <noreply@github.com>2017-10-19 07:53:59 (GMT)
commitb4ffad848e7a54e2f22081f07c8c7e49f31ca5f4 (patch)
tree1f93cdd7f480f3b22d9f2af47b99f0ed85b8fe3b
parent79ddda7e931f90f1b648e5372ccbaf1a35e88fb5 (diff)
parent243cd8948520e83740be328466352b10e6983aec (diff)
downloadcaffe-b4ffad848e7a54e2f22081f07c8c7e49f31ca5f4.zip
caffe-b4ffad848e7a54e2f22081f07c8c7e49f31ca5f4.tar.gz
caffe-b4ffad848e7a54e2f22081f07c8c7e49f31ca5f4.tar.bz2
Merge pull request #5973 from Noiredd/pytest
Add absolute tolerance to test_net.py to prevent random Travis fails
-rw-r--r--python/caffe/test/test_net.py50
1 files changed, 25 insertions, 25 deletions
diff --git a/python/caffe/test/test_net.py b/python/caffe/test/test_net.py
index afd2769..ee1d38c 100644
--- a/python/caffe/test/test_net.py
+++ b/python/caffe/test/test_net.py
@@ -72,41 +72,41 @@ class TestNet(unittest.TestCase):
self.net.backward()
def test_forward_start_end(self):
- conv_blob=self.net.blobs['conv'];
- ip_blob=self.net.blobs['ip_blob'];
- sample_data=np.random.uniform(size=conv_blob.data.shape);
- sample_data=sample_data.astype(np.float32);
- conv_blob.data[:]=sample_data;
- forward_blob=self.net.forward(start='ip',end='ip');
- self.assertIn('ip_blob',forward_blob);
-
- manual_forward=[];
+ conv_blob=self.net.blobs['conv']
+ ip_blob=self.net.blobs['ip_blob']
+ sample_data=np.random.uniform(size=conv_blob.data.shape)
+ sample_data=sample_data.astype(np.float32)
+ conv_blob.data[:]=sample_data
+ forward_blob=self.net.forward(start='ip',end='ip')
+ self.assertIn('ip_blob',forward_blob)
+
+ manual_forward=[]
for i in range(0,conv_blob.data.shape[0]):
dot=np.dot(self.net.params['ip'][0].data,
- conv_blob.data[i].reshape(-1));
- manual_forward.append(dot+self.net.params['ip'][1].data);
- manual_forward=np.array(manual_forward);
+ conv_blob.data[i].reshape(-1))
+ manual_forward.append(dot+self.net.params['ip'][1].data)
+ manual_forward=np.array(manual_forward)
- np.testing.assert_allclose(ip_blob.data,manual_forward,rtol=1e-3);
+ np.testing.assert_allclose(ip_blob.data,manual_forward,rtol=1e-3,atol=1e-5)
def test_backward_start_end(self):
- conv_blob=self.net.blobs['conv'];
- ip_blob=self.net.blobs['ip_blob'];
+ conv_blob=self.net.blobs['conv']
+ ip_blob=self.net.blobs['ip_blob']
sample_data=np.random.uniform(size=ip_blob.data.shape)
- sample_data=sample_data.astype(np.float32);
- ip_blob.diff[:]=sample_data;
- backward_blob=self.net.backward(start='ip',end='ip');
- self.assertIn('conv',backward_blob);
+ sample_data=sample_data.astype(np.float32)
+ ip_blob.diff[:]=sample_data
+ backward_blob=self.net.backward(start='ip',end='ip')
+ self.assertIn('conv',backward_blob)
- manual_backward=[];
+ manual_backward=[]
for i in range(0,conv_blob.data.shape[0]):
dot=np.dot(self.net.params['ip'][0].data.transpose(),
- sample_data[i].reshape(-1));
- manual_backward.append(dot);
- manual_backward=np.array(manual_backward);
- manual_backward=manual_backward.reshape(conv_blob.data.shape);
+ sample_data[i].reshape(-1))
+ manual_backward.append(dot)
+ manual_backward=np.array(manual_backward)
+ manual_backward=manual_backward.reshape(conv_blob.data.shape)
- np.testing.assert_allclose(conv_blob.diff,manual_backward,rtol=1e-3);
+ np.testing.assert_allclose(conv_blob.diff,manual_backward,rtol=1e-3,atol=1e-5)
def test_clear_param_diffs(self):
# Run a forward/backward step to have non-zero diffs