summaryrefslogtreecommitdiff
path: root/caffe2/opt
diff options
context:
space:
mode:
authorBram Wasti <bwasti@fb.com>2018-10-16 20:58:21 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-10-16 21:03:28 -0700
commit84edd4a48b6e4d273faaeeca96b98c3efde433c3 (patch)
treed9f0e380bb160f83dd1d5c3d5127d4cd66c1b7d5 /caffe2/opt
parent1bf642800d72bfdff17c66f8f13d929e1118d139 (diff)
downloadpytorch-84edd4a48b6e4d273faaeeca96b98c3efde433c3.tar.gz
pytorch-84edd4a48b6e4d273faaeeca96b98c3efde433c3.tar.bz2
pytorch-84edd4a48b6e4d273faaeeca96b98c3efde433c3.zip
Enable mapping from operatordef to converted node for debugging
Summary: Add a mapping for conversion -- this will help with debugging as well but is directly used by the TUI stacked on top of this Reviewed By: duc0 Differential Revision: D10396130 fbshipit-source-id: cdd39278f0ed563bb828b1aebbbd228f486d89c8
Diffstat (limited to 'caffe2/opt')
-rw-r--r--caffe2/opt/converter.cc8
-rw-r--r--caffe2/opt/converter.h9
2 files changed, 15 insertions, 2 deletions
diff --git a/caffe2/opt/converter.cc b/caffe2/opt/converter.cc
index 0774034337..27dbdbb6ff 100644
--- a/caffe2/opt/converter.cc
+++ b/caffe2/opt/converter.cc
@@ -264,7 +264,10 @@ std::unique_ptr<repr::NeuralNetOperator> convertToNeuralNetOperator(
/// \brief Ingest a caffe2 protobuf model and output an NNModule.
/// \param net The caffe2 protobuf NetDef
-repr::NNModule convertToNNModule(caffe2::NetDef &net, bool strict) {
+repr::NNModule convertToNNModule(
+ caffe2::NetDef& net,
+ bool strict,
+ std::vector<repr::NNGraph::NodeRef>* opNodeVec) {
repr::NNModule module;
repr::NNGraph& dfg = module.dataFlow;
repr::NNCFGraph& cfg = module.controlFlow;
@@ -315,6 +318,9 @@ repr::NNModule convertToNNModule(caffe2::NetDef &net, bool strict) {
}
opNode->resetData(convertToNeuralNetOperator(op));
+ if (opNodeVec) {
+ opNodeVec->emplace_back(opNode);
+ }
auto currentBasicBlock = bbNode->mutableData();
currentBasicBlock->pushInstructionNode(opNode);
}
diff --git a/caffe2/opt/converter.h b/caffe2/opt/converter.h
index 9666739d14..ab43b4033b 100644
--- a/caffe2/opt/converter.h
+++ b/caffe2/opt/converter.h
@@ -16,7 +16,14 @@ namespace caffe2 {
CAFFE2_API void injectDataEdgeIndicators(caffe2::NetDef* net);
CAFFE2_API void removeDataEdgeIndicators(caffe2::NetDef* net);
-CAFFE2_API nom::repr::NNModule convertToNNModule(caffe2::NetDef &net, bool strict = false);
+// Default conversion to a NNModule
+// Optionally strict -- which checks for various input and output conditions.
+// Optionally this function will update a vector that maps operators in the
+// netdef positionally to NodeRefs in the resultant NNModule.
+CAFFE2_API nom::repr::NNModule convertToNNModule(
+ caffe2::NetDef& net,
+ bool strict = false,
+ std::vector<nom::repr::NNGraph::NodeRef>* = nullptr);
CAFFE2_API caffe2::NetDef convertToCaffe2Proto(nom::repr::NNModule&);
// Pass in an oldNet to copy all the attributes of that network.