diff options
author | Michael Suo <suo@fb.com> | 2019-04-19 12:48:39 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-19 13:02:06 -0700 |
commit | 9245eaf3f0f45583b1d44d221e700b88c0ad3b9d (patch) | |
tree | 90159d41f3c2ec145901b2b459850c9d5621498c /torch | |
parent | 73c166a5ed83ecccb1d45a37e1e5ef58e4f56bcf (diff) | |
download | pytorch-9245eaf3f0f45583b1d44d221e700b88c0ad3b9d.tar.gz pytorch-9245eaf3f0f45583b1d44d221e700b88c0ad3b9d.tar.bz2 pytorch-9245eaf3f0f45583b1d44d221e700b88c0ad3b9d.zip |
Allow for segmented printing in PythonPrint (#19238)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19238
ghimport-source-id: 469d33cd187fa68840b201d625800a0f4fead547
Differential Revision: D14928291
Reviewed By: zdevito
Pulled By: suo
fbshipit-source-id: 257fce3dd1601ba192092d3fc318374e3752907e
Diffstat (limited to 'torch')
-rw-r--r-- | torch/csrc/jit/passes/python_print.cpp | 85 |
1 files changed, 45 insertions, 40 deletions
diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp index aacc4ae859..0395ed9085 100644 --- a/torch/csrc/jit/passes/python_print.cpp +++ b/torch/csrc/jit/passes/python_print.cpp @@ -196,7 +196,7 @@ const static std::unordered_set<std::string> reserved_names = { }; struct PythonPrintPass { - std::ostream& out; + std::ostringstream body_; // constants are written to this table, and given then named CONSTANTS.cN // where N is the index into this table. @@ -430,9 +430,9 @@ struct PythonPrintPass { // indent to the current indent level std::ostream& indent() { for (size_t i = 0; i < level; ++i) { - out << " "; + body_ << " "; } - return out; + return body_; } ResourceGuard WithIndented() { @@ -492,10 +492,10 @@ struct PythonPrintPass { void printAssignment(at::ArrayRef<Value*> lhs, at::ArrayRef<Value*> rhs) { if (lhs.size() > 0) { indent(); - printValueList(out, lhs); - out << " = "; - printValueList(out, rhs); - out << "\n"; + printValueList(body_, lhs); + body_ << " = "; + printValueList(body_, rhs); + body_ << "\n"; } } @@ -572,14 +572,14 @@ struct PythonPrintPass { // Loop header if (emit_as_for_loop) { indent(); - out << "for " << useOf(stmt.currentTripCount()) << " in range(" - << useOf(stmt.maxTripCount()) << "):\n"; + body_ << "for " << useOf(stmt.currentTripCount()) << " in range(" + << useOf(stmt.maxTripCount()) << "):\n"; } else { // note: trip_count_in_block is unused because this is a while loop, // so we reuse the Value* as a stand-in for the loop condition printAssignment(stmt.currentTripCount(), stmt.inputCond()); indent(); - out << "while " << useOf(stmt.currentTripCount()) << ":\n"; + body_ << "while " << useOf(stmt.currentTripCount()) << ":\n"; } // Loop body { @@ -646,10 +646,10 @@ struct PythonPrintPass { indent(); // Print outputs if (node->outputs().size() > 0) { - printValueList(out, node->outputs()); - out << " = "; + printValueList(body_, node->outputs()); + body_ << " = "; } - out << str << "\n"; + body_ << str << "\n"; } // Recursively check contained types for any class dependencies @@ -679,7 +679,7 @@ struct PythonPrintPass { if (enforce_importable_ && value->ignore_on_export) { // Op has been marked as ignored, so insert an error in its place indent(); - out << "ops.prim.IgnoredPythonOp()\n"; + body_ << "ops.prim.IgnoredPythonOp()\n"; return; } } @@ -693,9 +693,9 @@ struct PythonPrintPass { } if (node->inputs().size() > 0) { indent(); - out << "return "; - printValueList(out, node->inputs()); - out << "\n"; + body_ << "return "; + printValueList(body_, node->inputs()); + body_ << "\n"; } break; case prim::Loop: @@ -713,9 +713,9 @@ struct PythonPrintPass { // a, b, = unpacked // a, = unpacked # trailing comma forces an unpack to happen if (node->outputs().size() > 0) { - printValueList(out, node->outputs(), "", ", = "); + printValueList(body_, node->outputs(), "", ", = "); } - out << useOf(node->input()) << "\n"; + body_ << useOf(node->input()) << "\n"; break; case prim::SetAttr: { const auto obj = node->inputs().at(0); @@ -723,7 +723,8 @@ struct PythonPrintPass { const auto type = obj->type()->expect<ClassType>(); const auto& attrname = node->s(attr::name); indent(); - out << useOf(obj) << "." << attrname << " = " << useOf(newVal) << "\n"; + body_ << useOf(obj) << "." << attrname << " = " << useOf(newVal) + << "\n"; } break; default: std::stringstream ss; @@ -974,12 +975,12 @@ struct PythonPrintPass { if (!block_has_other_statements && root->nodes().begin() == root->nodes().end()) { indent(); - out << "pass\n"; + body_ << "pass\n"; } for (auto* node : root->nodes()) { printNode(node, /*print_const=*/false); } - return out; + return body_; } void printDefaultValue( @@ -1027,7 +1028,7 @@ struct PythonPrintPass { auto defaults_offset = defaults.begin(); indent(); - out << "def " << name << "("; + body_ << "def " << name << "("; auto input_iter = true_inputs.begin(); // Print the `self` argument @@ -1035,24 +1036,24 @@ struct PythonPrintPass { // If this is a class, print the self var without a type annotation, // following Python convention AT_ASSERT(true_inputs.size() > 0); - out << useOf(*input_iter); + body_ << useOf(*input_iter); ++input_iter; AT_ASSERT(!defaults_offset->has_value()); ++defaults_offset; } else { // If this is not a class, then we need to insert a "self". - out << "self"; + body_ << "self"; } // Print the rest of the arguments for (; input_iter != true_inputs.end(); ++input_iter) { auto input = *input_iter; - out << ",\n " << useOf(input) << ": " << input->type()->python_str(); + body_ << ",\n " << useOf(input) << ": " << input->type()->python_str(); if (defaults_offset != defaults.end()) { const c10::optional<IValue>& def = *defaults_offset++; if (def) { - printDefaultValue(input->type(), out, *def); + printDefaultValue(input->type(), body_, *def); } } } @@ -1060,7 +1061,7 @@ struct PythonPrintPass { // have we use all the provided defaults? AT_ASSERT(defaults_offset == defaults.end()); - out << ") -> " << resultType(graph)->python_str() << ":\n"; + body_ << ") -> " << resultType(graph)->python_str() << ":\n"; { auto guard = WithIndented(); // Print initial constant table (most are just inlined into their use, @@ -1077,12 +1078,10 @@ struct PythonPrintPass { public: PythonPrintPass( - std::ostream& out_, std::vector<at::Tensor>& tensor_table, std::vector<ClassTypePtr>& class_table, bool enforce_importable) - : out(out_), - tensor_table_(tensor_table), + : tensor_table_(tensor_table), class_table_(class_table), enforce_importable_(enforce_importable) {} @@ -1105,7 +1104,7 @@ struct PythonPrintPass { const std::vector<std::string>& param_names = {}) { printFunctionDefinition(graph, name, is_class, defaults, param_names); while (!worklist.empty()) { - out << "\n\n"; + body_ << "\n\n"; auto work = worklist.back(); worklist.pop_back(); work(); @@ -1133,9 +1132,7 @@ struct PythonPrintPass { [](const Argument& arg) { return arg.default_value(); }); printFunction(graph, name, is_class, defaults, ivalue_names); } - void printFunction( - script::Function& method, - bool is_class) { + void printFunction(script::Function& method, bool is_class) { const std::string& name = method.name(); Graph& graph = *method.graph(); auto defaults = fmap( @@ -1160,7 +1157,7 @@ struct PythonPrintPass { } void printClass(const ClassTypePtr& classType) { - out << "class " << classType->name() << ":\n"; + body_ << "class " << classType->name() << ":\n"; { const auto guard = WithIndented(); for (auto& method : classType->methods()) { @@ -1168,6 +1165,10 @@ struct PythonPrintPass { } } } + + void print(std::ostream& out) { + out << body_.str(); + } }; TORCH_API void PythonPrint( @@ -1176,9 +1177,10 @@ TORCH_API void PythonPrint( std::vector<at::Tensor>& tensor_table, std::vector<ClassTypePtr>& class_table, bool enforce_importable) { - PythonPrintPass pp(out, tensor_table, class_table, enforce_importable); + PythonPrintPass pp(tensor_table, class_table, enforce_importable); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) pp.printFunction(const_cast<Graph&>(graph), "graph", /*is_class=*/false); + pp.print(out); } TORCH_API void PythonPrint( @@ -1187,9 +1189,10 @@ TORCH_API void PythonPrint( std::vector<at::Tensor>& tensor_table, std::vector<ClassTypePtr>& class_table, bool enforce_importable) { - PythonPrintPass pp(out, tensor_table, class_table, enforce_importable); + PythonPrintPass pp(tensor_table, class_table, enforce_importable); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) pp.printMethod(const_cast<script::Method&>(method)); + pp.print(out); } TORCH_API void PythonPrint( @@ -1198,9 +1201,10 @@ TORCH_API void PythonPrint( std::vector<at::Tensor>& tensor_table, std::vector<ClassTypePtr>& class_table, bool enforce_importable) { - PythonPrintPass pp(out, tensor_table, class_table, enforce_importable); + PythonPrintPass pp(tensor_table, class_table, enforce_importable); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) pp.printModule(const_cast<script::Module&>(module)); + pp.print(out); } TORCH_API void PythonPrint( @@ -1209,8 +1213,9 @@ TORCH_API void PythonPrint( std::vector<at::Tensor>& tensor_table, std::vector<ClassTypePtr>& class_table, bool enforce_importable) { - PythonPrintPass pp(out, tensor_table, class_table, enforce_importable); + PythonPrintPass pp(tensor_table, class_table, enforce_importable); pp.printClass(classType); + pp.print(out); } TORCH_API bool printerHasSpecialCaseFor(Symbol sym) { |