summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNoiredd <snowball91b@gmail.com>2017-10-11 09:04:18 (GMT)
committerNoiredd <snowball91b@gmail.com>2017-10-11 09:04:18 (GMT)
commit243cd8948520e83740be328466352b10e6983aec (patch)
tree36b6e65f0edf8203535c0114881148ac5a1ac88d
parent62e0c8559045cb2b5a12e0d6c41acd25d4122630 (diff)
downloadcaffe-243cd8948520e83740be328466352b10e6983aec.zip
caffe-243cd8948520e83740be328466352b10e6983aec.tar.gz
caffe-243cd8948520e83740be328466352b10e6983aec.tar.bz2
Add absolute tolerance to test_net.py to prevent random Travis fails
-rw-r--r--python/caffe/test/test_net.py50
1 files changed, 25 insertions, 25 deletions
diff --git a/python/caffe/test/test_net.py b/python/caffe/test/test_net.py
index afd2769..ee1d38c 100644
--- a/python/caffe/test/test_net.py
+++ b/python/caffe/test/test_net.py
@@ -72,41 +72,41 @@ class TestNet(unittest.TestCase):
self.net.backward()
def test_forward_start_end(self):
- conv_blob=self.net.blobs['conv'];
- ip_blob=self.net.blobs['ip_blob'];
- sample_data=np.random.uniform(size=conv_blob.data.shape);
- sample_data=sample_data.astype(np.float32);
- conv_blob.data[:]=sample_data;
- forward_blob=self.net.forward(start='ip',end='ip');
- self.assertIn('ip_blob',forward_blob);
-
- manual_forward=[];
+ conv_blob=self.net.blobs['conv']
+ ip_blob=self.net.blobs['ip_blob']
+ sample_data=np.random.uniform(size=conv_blob.data.shape)
+ sample_data=sample_data.astype(np.float32)
+ conv_blob.data[:]=sample_data
+ forward_blob=self.net.forward(start='ip',end='ip')
+ self.assertIn('ip_blob',forward_blob)
+
+ manual_forward=[]
for i in range(0,conv_blob.data.shape[0]):
dot=np.dot(self.net.params['ip'][0].data,
- conv_blob.data[i].reshape(-1));
- manual_forward.append(dot+self.net.params['ip'][1].data);
- manual_forward=np.array(manual_forward);
+ conv_blob.data[i].reshape(-1))
+ manual_forward.append(dot+self.net.params['ip'][1].data)
+ manual_forward=np.array(manual_forward)
- np.testing.assert_allclose(ip_blob.data,manual_forward,rtol=1e-3);
+ np.testing.assert_allclose(ip_blob.data,manual_forward,rtol=1e-3,atol=1e-5)
def test_backward_start_end(self):
- conv_blob=self.net.blobs['conv'];
- ip_blob=self.net.blobs['ip_blob'];
+ conv_blob=self.net.blobs['conv']
+ ip_blob=self.net.blobs['ip_blob']
sample_data=np.random.uniform(size=ip_blob.data.shape)
- sample_data=sample_data.astype(np.float32);
- ip_blob.diff[:]=sample_data;
- backward_blob=self.net.backward(start='ip',end='ip');
- self.assertIn('conv',backward_blob);
+ sample_data=sample_data.astype(np.float32)
+ ip_blob.diff[:]=sample_data
+ backward_blob=self.net.backward(start='ip',end='ip')
+ self.assertIn('conv',backward_blob)
- manual_backward=[];
+ manual_backward=[]
for i in range(0,conv_blob.data.shape[0]):
dot=np.dot(self.net.params['ip'][0].data.transpose(),
- sample_data[i].reshape(-1));
- manual_backward.append(dot);
- manual_backward=np.array(manual_backward);
- manual_backward=manual_backward.reshape(conv_blob.data.shape);
+ sample_data[i].reshape(-1))
+ manual_backward.append(dot)
+ manual_backward=np.array(manual_backward)
+ manual_backward=manual_backward.reshape(conv_blob.data.shape)
- np.testing.assert_allclose(conv_blob.diff,manual_backward,rtol=1e-3);
+ np.testing.assert_allclose(conv_blob.diff,manual_backward,rtol=1e-3,atol=1e-5)
def test_clear_param_diffs(self):
# Run a forward/backward step to have non-zero diffs