summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorGustav Larsson <gustav.m.larsson@gmail.com>2015-10-05 21:55:00 -0500
committerGustav Larsson <gustav.m.larsson@gmail.com>2015-10-05 22:41:01 -0500
commit19d9927d76d6655a3efc090611e59aa2ea0f25a5 (patch)
tree446423eed837b080398b4c118ad5f218c6e6cb3a /python
parentb4f9add57fa468ab43aa40f0a95badf3e9ace243 (diff)
downloadcaffeonacl-19d9927d76d6655a3efc090611e59aa2ea0f25a5.tar.gz
caffeonacl-19d9927d76d6655a3efc090611e59aa2ea0f25a5.tar.bz2
caffeonacl-19d9927d76d6655a3efc090611e59aa2ea0f25a5.zip
Add pycaffe test for solver.snapshot()
Diffstat (limited to 'python')
-rw-r--r--python/caffe/test/test_solver.py11
1 files changed, 10 insertions, 1 deletions
diff --git a/python/caffe/test/test_solver.py b/python/caffe/test/test_solver.py
index 9cfc10d2..f618fded 100644
--- a/python/caffe/test/test_solver.py
+++ b/python/caffe/test/test_solver.py
@@ -16,7 +16,8 @@ class TestSolver(unittest.TestCase):
f.write("""net: '""" + net_f + """'
test_iter: 10 test_interval: 10 base_lr: 0.01 momentum: 0.9
weight_decay: 0.0005 lr_policy: 'inv' gamma: 0.0001 power: 0.75
- display: 100 max_iter: 100 snapshot_after_train: false""")
+ display: 100 max_iter: 100 snapshot_after_train: false
+ snapshot_prefix: "model" """)
f.close()
self.solver = caffe.SGDSolver(f.name)
# also make sure get_solver runs
@@ -51,3 +52,11 @@ class TestSolver(unittest.TestCase):
total += p.data.sum() + p.diff.sum()
for bl in six.itervalues(net.blobs):
total += bl.data.sum() + bl.diff.sum()
+
+ def test_snapshot(self):
+ self.solver.snapshot()
+ # Check that these files exist and then remove them
+ files = ['model_iter_0.caffemodel', 'model_iter_0.solverstate']
+ for fn in files:
+ assert os.path.isfile(fn)
+ os.remove(fn)