diff options
author | Yinghai Lu <yinghai@fb.com> | 2019-02-12 14:43:44 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-12 15:02:50 -0800 |
commit | f435fb8290aeb00fbfa0191d5c42c59c5a772623 (patch) | |
tree | c6fd7dafb2a2d4d3cf598ae74596223ef93ec5ea /caffe2 | |
parent | 65b49b46966abc1fd9b7ad6668a994e9a669be96 (diff) | |
download | pytorch-f435fb8290aeb00fbfa0191d5c42c59c5a772623.tar.gz pytorch-f435fb8290aeb00fbfa0191d5c42c59c5a772623.tar.bz2 pytorch-f435fb8290aeb00fbfa0191d5c42c59c5a772623.zip |
Allow customization of blob node in net_drawer (#16915)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16915
TSIA
Reviewed By: ipiszy
Differential Revision: D14018010
fbshipit-source-id: df5ccc06fa37f08e7a02a8acc466c4ad47afe04e
Diffstat (limited to 'caffe2')
-rw-r--r-- | caffe2/core/nomnigraph/include/nomnigraph/Converters/Dot.h | 8 | ||||
-rw-r--r-- | caffe2/opt/backend_cutting.cc | 7 | ||||
-rw-r--r-- | caffe2/opt/backend_cutting.h | 3 | ||||
-rw-r--r-- | caffe2/opt/onnxifi_transformer.cc | 9 | ||||
-rw-r--r-- | caffe2/python/net_drawer.py | 30 |
5 files changed, 34 insertions, 23 deletions
diff --git a/caffe2/core/nomnigraph/include/nomnigraph/Converters/Dot.h b/caffe2/core/nomnigraph/include/nomnigraph/Converters/Dot.h index e279378304..2a2800401c 100644 --- a/caffe2/core/nomnigraph/include/nomnigraph/Converters/Dot.h +++ b/caffe2/core/nomnigraph/include/nomnigraph/Converters/Dot.h @@ -38,8 +38,7 @@ class DotGenerator { const typename GraphT::SubgraphType& sg, const std::vector<typename GraphT::SubgraphType*>& subgraphs) const { std::ostringstream output; - output << "digraph G {\n\ - "; + output << "digraph G {\nrankdir=LR\n"; for (const auto& node : sg.getNodes()) { generateNode(node, sg, output); } @@ -60,8 +59,7 @@ class DotGenerator { // Convert a subgraph to dot. std::string convert(const typename GraphT::SubgraphType& sg) const { std::ostringstream output; - output << "digraph G {\n\ - "; + output << "digraph G {\nrankdir=LR\n"; for (const auto& node : sg.getNodes()) { generateNode(node, sg, output); } @@ -82,7 +80,7 @@ class DotGenerator { */ std::string convertStruct(const typename GraphT::SubgraphType& sg) const { std::ostringstream output; - output << "digraph G {\n"; + output << "digraph G {\nrankdir=LR\n"; // Get input nodes (nodes w/o parents) std::unordered_map<typename GraphT::NodeRef, int> diff --git a/caffe2/opt/backend_cutting.cc b/caffe2/opt/backend_cutting.cc index 5715e5a922..c4dd792470 100644 --- a/caffe2/opt/backend_cutting.cc +++ b/caffe2/opt/backend_cutting.cc @@ -346,7 +346,8 @@ void PruneUnrefereredNodes(NNModule* nn) { caffe2::NetDef OptimizeForBackend( caffe2::NetDef& net, std::function<bool(const caffe2::OperatorDef&)> supports, - std::function<caffe2::NetDef(const caffe2::NetDef&)> transform_func) { + std::function<caffe2::NetDef(const caffe2::NetDef&)> transform_func, + bool debug) { auto nn = convertToNNModule(net); auto& dfg = nn.dataFlow; @@ -413,6 +414,10 @@ caffe2::NetDef OptimizeForBackend( // absorbed PruneUnrefereredNodes(&nn); + if (debug) { + DumpGraph(&dfg); + } + auto new_net = convertToCaffe2Proto(nn); new_net.set_name(net.name() + "_opt"); return new_net; diff --git a/caffe2/opt/backend_cutting.h b/caffe2/opt/backend_cutting.h index 8ea1413dc1..cf98c11ddc 100644 --- a/caffe2/opt/backend_cutting.h +++ b/caffe2/opt/backend_cutting.h @@ -12,6 +12,7 @@ namespace opt { CAFFE2_API caffe2::NetDef OptimizeForBackend( caffe2::NetDef& net, std::function<bool(const caffe2::OperatorDef&)> supports, - std::function<caffe2::NetDef(const caffe2::NetDef&)> transform_func); + std::function<caffe2::NetDef(const caffe2::NetDef&)> transform_func, + bool debug = false); } } // namespace caffe2 diff --git a/caffe2/opt/onnxifi_transformer.cc b/caffe2/opt/onnxifi_transformer.cc index 64e801cf9c..a1e089666a 100644 --- a/caffe2/opt/onnxifi_transformer.cc +++ b/caffe2/opt/onnxifi_transformer.cc @@ -285,10 +285,10 @@ int64_t GetBlob1stDimSize( return shape_info.shape.dims(0); } -// Generates AdjustBatchOps for external inputs / outputs with type BATCH or +// Generates AdjustBatchOps for external inputs/outputs with type BATCH or // SEQ and adds them to input_ops and output_ops. -// Meanwhile, modifies inputs / outputs of corresponding operators in the -// onnxifi_net to use the new inputs / outputs of AdjustBatchOps. +// Meanwhile, modifies inputs/outputs of corresponding operators in the +// onnxifi_net to use the new inputs/outputs of AdjustBatchOps. std::unordered_map<std::string, std::string> AddAdjustBatchOps( const ShapeInfoMap& shape_hints, NetDef* onnxifi_net, @@ -979,7 +979,8 @@ NetDef OnnxifiTransformer::TransformViaOnnx( net, weights, ws, &exporter2, shape_hints, &shape_hints_onnx); }; - return opt::OptimizeForBackend(*pred_net, onnx_supports, onnx_converter); + return opt::OptimizeForBackend( + *pred_net, onnx_supports, onnx_converter, opts_.debug); } // Cutting off the runnable part and replace with ONNXIFI ops. Asssume the nets diff --git a/caffe2/python/net_drawer.py b/caffe2/python/net_drawer.py index ee124bc6c8..17f6c4b000 100644 --- a/caffe2/python/net_drawer.py +++ b/caffe2/python/net_drawer.py @@ -79,31 +79,38 @@ def GetOpNodeProducer(append_output, **kwargs): return ReallyGetOpNode +def GetBlobNodeProducer(**kwargs): + def ReallyGetBlobNode(node_name, label): + return pydot.Node(node_name, label=label, **kwargs) + return ReallyGetBlobNode + def GetPydotGraph( operators_or_net, name=None, rankdir='LR', - node_producer=None + op_node_producer=None, + blob_node_producer=None ): - if node_producer is None: - node_producer = GetOpNodeProducer(False, **OP_STYLE) + if op_node_producer is None: + op_node_producer = GetOpNodeProducer(False, **OP_STYLE) + if blob_node_producer is None: + blob_node_producer = GetBlobNodeProducer(**BLOB_STYLE) operators, name = _rectify_operator_and_name(operators_or_net, name) graph = pydot.Dot(name, rankdir=rankdir) pydot_nodes = {} pydot_node_counts = defaultdict(int) for op_id, op in enumerate(operators): - op_node = node_producer(op, op_id) + op_node = op_node_producer(op, op_id) graph.add_node(op_node) # print 'Op: %s' % op.name # print 'inputs: %s' % str(op.input) # print 'outputs: %s' % str(op.output) for input_name in op.input: if input_name not in pydot_nodes: - input_node = pydot.Node( + input_node = blob_node_producer( _escape_label( input_name + str(pydot_node_counts[input_name])), label=_escape_label(input_name), - **BLOB_STYLE ) pydot_nodes[input_name] = input_node else: @@ -114,11 +121,10 @@ def GetPydotGraph( if output_name in pydot_nodes: # we are overwriting an existing blob. need to updat the count. pydot_node_counts[output_name] += 1 - output_node = pydot.Node( + output_node = blob_node_producer( _escape_label( output_name + str(pydot_node_counts[output_name])), label=_escape_label(output_name), - **BLOB_STYLE ) pydot_nodes[output_name] = output_node graph.add_node(output_node) @@ -131,7 +137,7 @@ def GetPydotGraphMinimal( name=None, rankdir='LR', minimal_dependency=False, - node_producer=None, + op_node_producer=None, ): """Different from GetPydotGraph, hide all blob nodes and only show op nodes. @@ -140,8 +146,8 @@ def GetPydotGraphMinimal( op a and b, and op b depends on a, then only the edge b->c will be drawn because a->c will be implied. """ - if node_producer is None: - node_producer = GetOpNodeProducer(False, **OP_STYLE) + if op_node_producer is None: + op_node_producer = GetOpNodeProducer(False, **OP_STYLE) operators, name = _rectify_operator_and_name(operators_or_net, name) graph = pydot.Dot(name, rankdir=rankdir) # blob_parents maps each blob name to its generating op. @@ -149,7 +155,7 @@ def GetPydotGraphMinimal( # op_ancestry records the ancestors of each op. op_ancestry = defaultdict(set) for op_id, op in enumerate(operators): - op_node = node_producer(op, op_id) + op_node = op_node_producer(op, op_id) graph.add_node(op_node) # Get parents, and set up op ancestry. parents = [ |