diff options
author | Yangqing Jia <jiayq84@gmail.com> | 2015-12-08 17:38:00 -0800 |
---|---|---|
committer | Yangqing Jia <jiayq84@gmail.com> | 2015-12-08 17:38:00 -0800 |
commit | 03a00e8290df1bb40b5033c61992a3da527c51db (patch) | |
tree | a4bcf2173139306c88173b61739a332f7f6508c5 /tools | |
parent | 9c9f94e18a8909580a6b94c44dbb1e46f0ee8eb8 (diff) | |
parent | 84eb44e6cf9623e09c354a863e201971270ba25b (diff) | |
download | caffeonacl-03a00e8290df1bb40b5033c61992a3da527c51db.tar.gz caffeonacl-03a00e8290df1bb40b5033c61992a3da527c51db.tar.bz2 caffeonacl-03a00e8290df1bb40b5033c61992a3da527c51db.zip |
Merge pull request #3090 from longjon/summarize-tool
A Python script for at-a-glance net summary
Diffstat (limited to 'tools')
-rwxr-xr-x | tools/extra/summarize.py | 140 |
1 files changed, 140 insertions, 0 deletions
diff --git a/tools/extra/summarize.py b/tools/extra/summarize.py new file mode 100755 index 00000000..7e2d22fd --- /dev/null +++ b/tools/extra/summarize.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python + +"""Net summarization tool. + +This tool summarizes the structure of a net in a concise but comprehensive +tabular listing, taking a prototxt file as input. + +Use this tool to check at a glance that the computation you've specified is the +computation you expect. +""" + +from caffe.proto import caffe_pb2 +from google import protobuf +import re +import argparse + +# ANSI codes for coloring blobs (used cyclically) +COLORS = ['92', '93', '94', '95', '97', '96', '42', '43;30', '100', + '444', '103;30', '107;30'] +DISCONNECTED_COLOR = '41' + +def read_net(filename): + net = caffe_pb2.NetParameter() + with open(filename) as f: + protobuf.text_format.Parse(f.read(), net) + return net + +def format_param(param): + out = [] + if len(param.name) > 0: + out.append(param.name) + if param.lr_mult != 1: + out.append('x{}'.format(param.lr_mult)) + if param.decay_mult != 1: + out.append('Dx{}'.format(param.decay_mult)) + return ' '.join(out) + +def printed_len(s): + return len(re.sub(r'\033\[[\d;]+m', '', s)) + +def print_table(table, max_width): + """Print a simple nicely-aligned table. + + table must be a list of (equal-length) lists. Columns are space-separated, + and as narrow as possible, but no wider than max_width. Text may overflow + columns; note that unlike string.format, this will not affect subsequent + columns, if possible.""" + + max_widths = [max_width] * len(table[0]) + column_widths = [max(printed_len(row[j]) + 1 for row in table) + for j in range(len(table[0]))] + column_widths = [min(w, max_w) for w, max_w in zip(column_widths, max_widths)] + + for row in table: + row_str = '' + right_col = 0 + for cell, width in zip(row, column_widths): + right_col += width + row_str += cell + ' ' + row_str += ' ' * max(right_col - printed_len(row_str), 0) + print row_str + +def summarize_net(net): + disconnected_tops = set() + for lr in net.layer: + disconnected_tops |= set(lr.top) + disconnected_tops -= set(lr.bottom) + + table = [] + colors = {} + for lr in net.layer: + tops = [] + for ind, top in enumerate(lr.top): + color = colors.setdefault(top, COLORS[len(colors) % len(COLORS)]) + if top in disconnected_tops: + top = '\033[1;4m' + top + if len(lr.loss_weight) > 0: + top = '{} * {}'.format(lr.loss_weight[ind], top) + tops.append('\033[{}m{}\033[0m'.format(color, top)) + top_str = ', '.join(tops) + + bottoms = [] + for bottom in lr.bottom: + color = colors.get(bottom, DISCONNECTED_COLOR) + bottoms.append('\033[{}m{}\033[0m'.format(color, bottom)) + bottom_str = ', '.join(bottoms) + + if lr.type == 'Python': + type_str = lr.python_param.module + '.' + lr.python_param.layer + else: + type_str = lr.type + + # Summarize conv/pool parameters. + # TODO support rectangular/ND parameters + conv_param = lr.convolution_param + if (lr.type in ['Convolution', 'Deconvolution'] + and len(conv_param.kernel_size) == 1): + arg_str = str(conv_param.kernel_size[0]) + if len(conv_param.stride) > 0 and conv_param.stride[0] != 1: + arg_str += '/' + str(conv_param.stride[0]) + if len(conv_param.pad) > 0 and conv_param.pad[0] != 0: + arg_str += '+' + str(conv_param.pad[0]) + arg_str += ' ' + str(conv_param.num_output) + if conv_param.group != 1: + arg_str += '/' + str(conv_param.group) + elif lr.type == 'Pooling': + arg_str = str(lr.pooling_param.kernel_size) + if lr.pooling_param.stride != 1: + arg_str += '/' + str(lr.pooling_param.stride) + if lr.pooling_param.pad != 0: + arg_str += '+' + str(lr.pooling_param.pad) + else: + arg_str = '' + + if len(lr.param) > 0: + param_strs = map(format_param, lr.param) + if max(map(len, param_strs)) > 0: + param_str = '({})'.format(', '.join(param_strs)) + else: + param_str = '' + else: + param_str = '' + + table.append([lr.name, type_str, param_str, bottom_str, '->', top_str, + arg_str]) + return table + +def main(): + parser = argparse.ArgumentParser(description="Print a concise summary of net computation.") + parser.add_argument('filename', help='net prototxt file to summarize') + parser.add_argument('-w', '--max-width', help='maximum field width', + type=int, default=30) + args = parser.parse_args() + + net = read_net(args.filename) + table = summarize_net(net) + print_table(table, max_width=args.max_width) + +if __name__ == '__main__': + main() |