summaryrefslogtreecommitdiff
path: root/caffe2
diff options
context:
space:
mode:
authorYinghai Lu <yinghai@fb.com>2019-02-12 14:43:44 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-12 15:02:50 -0800
commitf435fb8290aeb00fbfa0191d5c42c59c5a772623 (patch)
treec6fd7dafb2a2d4d3cf598ae74596223ef93ec5ea /caffe2
parent65b49b46966abc1fd9b7ad6668a994e9a669be96 (diff)
downloadpytorch-f435fb8290aeb00fbfa0191d5c42c59c5a772623.tar.gz
pytorch-f435fb8290aeb00fbfa0191d5c42c59c5a772623.tar.bz2
pytorch-f435fb8290aeb00fbfa0191d5c42c59c5a772623.zip
Allow customization of blob node in net_drawer (#16915)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16915 TSIA Reviewed By: ipiszy Differential Revision: D14018010 fbshipit-source-id: df5ccc06fa37f08e7a02a8acc466c4ad47afe04e
Diffstat (limited to 'caffe2')
-rw-r--r--caffe2/core/nomnigraph/include/nomnigraph/Converters/Dot.h8
-rw-r--r--caffe2/opt/backend_cutting.cc7
-rw-r--r--caffe2/opt/backend_cutting.h3
-rw-r--r--caffe2/opt/onnxifi_transformer.cc9
-rw-r--r--caffe2/python/net_drawer.py30
5 files changed, 34 insertions, 23 deletions
diff --git a/caffe2/core/nomnigraph/include/nomnigraph/Converters/Dot.h b/caffe2/core/nomnigraph/include/nomnigraph/Converters/Dot.h
index e279378304..2a2800401c 100644
--- a/caffe2/core/nomnigraph/include/nomnigraph/Converters/Dot.h
+++ b/caffe2/core/nomnigraph/include/nomnigraph/Converters/Dot.h
@@ -38,8 +38,7 @@ class DotGenerator {
const typename GraphT::SubgraphType& sg,
const std::vector<typename GraphT::SubgraphType*>& subgraphs) const {
std::ostringstream output;
- output << "digraph G {\n\
- ";
+ output << "digraph G {\nrankdir=LR\n";
for (const auto& node : sg.getNodes()) {
generateNode(node, sg, output);
}
@@ -60,8 +59,7 @@ class DotGenerator {
// Convert a subgraph to dot.
std::string convert(const typename GraphT::SubgraphType& sg) const {
std::ostringstream output;
- output << "digraph G {\n\
- ";
+ output << "digraph G {\nrankdir=LR\n";
for (const auto& node : sg.getNodes()) {
generateNode(node, sg, output);
}
@@ -82,7 +80,7 @@ class DotGenerator {
*/
std::string convertStruct(const typename GraphT::SubgraphType& sg) const {
std::ostringstream output;
- output << "digraph G {\n";
+ output << "digraph G {\nrankdir=LR\n";
// Get input nodes (nodes w/o parents)
std::unordered_map<typename GraphT::NodeRef, int>
diff --git a/caffe2/opt/backend_cutting.cc b/caffe2/opt/backend_cutting.cc
index 5715e5a922..c4dd792470 100644
--- a/caffe2/opt/backend_cutting.cc
+++ b/caffe2/opt/backend_cutting.cc
@@ -346,7 +346,8 @@ void PruneUnrefereredNodes(NNModule* nn) {
caffe2::NetDef OptimizeForBackend(
caffe2::NetDef& net,
std::function<bool(const caffe2::OperatorDef&)> supports,
- std::function<caffe2::NetDef(const caffe2::NetDef&)> transform_func) {
+ std::function<caffe2::NetDef(const caffe2::NetDef&)> transform_func,
+ bool debug) {
auto nn = convertToNNModule(net);
auto& dfg = nn.dataFlow;
@@ -413,6 +414,10 @@ caffe2::NetDef OptimizeForBackend(
// absorbed
PruneUnrefereredNodes(&nn);
+ if (debug) {
+ DumpGraph(&dfg);
+ }
+
auto new_net = convertToCaffe2Proto(nn);
new_net.set_name(net.name() + "_opt");
return new_net;
diff --git a/caffe2/opt/backend_cutting.h b/caffe2/opt/backend_cutting.h
index 8ea1413dc1..cf98c11ddc 100644
--- a/caffe2/opt/backend_cutting.h
+++ b/caffe2/opt/backend_cutting.h
@@ -12,6 +12,7 @@ namespace opt {
CAFFE2_API caffe2::NetDef OptimizeForBackend(
caffe2::NetDef& net,
std::function<bool(const caffe2::OperatorDef&)> supports,
- std::function<caffe2::NetDef(const caffe2::NetDef&)> transform_func);
+ std::function<caffe2::NetDef(const caffe2::NetDef&)> transform_func,
+ bool debug = false);
}
} // namespace caffe2
diff --git a/caffe2/opt/onnxifi_transformer.cc b/caffe2/opt/onnxifi_transformer.cc
index 64e801cf9c..a1e089666a 100644
--- a/caffe2/opt/onnxifi_transformer.cc
+++ b/caffe2/opt/onnxifi_transformer.cc
@@ -285,10 +285,10 @@ int64_t GetBlob1stDimSize(
return shape_info.shape.dims(0);
}
-// Generates AdjustBatchOps for external inputs / outputs with type BATCH or
+// Generates AdjustBatchOps for external inputs/outputs with type BATCH or
// SEQ and adds them to input_ops and output_ops.
-// Meanwhile, modifies inputs / outputs of corresponding operators in the
-// onnxifi_net to use the new inputs / outputs of AdjustBatchOps.
+// Meanwhile, modifies inputs/outputs of corresponding operators in the
+// onnxifi_net to use the new inputs/outputs of AdjustBatchOps.
std::unordered_map<std::string, std::string> AddAdjustBatchOps(
const ShapeInfoMap& shape_hints,
NetDef* onnxifi_net,
@@ -979,7 +979,8 @@ NetDef OnnxifiTransformer::TransformViaOnnx(
net, weights, ws, &exporter2, shape_hints, &shape_hints_onnx);
};
- return opt::OptimizeForBackend(*pred_net, onnx_supports, onnx_converter);
+ return opt::OptimizeForBackend(
+ *pred_net, onnx_supports, onnx_converter, opts_.debug);
}
// Cutting off the runnable part and replace with ONNXIFI ops. Asssume the nets
diff --git a/caffe2/python/net_drawer.py b/caffe2/python/net_drawer.py
index ee124bc6c8..17f6c4b000 100644
--- a/caffe2/python/net_drawer.py
+++ b/caffe2/python/net_drawer.py
@@ -79,31 +79,38 @@ def GetOpNodeProducer(append_output, **kwargs):
return ReallyGetOpNode
+def GetBlobNodeProducer(**kwargs):
+ def ReallyGetBlobNode(node_name, label):
+ return pydot.Node(node_name, label=label, **kwargs)
+ return ReallyGetBlobNode
+
def GetPydotGraph(
operators_or_net,
name=None,
rankdir='LR',
- node_producer=None
+ op_node_producer=None,
+ blob_node_producer=None
):
- if node_producer is None:
- node_producer = GetOpNodeProducer(False, **OP_STYLE)
+ if op_node_producer is None:
+ op_node_producer = GetOpNodeProducer(False, **OP_STYLE)
+ if blob_node_producer is None:
+ blob_node_producer = GetBlobNodeProducer(**BLOB_STYLE)
operators, name = _rectify_operator_and_name(operators_or_net, name)
graph = pydot.Dot(name, rankdir=rankdir)
pydot_nodes = {}
pydot_node_counts = defaultdict(int)
for op_id, op in enumerate(operators):
- op_node = node_producer(op, op_id)
+ op_node = op_node_producer(op, op_id)
graph.add_node(op_node)
# print 'Op: %s' % op.name
# print 'inputs: %s' % str(op.input)
# print 'outputs: %s' % str(op.output)
for input_name in op.input:
if input_name not in pydot_nodes:
- input_node = pydot.Node(
+ input_node = blob_node_producer(
_escape_label(
input_name + str(pydot_node_counts[input_name])),
label=_escape_label(input_name),
- **BLOB_STYLE
)
pydot_nodes[input_name] = input_node
else:
@@ -114,11 +121,10 @@ def GetPydotGraph(
if output_name in pydot_nodes:
# we are overwriting an existing blob. need to updat the count.
pydot_node_counts[output_name] += 1
- output_node = pydot.Node(
+ output_node = blob_node_producer(
_escape_label(
output_name + str(pydot_node_counts[output_name])),
label=_escape_label(output_name),
- **BLOB_STYLE
)
pydot_nodes[output_name] = output_node
graph.add_node(output_node)
@@ -131,7 +137,7 @@ def GetPydotGraphMinimal(
name=None,
rankdir='LR',
minimal_dependency=False,
- node_producer=None,
+ op_node_producer=None,
):
"""Different from GetPydotGraph, hide all blob nodes and only show op nodes.
@@ -140,8 +146,8 @@ def GetPydotGraphMinimal(
op a and b, and op b depends on a, then only the edge b->c will be drawn
because a->c will be implied.
"""
- if node_producer is None:
- node_producer = GetOpNodeProducer(False, **OP_STYLE)
+ if op_node_producer is None:
+ op_node_producer = GetOpNodeProducer(False, **OP_STYLE)
operators, name = _rectify_operator_and_name(operators_or_net, name)
graph = pydot.Dot(name, rankdir=rankdir)
# blob_parents maps each blob name to its generating op.
@@ -149,7 +155,7 @@ def GetPydotGraphMinimal(
# op_ancestry records the ancestors of each op.
op_ancestry = defaultdict(set)
for op_id, op in enumerate(operators):
- op_node = node_producer(op, op_id)
+ op_node = op_node_producer(op, op_id)
graph.add_node(op_node)
# Get parents, and set up op ancestry.
parents = [