diff options
author | Sergey Karayev <sergeykarayev@gmail.com> | 2014-08-12 17:29:25 -0700 |
---|---|---|
committer | Sergey Karayev <sergeykarayev@gmail.com> | 2014-09-04 01:53:18 +0100 |
commit | e553573e2c4800e11050d6b83f0579766ebf4648 (patch) | |
tree | feacf4a98bb26deafd06ff644c2d1ddeacd6b9eb /scripts | |
parent | 41751046f18499b84dbaf529f64c0e664e2a09fe (diff) | |
download | caffeonacl-e553573e2c4800e11050d6b83f0579766ebf4648.tar.gz caffeonacl-e553573e2c4800e11050d6b83f0579766ebf4648.tar.bz2 caffeonacl-e553573e2c4800e11050d6b83f0579766ebf4648.zip |
[models] adding zoo readme; caffenet, alexnet, and rcnn models in zoo format
Diffstat (limited to 'scripts')
-rwxr-xr-x | scripts/download_model_binary.py | 76 |
1 files changed, 76 insertions, 0 deletions
diff --git a/scripts/download_model_binary.py b/scripts/download_model_binary.py new file mode 100755 index 00000000..48e9015f --- /dev/null +++ b/scripts/download_model_binary.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python +import os +import sys +import time +import yaml +import urllib +import hashlib +import argparse + +required_keys = ['caffemodel', 'caffemodel_url', 'sha1'] + + +def reporthook(count, block_size, total_size): + """ + From http://blog.moleculea.com/2012/10/04/urlretrieve-progres-indicator/ + """ + global start_time + if count == 0: + start_time = time.time() + return + duration = time.time() - start_time + progress_size = int(count * block_size) + speed = int(progress_size / (1024 * duration)) + percent = int(count * block_size * 100 / total_size) + sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" % + (percent, progress_size / (1024 * 1024), speed, duration)) + sys.stdout.flush() + + +def parse_readme_frontmatter(dirname): + readme_filename = os.path.join(dirname, 'readme.md') + with open(readme_filename) as f: + lines = [line.strip() for line in f.readlines()] + top = lines.index('---') + bottom = lines[top + 1:].index('---') + frontmatter = yaml.load('\n'.join(lines[top + 1:bottom])) + assert all(key in frontmatter for key in required_keys) + return dirname, frontmatter + + +def valid_dirname(dirname): + try: + return parse_readme_frontmatter(dirname) + except Exception as e: + print('ERROR: {}'.format(e)) + raise argparse.ArgumentTypeError( + 'Must be valid Caffe model directory with a correct readme.md') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Download trained model binary.') + parser.add_argument('dirname', type=valid_dirname) + args = parser.parse_args() + + # A tiny hack: the dirname validator also returns readme YAML frontmatter. + dirname = args.dirname[0] + frontmatter = args.dirname[1] + model_filename = os.path.join(dirname, frontmatter['caffemodel']) + + # Closure-d function for checking SHA1. + def model_checks_out(filename=model_filename, sha1=frontmatter['sha1']): + with open(filename, 'r') as f: + return hashlib.sha1(f.read()).hexdigest() == sha1 + + # Check if model exists. + if os.path.exists(model_filename) and model_checks_out(): + print("Model already exists.") + sys.exit(0) + + # Download and verify model. + urllib.urlretrieve( + frontmatter['caffemodel_url'], model_filename, reporthook) + if not model_checks_out(): + print('ERROR: model did not download correctly! Run this again.') + sys.exit(1) |