summaryrefslogtreecommitdiff
path: root/scripts/download_model_binary.py
diff options
context:
space:
mode:
authorSergey Karayev <sergeykarayev@gmail.com>2014-08-12 17:29:25 -0700
committerSergey Karayev <sergeykarayev@gmail.com>2014-09-04 01:53:18 +0100
commite553573e2c4800e11050d6b83f0579766ebf4648 (patch)
treefeacf4a98bb26deafd06ff644c2d1ddeacd6b9eb /scripts/download_model_binary.py
parent41751046f18499b84dbaf529f64c0e664e2a09fe (diff)
downloadcaffeonacl-e553573e2c4800e11050d6b83f0579766ebf4648.tar.gz
caffeonacl-e553573e2c4800e11050d6b83f0579766ebf4648.tar.bz2
caffeonacl-e553573e2c4800e11050d6b83f0579766ebf4648.zip
[models] adding zoo readme; caffenet, alexnet, and rcnn models in zoo format
Diffstat (limited to 'scripts/download_model_binary.py')
-rwxr-xr-xscripts/download_model_binary.py76
1 files changed, 76 insertions, 0 deletions
diff --git a/scripts/download_model_binary.py b/scripts/download_model_binary.py
new file mode 100755
index 00000000..48e9015f
--- /dev/null
+++ b/scripts/download_model_binary.py
@@ -0,0 +1,76 @@
+#!/usr/bin/env python
+import os
+import sys
+import time
+import yaml
+import urllib
+import hashlib
+import argparse
+
+required_keys = ['caffemodel', 'caffemodel_url', 'sha1']
+
+
+def reporthook(count, block_size, total_size):
+ """
+ From http://blog.moleculea.com/2012/10/04/urlretrieve-progres-indicator/
+ """
+ global start_time
+ if count == 0:
+ start_time = time.time()
+ return
+ duration = time.time() - start_time
+ progress_size = int(count * block_size)
+ speed = int(progress_size / (1024 * duration))
+ percent = int(count * block_size * 100 / total_size)
+ sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" %
+ (percent, progress_size / (1024 * 1024), speed, duration))
+ sys.stdout.flush()
+
+
+def parse_readme_frontmatter(dirname):
+ readme_filename = os.path.join(dirname, 'readme.md')
+ with open(readme_filename) as f:
+ lines = [line.strip() for line in f.readlines()]
+ top = lines.index('---')
+ bottom = lines[top + 1:].index('---')
+ frontmatter = yaml.load('\n'.join(lines[top + 1:bottom]))
+ assert all(key in frontmatter for key in required_keys)
+ return dirname, frontmatter
+
+
+def valid_dirname(dirname):
+ try:
+ return parse_readme_frontmatter(dirname)
+ except Exception as e:
+ print('ERROR: {}'.format(e))
+ raise argparse.ArgumentTypeError(
+ 'Must be valid Caffe model directory with a correct readme.md')
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(
+ description='Download trained model binary.')
+ parser.add_argument('dirname', type=valid_dirname)
+ args = parser.parse_args()
+
+ # A tiny hack: the dirname validator also returns readme YAML frontmatter.
+ dirname = args.dirname[0]
+ frontmatter = args.dirname[1]
+ model_filename = os.path.join(dirname, frontmatter['caffemodel'])
+
+ # Closure-d function for checking SHA1.
+ def model_checks_out(filename=model_filename, sha1=frontmatter['sha1']):
+ with open(filename, 'r') as f:
+ return hashlib.sha1(f.read()).hexdigest() == sha1
+
+ # Check if model exists.
+ if os.path.exists(model_filename) and model_checks_out():
+ print("Model already exists.")
+ sys.exit(0)
+
+ # Download and verify model.
+ urllib.urlretrieve(
+ frontmatter['caffemodel_url'], model_filename, reporthook)
+ if not model_checks_out():
+ print('ERROR: model did not download correctly! Run this again.')
+ sys.exit(1)