diff options
author | Bram Wasti <bwasti@fb.com> | 2018-10-16 20:58:21 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-10-16 21:03:28 -0700 |
commit | 84edd4a48b6e4d273faaeeca96b98c3efde433c3 (patch) | |
tree | d9f0e380bb160f83dd1d5c3d5127d4cd66c1b7d5 /caffe2/opt | |
parent | 1bf642800d72bfdff17c66f8f13d929e1118d139 (diff) | |
download | pytorch-84edd4a48b6e4d273faaeeca96b98c3efde433c3.tar.gz pytorch-84edd4a48b6e4d273faaeeca96b98c3efde433c3.tar.bz2 pytorch-84edd4a48b6e4d273faaeeca96b98c3efde433c3.zip |
Enable mapping from operatordef to converted node for debugging
Summary: Add a mapping for conversion -- this will help with debugging as well but is directly used by the TUI stacked on top of this
Reviewed By: duc0
Differential Revision: D10396130
fbshipit-source-id: cdd39278f0ed563bb828b1aebbbd228f486d89c8
Diffstat (limited to 'caffe2/opt')
-rw-r--r-- | caffe2/opt/converter.cc | 8 | ||||
-rw-r--r-- | caffe2/opt/converter.h | 9 |
2 files changed, 15 insertions, 2 deletions
diff --git a/caffe2/opt/converter.cc b/caffe2/opt/converter.cc index 0774034337..27dbdbb6ff 100644 --- a/caffe2/opt/converter.cc +++ b/caffe2/opt/converter.cc @@ -264,7 +264,10 @@ std::unique_ptr<repr::NeuralNetOperator> convertToNeuralNetOperator( /// \brief Ingest a caffe2 protobuf model and output an NNModule. /// \param net The caffe2 protobuf NetDef -repr::NNModule convertToNNModule(caffe2::NetDef &net, bool strict) { +repr::NNModule convertToNNModule( + caffe2::NetDef& net, + bool strict, + std::vector<repr::NNGraph::NodeRef>* opNodeVec) { repr::NNModule module; repr::NNGraph& dfg = module.dataFlow; repr::NNCFGraph& cfg = module.controlFlow; @@ -315,6 +318,9 @@ repr::NNModule convertToNNModule(caffe2::NetDef &net, bool strict) { } opNode->resetData(convertToNeuralNetOperator(op)); + if (opNodeVec) { + opNodeVec->emplace_back(opNode); + } auto currentBasicBlock = bbNode->mutableData(); currentBasicBlock->pushInstructionNode(opNode); } diff --git a/caffe2/opt/converter.h b/caffe2/opt/converter.h index 9666739d14..ab43b4033b 100644 --- a/caffe2/opt/converter.h +++ b/caffe2/opt/converter.h @@ -16,7 +16,14 @@ namespace caffe2 { CAFFE2_API void injectDataEdgeIndicators(caffe2::NetDef* net); CAFFE2_API void removeDataEdgeIndicators(caffe2::NetDef* net); -CAFFE2_API nom::repr::NNModule convertToNNModule(caffe2::NetDef &net, bool strict = false); +// Default conversion to a NNModule +// Optionally strict -- which checks for various input and output conditions. +// Optionally this function will update a vector that maps operators in the +// netdef positionally to NodeRefs in the resultant NNModule. +CAFFE2_API nom::repr::NNModule convertToNNModule( + caffe2::NetDef& net, + bool strict = false, + std::vector<nom::repr::NNGraph::NodeRef>* = nullptr); CAFFE2_API caffe2::NetDef convertToCaffe2Proto(nom::repr::NNModule&); // Pass in an oldNet to copy all the attributes of that network. |