diff options
author | Gustav Larsson <gustav.m.larsson@gmail.com> | 2015-10-05 21:55:00 -0500 |
---|---|---|
committer | Gustav Larsson <gustav.m.larsson@gmail.com> | 2015-10-05 22:41:01 -0500 |
commit | 19d9927d76d6655a3efc090611e59aa2ea0f25a5 (patch) | |
tree | 446423eed837b080398b4c118ad5f218c6e6cb3a /python | |
parent | b4f9add57fa468ab43aa40f0a95badf3e9ace243 (diff) | |
download | caffeonacl-19d9927d76d6655a3efc090611e59aa2ea0f25a5.tar.gz caffeonacl-19d9927d76d6655a3efc090611e59aa2ea0f25a5.tar.bz2 caffeonacl-19d9927d76d6655a3efc090611e59aa2ea0f25a5.zip |
Add pycaffe test for solver.snapshot()
Diffstat (limited to 'python')
-rw-r--r-- | python/caffe/test/test_solver.py | 11 |
1 files changed, 10 insertions, 1 deletions
diff --git a/python/caffe/test/test_solver.py b/python/caffe/test/test_solver.py index 9cfc10d2..f618fded 100644 --- a/python/caffe/test/test_solver.py +++ b/python/caffe/test/test_solver.py @@ -16,7 +16,8 @@ class TestSolver(unittest.TestCase): f.write("""net: '""" + net_f + """' test_iter: 10 test_interval: 10 base_lr: 0.01 momentum: 0.9 weight_decay: 0.0005 lr_policy: 'inv' gamma: 0.0001 power: 0.75 - display: 100 max_iter: 100 snapshot_after_train: false""") + display: 100 max_iter: 100 snapshot_after_train: false + snapshot_prefix: "model" """) f.close() self.solver = caffe.SGDSolver(f.name) # also make sure get_solver runs @@ -51,3 +52,11 @@ class TestSolver(unittest.TestCase): total += p.data.sum() + p.diff.sum() for bl in six.itervalues(net.blobs): total += bl.data.sum() + bl.diff.sum() + + def test_snapshot(self): + self.solver.snapshot() + # Check that these files exist and then remove them + files = ['model_iter_0.caffemodel', 'model_iter_0.solverstate'] + for fn in files: + assert os.path.isfile(fn) + os.remove(fn) |