summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorValentin Tolmer <valentin.tolmer@gmail.com>2016-06-21 17:12:57 -0700
committerWook Song <wook16.song@samsung.com>2020-01-23 22:50:45 +0900
commit1c614b3304a5c38aa6d2b67bf11c4ca696741951 (patch)
tree456d7a21a6c8531b6361d95be307706ca9a3c55e
parent3c4e2c548a9db09381607da85a4bbde990ca7f2b (diff)
downloadcaffe-1c614b3304a5c38aa6d2b67bf11c4ca696741951.tar.gz
caffe-1c614b3304a5c38aa6d2b67bf11c4ca696741951.tar.bz2
caffe-1c614b3304a5c38aa6d2b67bf11c4ca696741951.zip
[pycaffe] test solver update
-rw-r--r--python/caffe/test/test_solver.py11
1 files changed, 11 insertions, 0 deletions
diff --git a/python/caffe/test/test_solver.py b/python/caffe/test/test_solver.py
index f618fded..50c9d541 100644
--- a/python/caffe/test/test_solver.py
+++ b/python/caffe/test/test_solver.py
@@ -38,6 +38,17 @@ class TestSolver(unittest.TestCase):
self.solver.solve()
self.assertEqual(self.solver.iter, 100)
+ def test_apply_update(self):
+ net = self.solver.net
+ data = net.layers[1].blobs[0].data[...]
+ # Reset the weights of that layer to 0
+ data[...] = 0
+ net.layers[1].blobs[0].diff[...] = 1
+ # Apply the update, the initial learning rate should be 0.01
+ self.solver.apply_update()
+ # Check that the new weights are -0.01, with a precision of 1e-7
+ self.assertTrue((data - -0.01 * np.ones(data.shape)).max() < 1e-7)
+
def test_net_memory(self):
"""Check that nets survive after the solver is destroyed."""