summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorEvan Shelhamer <shelhamer@imaginarynumber.net>2014-05-18 17:14:53 -0700
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2014-05-19 23:55:22 -0700
commit6b85fd006d87c3af538ce679eaaf0ba6b866765e (patch)
tree58429c3ef97d83bbb1ceb16d6755bec0b3bec9d3 /python
parent50d0b6d9c67d9ca6b062ddea6f5ec30189e61518 (diff)
downloadcaffe-6b85fd006d87c3af538ce679eaaf0ba6b866765e.tar.gz
caffe-6b85fd006d87c3af538ce679eaaf0ba6b866765e.tar.bz2
caffe-6b85fd006d87c3af538ce679eaaf0ba6b866765e.zip
split drawnet into module code and script
Don't run scripts in the module dir to avoid import collisions between io and caffe.io.
Diffstat (limited to 'python')
-rw-r--r--python/caffe/draw.py (renamed from python/caffe/drawnet.py)15
-rwxr-xr-xpython/draw_net.py25
2 files changed, 25 insertions, 15 deletions
diff --git a/python/caffe/drawnet.py b/python/caffe/draw.py
index ff18ecf4..f8631cfa 100644
--- a/python/caffe/drawnet.py
+++ b/python/caffe/draw.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
"""
Caffe network visualization: draw the NetParameter protobuffer.
@@ -10,8 +9,6 @@ Caffe.
from caffe.proto import caffe_pb2
from google.protobuf import text_format
import pydot
-import os
-import sys
# Internal layer and blob styles.
LAYER_STYLE = {'shape': 'record', 'fillcolor': '#6495ED',
@@ -77,15 +74,3 @@ def draw_net_to_file(caffe_net, filename):
ext = filename[filename.rfind('.')+1:]
with open(filename, 'wb') as fid:
fid.write(draw_net(caffe_net, ext))
-
-if __name__ == '__main__':
- if len(sys.argv) != 3:
- print 'Usage: %s input_net_proto_file output_image_file' % \
- os.path.basename(sys.argv[0])
- else:
- net = caffe_pb2.NetParameter()
- text_format.Merge(open(sys.argv[1]).read(), net)
- print 'Drawing net to %s' % sys.argv[2]
- draw_net_to_file(net, sys.argv[2])
-
-
diff --git a/python/draw_net.py b/python/draw_net.py
new file mode 100755
index 00000000..cbea5d9f
--- /dev/null
+++ b/python/draw_net.py
@@ -0,0 +1,25 @@
+#!/usr/bin/env python
+"""
+Draw a graph of the net architecture.
+"""
+import os
+from google.protobuf import text_format
+
+import caffe
+from caffe.proto import caffe_pb2
+
+
+def main(argv):
+ if len(argv) != 3:
+ print 'Usage: %s input_net_proto_file output_image_file' % \
+ os.path.basename(sys.argv[0])
+ else:
+ net = caffe_pb2.NetParameter()
+ text_format.Merge(open(sys.argv[1]).read(), net)
+ print 'Drawing net to %s' % sys.argv[2]
+ draw_net_to_file(net, sys.argv[2])
+
+
+if __name__ == '__main__':
+ import sys
+ main(sys.argv)