summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorZhiHeng NIU <niuzhiheng@gmail.com>2014-04-25 17:43:56 +0800
committerZhiHeng NIU <niuzhiheng@gmail.com>2014-04-25 17:43:56 +0800
commit65015e37128feb3ba561b96361ee9a5c509467a6 (patch)
treeb9759d33ba64105f482c91ad42481d44724817a6 /python
parentc55ebd09811aea2d9c36295ced6aa280079d308a (diff)
downloadcaffe-65015e37128feb3ba561b96361ee9a5c509467a6.tar.gz
caffe-65015e37128feb3ba561b96361ee9a5c509467a6.tar.bz2
caffe-65015e37128feb3ba561b96361ee9a5c509467a6.zip
Update the drawnet.py to reflect the recent revised net definition.
Diffstat (limited to 'python')
-rw-r--r--python/caffe/drawnet.py15
1 files changed, 11 insertions, 4 deletions
diff --git a/python/caffe/drawnet.py b/python/caffe/drawnet.py
index 8ff0d83f..de5a8760 100644
--- a/python/caffe/drawnet.py
+++ b/python/caffe/drawnet.py
@@ -15,14 +15,21 @@ NEURON_LAYER_STYLE = {'shape': 'record', 'fillcolor': '#90EE90',
'style': 'filled'}
BLOB_STYLE = {'shape': 'octagon', 'fillcolor': '#F0E68C',
'style': 'filled'}
+def get_enum_name_by_value():
+ desc = caffe_pb2.LayerParameter.LayerType.DESCRIPTOR
+ d = {}
+ for k,v in desc.values_by_name.items():
+ d[v.number] = k
+ return d
def get_pydot_graph(caffe_net):
- pydot_graph = pydot.Dot(caffe_net.name, graph_type='digraph')
+ pydot_graph = pydot.Dot(caffe_net.name, graph_type='digraph', rankdir="BT")
pydot_nodes = {}
pydot_edges = []
+ d = get_enum_name_by_value()
for layer in caffe_net.layers:
- name = layer.layer.name
- layertype = layer.layer.type
+ name = layer.name
+ layertype = d[layer.type]
if (len(layer.bottom) == 1 and len(layer.top) == 1 and
layer.bottom[0] == layer.top[0]):
# We have an in-place neuron layer.
@@ -63,7 +70,7 @@ def draw_net_to_file(caffe_net, filename):
to graphviz to draw graphs.
"""
ext = filename[filename.rfind('.')+1:]
- with open(filename, 'w') as fid:
+ with open(filename, 'wb') as fid:
fid.write(draw_net(caffe_net, ext))
if __name__ == '__main__':