summaryrefslogtreecommitdiff
path: root/torch
diff options
context:
space:
mode:
authorMichael Suo <suo@fb.com>2019-04-19 12:48:39 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-19 13:02:06 -0700
commit9245eaf3f0f45583b1d44d221e700b88c0ad3b9d (patch)
tree90159d41f3c2ec145901b2b459850c9d5621498c /torch
parent73c166a5ed83ecccb1d45a37e1e5ef58e4f56bcf (diff)
downloadpytorch-9245eaf3f0f45583b1d44d221e700b88c0ad3b9d.tar.gz
pytorch-9245eaf3f0f45583b1d44d221e700b88c0ad3b9d.tar.bz2
pytorch-9245eaf3f0f45583b1d44d221e700b88c0ad3b9d.zip
Allow for segmented printing in PythonPrint (#19238)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19238 ghimport-source-id: 469d33cd187fa68840b201d625800a0f4fead547 Differential Revision: D14928291 Reviewed By: zdevito Pulled By: suo fbshipit-source-id: 257fce3dd1601ba192092d3fc318374e3752907e
Diffstat (limited to 'torch')
-rw-r--r--torch/csrc/jit/passes/python_print.cpp85
1 files changed, 45 insertions, 40 deletions
diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp
index aacc4ae859..0395ed9085 100644
--- a/torch/csrc/jit/passes/python_print.cpp
+++ b/torch/csrc/jit/passes/python_print.cpp
@@ -196,7 +196,7 @@ const static std::unordered_set<std::string> reserved_names = {
};
struct PythonPrintPass {
- std::ostream& out;
+ std::ostringstream body_;
// constants are written to this table, and given then named CONSTANTS.cN
// where N is the index into this table.
@@ -430,9 +430,9 @@ struct PythonPrintPass {
// indent to the current indent level
std::ostream& indent() {
for (size_t i = 0; i < level; ++i) {
- out << " ";
+ body_ << " ";
}
- return out;
+ return body_;
}
ResourceGuard WithIndented() {
@@ -492,10 +492,10 @@ struct PythonPrintPass {
void printAssignment(at::ArrayRef<Value*> lhs, at::ArrayRef<Value*> rhs) {
if (lhs.size() > 0) {
indent();
- printValueList(out, lhs);
- out << " = ";
- printValueList(out, rhs);
- out << "\n";
+ printValueList(body_, lhs);
+ body_ << " = ";
+ printValueList(body_, rhs);
+ body_ << "\n";
}
}
@@ -572,14 +572,14 @@ struct PythonPrintPass {
// Loop header
if (emit_as_for_loop) {
indent();
- out << "for " << useOf(stmt.currentTripCount()) << " in range("
- << useOf(stmt.maxTripCount()) << "):\n";
+ body_ << "for " << useOf(stmt.currentTripCount()) << " in range("
+ << useOf(stmt.maxTripCount()) << "):\n";
} else {
// note: trip_count_in_block is unused because this is a while loop,
// so we reuse the Value* as a stand-in for the loop condition
printAssignment(stmt.currentTripCount(), stmt.inputCond());
indent();
- out << "while " << useOf(stmt.currentTripCount()) << ":\n";
+ body_ << "while " << useOf(stmt.currentTripCount()) << ":\n";
}
// Loop body
{
@@ -646,10 +646,10 @@ struct PythonPrintPass {
indent();
// Print outputs
if (node->outputs().size() > 0) {
- printValueList(out, node->outputs());
- out << " = ";
+ printValueList(body_, node->outputs());
+ body_ << " = ";
}
- out << str << "\n";
+ body_ << str << "\n";
}
// Recursively check contained types for any class dependencies
@@ -679,7 +679,7 @@ struct PythonPrintPass {
if (enforce_importable_ && value->ignore_on_export) {
// Op has been marked as ignored, so insert an error in its place
indent();
- out << "ops.prim.IgnoredPythonOp()\n";
+ body_ << "ops.prim.IgnoredPythonOp()\n";
return;
}
}
@@ -693,9 +693,9 @@ struct PythonPrintPass {
}
if (node->inputs().size() > 0) {
indent();
- out << "return ";
- printValueList(out, node->inputs());
- out << "\n";
+ body_ << "return ";
+ printValueList(body_, node->inputs());
+ body_ << "\n";
}
break;
case prim::Loop:
@@ -713,9 +713,9 @@ struct PythonPrintPass {
// a, b, = unpacked
// a, = unpacked # trailing comma forces an unpack to happen
if (node->outputs().size() > 0) {
- printValueList(out, node->outputs(), "", ", = ");
+ printValueList(body_, node->outputs(), "", ", = ");
}
- out << useOf(node->input()) << "\n";
+ body_ << useOf(node->input()) << "\n";
break;
case prim::SetAttr: {
const auto obj = node->inputs().at(0);
@@ -723,7 +723,8 @@ struct PythonPrintPass {
const auto type = obj->type()->expect<ClassType>();
const auto& attrname = node->s(attr::name);
indent();
- out << useOf(obj) << "." << attrname << " = " << useOf(newVal) << "\n";
+ body_ << useOf(obj) << "." << attrname << " = " << useOf(newVal)
+ << "\n";
} break;
default:
std::stringstream ss;
@@ -974,12 +975,12 @@ struct PythonPrintPass {
if (!block_has_other_statements &&
root->nodes().begin() == root->nodes().end()) {
indent();
- out << "pass\n";
+ body_ << "pass\n";
}
for (auto* node : root->nodes()) {
printNode(node, /*print_const=*/false);
}
- return out;
+ return body_;
}
void printDefaultValue(
@@ -1027,7 +1028,7 @@ struct PythonPrintPass {
auto defaults_offset = defaults.begin();
indent();
- out << "def " << name << "(";
+ body_ << "def " << name << "(";
auto input_iter = true_inputs.begin();
// Print the `self` argument
@@ -1035,24 +1036,24 @@ struct PythonPrintPass {
// If this is a class, print the self var without a type annotation,
// following Python convention
AT_ASSERT(true_inputs.size() > 0);
- out << useOf(*input_iter);
+ body_ << useOf(*input_iter);
++input_iter;
AT_ASSERT(!defaults_offset->has_value());
++defaults_offset;
} else {
// If this is not a class, then we need to insert a "self".
- out << "self";
+ body_ << "self";
}
// Print the rest of the arguments
for (; input_iter != true_inputs.end(); ++input_iter) {
auto input = *input_iter;
- out << ",\n " << useOf(input) << ": " << input->type()->python_str();
+ body_ << ",\n " << useOf(input) << ": " << input->type()->python_str();
if (defaults_offset != defaults.end()) {
const c10::optional<IValue>& def = *defaults_offset++;
if (def) {
- printDefaultValue(input->type(), out, *def);
+ printDefaultValue(input->type(), body_, *def);
}
}
}
@@ -1060,7 +1061,7 @@ struct PythonPrintPass {
// have we use all the provided defaults?
AT_ASSERT(defaults_offset == defaults.end());
- out << ") -> " << resultType(graph)->python_str() << ":\n";
+ body_ << ") -> " << resultType(graph)->python_str() << ":\n";
{
auto guard = WithIndented();
// Print initial constant table (most are just inlined into their use,
@@ -1077,12 +1078,10 @@ struct PythonPrintPass {
public:
PythonPrintPass(
- std::ostream& out_,
std::vector<at::Tensor>& tensor_table,
std::vector<ClassTypePtr>& class_table,
bool enforce_importable)
- : out(out_),
- tensor_table_(tensor_table),
+ : tensor_table_(tensor_table),
class_table_(class_table),
enforce_importable_(enforce_importable) {}
@@ -1105,7 +1104,7 @@ struct PythonPrintPass {
const std::vector<std::string>& param_names = {}) {
printFunctionDefinition(graph, name, is_class, defaults, param_names);
while (!worklist.empty()) {
- out << "\n\n";
+ body_ << "\n\n";
auto work = worklist.back();
worklist.pop_back();
work();
@@ -1133,9 +1132,7 @@ struct PythonPrintPass {
[](const Argument& arg) { return arg.default_value(); });
printFunction(graph, name, is_class, defaults, ivalue_names);
}
- void printFunction(
- script::Function& method,
- bool is_class) {
+ void printFunction(script::Function& method, bool is_class) {
const std::string& name = method.name();
Graph& graph = *method.graph();
auto defaults = fmap(
@@ -1160,7 +1157,7 @@ struct PythonPrintPass {
}
void printClass(const ClassTypePtr& classType) {
- out << "class " << classType->name() << ":\n";
+ body_ << "class " << classType->name() << ":\n";
{
const auto guard = WithIndented();
for (auto& method : classType->methods()) {
@@ -1168,6 +1165,10 @@ struct PythonPrintPass {
}
}
}
+
+ void print(std::ostream& out) {
+ out << body_.str();
+ }
};
TORCH_API void PythonPrint(
@@ -1176,9 +1177,10 @@ TORCH_API void PythonPrint(
std::vector<at::Tensor>& tensor_table,
std::vector<ClassTypePtr>& class_table,
bool enforce_importable) {
- PythonPrintPass pp(out, tensor_table, class_table, enforce_importable);
+ PythonPrintPass pp(tensor_table, class_table, enforce_importable);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
pp.printFunction(const_cast<Graph&>(graph), "graph", /*is_class=*/false);
+ pp.print(out);
}
TORCH_API void PythonPrint(
@@ -1187,9 +1189,10 @@ TORCH_API void PythonPrint(
std::vector<at::Tensor>& tensor_table,
std::vector<ClassTypePtr>& class_table,
bool enforce_importable) {
- PythonPrintPass pp(out, tensor_table, class_table, enforce_importable);
+ PythonPrintPass pp(tensor_table, class_table, enforce_importable);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
pp.printMethod(const_cast<script::Method&>(method));
+ pp.print(out);
}
TORCH_API void PythonPrint(
@@ -1198,9 +1201,10 @@ TORCH_API void PythonPrint(
std::vector<at::Tensor>& tensor_table,
std::vector<ClassTypePtr>& class_table,
bool enforce_importable) {
- PythonPrintPass pp(out, tensor_table, class_table, enforce_importable);
+ PythonPrintPass pp(tensor_table, class_table, enforce_importable);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
pp.printModule(const_cast<script::Module&>(module));
+ pp.print(out);
}
TORCH_API void PythonPrint(
@@ -1209,8 +1213,9 @@ TORCH_API void PythonPrint(
std::vector<at::Tensor>& tensor_table,
std::vector<ClassTypePtr>& class_table,
bool enforce_importable) {
- PythonPrintPass pp(out, tensor_table, class_table, enforce_importable);
+ PythonPrintPass pp(tensor_table, class_table, enforce_importable);
pp.printClass(classType);
+ pp.print(out);
}
TORCH_API bool printerHasSpecialCaseFor(Symbol sym) {