import argparse import gzip import os from urllib.error import URLError from urllib.request import urlretrieve import sys MIRRORS = [ 'http://yann.lecun.com/exdb/mnist/', 'https://ossci-datasets.s3.amazonaws.com/mnist/', ] RESOURCES = [ 'train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz', ] def report_download_progress( chunk_number: int, chunk_size: int, file_size: int, ) -> None: if file_size != -1: percent = min(1, (chunk_number * chunk_size) / file_size) bar = '#' * int(64 * percent) sys.stdout.write('\r0% |{:<64}| {}%'.format(bar, int(percent * 100))) def download(destination_path: str, resource: str, quiet: bool) -> None: if os.path.exists(destination_path): if not quiet: print('{} already exists, skipping ...'.format(destination_path)) else: for mirror in MIRRORS: url = mirror + resource print('Downloading {} ...'.format(url)) try: hook = None if quiet else report_download_progress urlretrieve(url, destination_path, reporthook=hook) except (URLError, ConnectionError) as e: print('Failed to download (trying next):\n{}'.format(e)) continue finally: if not quiet: # Just a newline. print() break else: raise RuntimeError('Error downloading resource!') def unzip(zipped_path: str, quiet: bool) -> None: unzipped_path = os.path.splitext(zipped_path)[0] if os.path.exists(unzipped_path): if not quiet: print('{} already exists, skipping ... '.format(unzipped_path)) return with gzip.open(zipped_path, 'rb') as zipped_file: with open(unzipped_path, 'wb') as unzipped_file: unzipped_file.write(zipped_file.read()) if not quiet: print('Unzipped {} ...'.format(zipped_path)) def main() -> None: parser = argparse.ArgumentParser( description='Download the MNIST dataset from the internet') parser.add_argument( '-d', '--destination', default='.', help='Destination directory') parser.add_argument( '-q', '--quiet', action='store_true', help="Don't report about progress") options = parser.parse_args() if not os.path.exists(options.destination): os.makedirs(options.destination) try: for resource in RESOURCES: path = os.path.join(options.destination, resource) download(path, resource, options.quiet) unzip(path, options.quiet) except KeyboardInterrupt: print('Interrupted') if __name__ == '__main__': main()