summaryrefslogtreecommitdiff
path: root/tools
diff options
context:
space:
mode:
authorYangqing Jia <jiayq84@gmail.com>2015-12-08 17:38:00 -0800
committerYangqing Jia <jiayq84@gmail.com>2015-12-08 17:38:00 -0800
commit03a00e8290df1bb40b5033c61992a3da527c51db (patch)
treea4bcf2173139306c88173b61739a332f7f6508c5 /tools
parent9c9f94e18a8909580a6b94c44dbb1e46f0ee8eb8 (diff)
parent84eb44e6cf9623e09c354a863e201971270ba25b (diff)
downloadcaffeonacl-03a00e8290df1bb40b5033c61992a3da527c51db.tar.gz
caffeonacl-03a00e8290df1bb40b5033c61992a3da527c51db.tar.bz2
caffeonacl-03a00e8290df1bb40b5033c61992a3da527c51db.zip
Merge pull request #3090 from longjon/summarize-tool
A Python script for at-a-glance net summary
Diffstat (limited to 'tools')
-rwxr-xr-xtools/extra/summarize.py140
1 files changed, 140 insertions, 0 deletions
diff --git a/tools/extra/summarize.py b/tools/extra/summarize.py
new file mode 100755
index 00000000..7e2d22fd
--- /dev/null
+++ b/tools/extra/summarize.py
@@ -0,0 +1,140 @@
+#!/usr/bin/env python
+
+"""Net summarization tool.
+
+This tool summarizes the structure of a net in a concise but comprehensive
+tabular listing, taking a prototxt file as input.
+
+Use this tool to check at a glance that the computation you've specified is the
+computation you expect.
+"""
+
+from caffe.proto import caffe_pb2
+from google import protobuf
+import re
+import argparse
+
+# ANSI codes for coloring blobs (used cyclically)
+COLORS = ['92', '93', '94', '95', '97', '96', '42', '43;30', '100',
+ '444', '103;30', '107;30']
+DISCONNECTED_COLOR = '41'
+
+def read_net(filename):
+ net = caffe_pb2.NetParameter()
+ with open(filename) as f:
+ protobuf.text_format.Parse(f.read(), net)
+ return net
+
+def format_param(param):
+ out = []
+ if len(param.name) > 0:
+ out.append(param.name)
+ if param.lr_mult != 1:
+ out.append('x{}'.format(param.lr_mult))
+ if param.decay_mult != 1:
+ out.append('Dx{}'.format(param.decay_mult))
+ return ' '.join(out)
+
+def printed_len(s):
+ return len(re.sub(r'\033\[[\d;]+m', '', s))
+
+def print_table(table, max_width):
+ """Print a simple nicely-aligned table.
+
+ table must be a list of (equal-length) lists. Columns are space-separated,
+ and as narrow as possible, but no wider than max_width. Text may overflow
+ columns; note that unlike string.format, this will not affect subsequent
+ columns, if possible."""
+
+ max_widths = [max_width] * len(table[0])
+ column_widths = [max(printed_len(row[j]) + 1 for row in table)
+ for j in range(len(table[0]))]
+ column_widths = [min(w, max_w) for w, max_w in zip(column_widths, max_widths)]
+
+ for row in table:
+ row_str = ''
+ right_col = 0
+ for cell, width in zip(row, column_widths):
+ right_col += width
+ row_str += cell + ' '
+ row_str += ' ' * max(right_col - printed_len(row_str), 0)
+ print row_str
+
+def summarize_net(net):
+ disconnected_tops = set()
+ for lr in net.layer:
+ disconnected_tops |= set(lr.top)
+ disconnected_tops -= set(lr.bottom)
+
+ table = []
+ colors = {}
+ for lr in net.layer:
+ tops = []
+ for ind, top in enumerate(lr.top):
+ color = colors.setdefault(top, COLORS[len(colors) % len(COLORS)])
+ if top in disconnected_tops:
+ top = '\033[1;4m' + top
+ if len(lr.loss_weight) > 0:
+ top = '{} * {}'.format(lr.loss_weight[ind], top)
+ tops.append('\033[{}m{}\033[0m'.format(color, top))
+ top_str = ', '.join(tops)
+
+ bottoms = []
+ for bottom in lr.bottom:
+ color = colors.get(bottom, DISCONNECTED_COLOR)
+ bottoms.append('\033[{}m{}\033[0m'.format(color, bottom))
+ bottom_str = ', '.join(bottoms)
+
+ if lr.type == 'Python':
+ type_str = lr.python_param.module + '.' + lr.python_param.layer
+ else:
+ type_str = lr.type
+
+ # Summarize conv/pool parameters.
+ # TODO support rectangular/ND parameters
+ conv_param = lr.convolution_param
+ if (lr.type in ['Convolution', 'Deconvolution']
+ and len(conv_param.kernel_size) == 1):
+ arg_str = str(conv_param.kernel_size[0])
+ if len(conv_param.stride) > 0 and conv_param.stride[0] != 1:
+ arg_str += '/' + str(conv_param.stride[0])
+ if len(conv_param.pad) > 0 and conv_param.pad[0] != 0:
+ arg_str += '+' + str(conv_param.pad[0])
+ arg_str += ' ' + str(conv_param.num_output)
+ if conv_param.group != 1:
+ arg_str += '/' + str(conv_param.group)
+ elif lr.type == 'Pooling':
+ arg_str = str(lr.pooling_param.kernel_size)
+ if lr.pooling_param.stride != 1:
+ arg_str += '/' + str(lr.pooling_param.stride)
+ if lr.pooling_param.pad != 0:
+ arg_str += '+' + str(lr.pooling_param.pad)
+ else:
+ arg_str = ''
+
+ if len(lr.param) > 0:
+ param_strs = map(format_param, lr.param)
+ if max(map(len, param_strs)) > 0:
+ param_str = '({})'.format(', '.join(param_strs))
+ else:
+ param_str = ''
+ else:
+ param_str = ''
+
+ table.append([lr.name, type_str, param_str, bottom_str, '->', top_str,
+ arg_str])
+ return table
+
+def main():
+ parser = argparse.ArgumentParser(description="Print a concise summary of net computation.")
+ parser.add_argument('filename', help='net prototxt file to summarize')
+ parser.add_argument('-w', '--max-width', help='maximum field width',
+ type=int, default=30)
+ args = parser.parse_args()
+
+ net = read_net(args.filename)
+ table = summarize_net(net)
+ print_table(table, max_width=args.max_width)
+
+if __name__ == '__main__':
+ main()