summaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
authorKai Li <kaili_kloud@163.com>2014-02-11 10:54:15 +0800
committerKai Li <kaili_kloud@163.com>2014-02-11 11:01:35 +0800
commit0e41d0871a5fc90790b9fc13e90a7db7194ce7d5 (patch)
tree3b8adaa3ac32c208f81c800ad82a1ab31b519392 /scripts
parent8112ed98c73d614fcb0b760feab4e5d3fc3ecb85 (diff)
downloadcaffeonacl-0e41d0871a5fc90790b9fc13e90a7db7194ce7d5.tar.gz
caffeonacl-0e41d0871a5fc90790b9fc13e90a7db7194ce7d5.tar.bz2
caffeonacl-0e41d0871a5fc90790b9fc13e90a7db7194ce7d5.zip
Add python matplotlib example to plot the training log
Diffstat (limited to 'scripts')
-rwxr-xr-xscripts/plot_training_log.py.example185
1 files changed, 185 insertions, 0 deletions
diff --git a/scripts/plot_training_log.py.example b/scripts/plot_training_log.py.example
new file mode 100755
index 00000000..b40cbc77
--- /dev/null
+++ b/scripts/plot_training_log.py.example
@@ -0,0 +1,185 @@
+#!/usr/bin/env python
+import os
+import random
+import sys
+import matplotlib.cm as cmx
+import matplotlib.colors as colors
+import matplotlib.pyplot as plt
+import matplotlib.legend as lgd
+import matplotlib.markers as mks
+
+def get_log_parsing_script():
+ return './parselog.sh'
+
+def get_log_file_suffix():
+ return '.log'
+
+def get_chart_type_description_separator():
+ return ' vs. '
+
+def is_x_axis_field(field):
+ x_axis_fields = ['Iters', 'Seconds']
+ return field in x_axis_fields
+
+def create_field_index():
+ train_key = 'Training'
+ test_key = 'Test'
+ field_index = {train_key:{'Iters':0, 'Seconds':1, train_key + ' loss':2,
+ train_key + ' learning rate':3},
+ test_key:{'Iters':0, 'Seconds':1, test_key + ' accuracy':2,
+ test_key + ' loss':3}}
+ fields = set()
+ for data_file_type in field_index.keys():
+ fields = fields.union(set(field_index[data_file_type].keys()))
+ fields = list(fields)
+ fields.sort()
+ return field_index, fields
+
+def get_supported_chart_types():
+ field_index, fields = create_field_index()
+ num_fields = len(fields)
+ supported_chart_types = []
+ for i in xrange(num_fields):
+ if not is_x_axis_field(fields[i]):
+ for j in xrange(num_fields):
+ if i != j and is_x_axis_field(fields[j]):
+ supported_chart_types.append('%s%s%s' % (
+ fields[i], get_chart_type_description_separator(),
+ fields[j]))
+ return supported_chart_types
+
+def get_chart_type_description(chart_type):
+ supported_chart_types = get_supported_chart_types()
+ chart_type_description = supported_chart_types[chart_type]
+ return chart_type_description
+
+def get_data_file_type(chart_type):
+ description = get_chart_type_description(chart_type)
+ data_file_type = description.split()[0]
+ return data_file_type
+
+def get_data_file(chart_type, path_to_log):
+ return path_to_log + '.' + get_data_file_type(chart_type).lower()
+
+def get_field_descriptions(chart_type):
+ description = get_chart_type_description(chart_type).split(
+ get_chart_type_description_separator())
+ y_axis_field = description[0]
+ x_axis_field = description[1]
+ return x_axis_field, y_axis_field
+
+def get_field_indecies(x_axis_field, y_axis_field):
+ data_file_type = get_data_file_type(chart_type)
+ fields = create_field_index()[0][data_file_type]
+ return fields[x_axis_field], fields[y_axis_field]
+
+def load_data(data_file, field_idx0, field_idx1):
+ data = [[], []]
+ with open(data_file, 'r') as f:
+ for line in f:
+ line = line.strip()
+ if line[0] != '#':
+ fields = line.split()
+ data[0].append(float(fields[field_idx0].strip()))
+ data[1].append(float(fields[field_idx1].strip()))
+ return data
+
+def random_marker():
+ markers = mks.MarkerStyle.markers
+ num = len(markers.values())
+ idx = random.randint(0, num - 1)
+ return markers.values()[idx]
+
+def get_data_label(path_to_log):
+ label = path_to_log[path_to_log.rfind('/')+1 : path_to_log.rfind(
+ get_log_file_suffix())]
+ return label
+
+def get_legend_loc(chart_type):
+ x_axis, y_axis = get_field_descriptions(chart_type)
+ loc = 'lower right'
+ if y_axis.find('accuracy') != -1:
+ pass
+ if y_axis.find('loss') != -1 or y_axis.find('learning rate') != -1:
+ loc = 'upper right'
+ return loc
+
+def plot_chart(chart_type, path_to_png, path_to_log_list):
+ for path_to_log in path_to_log_list:
+ os.system('%s %s' % (get_log_parsing_script(), path_to_log))
+ data_file = get_data_file(chart_type, path_to_log)
+ x_axis_field, y_axis_field = get_field_descriptions(chart_type)
+ x, y = get_field_indecies(x_axis_field, y_axis_field)
+ data = load_data(data_file, x, y)
+ ## TODO: more systematic color cycle for lines
+ color = [random.random(), random.random(), random.random()]
+ label = get_data_label(path_to_log)
+ linewidth = 0.75
+ ## If there too many datapoints, do not use marker.
+## use_marker = False
+ use_marker = True
+ if not use_marker:
+ plt.plot(data[0], data[1], label = label, color = color,
+ linewidth = linewidth)
+ else:
+ ok = False
+ ## Some markers throw ValueError: Unrecognized marker style
+ while not ok:
+ try:
+ marker = random_marker()
+ plt.plot(data[0], data[1], label = label, color = color,
+ marker = marker, linewidth = linewidth)
+ ok = True
+ except:
+ pass
+ legend_loc = get_legend_loc(chart_type)
+ plt.legend(loc = legend_loc, ncol = 1) # ajust ncol to fit the space
+ plt.title(get_chart_type_description(chart_type))
+ plt.xlabel(x_axis_field)
+ plt.ylabel(y_axis_field)
+ plt.savefig(path_to_png)
+ plt.show()
+
+def print_help():
+ print """This script mainly serves as the basis of your customizations.
+Customization is a must.
+You can copy, paste, edit them in whatever way you want.
+Be warned that the fields in the training log may change in the future.
+You had better check the data files and change the mapping from field name to
+ field index in create_field_index before designing your own plots.
+Usage:
+ ./plot_log.sh chart_type[0-%s] /where/to/save.png /path/to/first.log ...
+Notes:
+ 1. Supporting multiple logs.
+ 2. Log file name must end with the lower-cased "%s".
+Supported chart types:""" % (len(get_supported_chart_types()) - 1,
+ get_log_file_suffix())
+ supported_chart_types = get_supported_chart_types()
+ num = len(supported_chart_types)
+ for i in xrange(num):
+ print ' %d: %s' % (i, supported_chart_types[i])
+ exit
+
+def is_valid_chart_type(chart_type):
+ return chart_type >= 0 and chart_type < len(get_supported_chart_types())
+
+if __name__ == '__main__':
+ if len(sys.argv) < 4:
+ print_help()
+ else:
+ chart_type = int(sys.argv[1])
+ if not is_valid_chart_type(chart_type):
+ print_help()
+ path_to_png = sys.argv[2]
+ if not path_to_png.endswith('.png'):
+ print 'Path must ends with png' % path_to_png
+ exit
+ path_to_logs = sys.argv[3:]
+ for path_to_log in path_to_logs:
+ if not os.path.exists(path_to_log):
+ print 'Path does not exist: %s' % path_to_log
+ exit
+ if not path_to_log.endswith(get_log_file_suffix()):
+ print_help()
+ ## plot_chart accpets multiple path_to_logs
+ plot_chart(chart_type, path_to_png, path_to_logs)