summaryrefslogtreecommitdiff
path: root/torch
diff options
context:
space:
mode:
authorMikhail Zolotukhin <mvz@fb.com>2019-01-29 00:17:30 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-01-29 00:26:37 -0800
commitdbebb5322cb5306df7844bf2daf140d031392452 (patch)
tree2240734ac3eceb9753ede0b10847d95d4a94aa16 /torch
parent0e6123fb8af16e173f7c7e86e0d3e7f2a635b6d0 (diff)
downloadpytorch-dbebb5322cb5306df7844bf2daf140d031392452.tar.gz
pytorch-dbebb5322cb5306df7844bf2daf140d031392452.tar.bz2
pytorch-dbebb5322cb5306df7844bf2daf140d031392452.zip
Properly screen string literals when dumping JIT IR
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16056 Differential Revision: D13719444 Pulled By: ZolotukhinM fbshipit-source-id: 7113ee9328eff6263513476cdf9254a2e1116f4c
Diffstat (limited to 'torch')
-rw-r--r--torch/csrc/jit/ir.cpp27
1 files changed, 15 insertions, 12 deletions
diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp
index 38ec0bbcf0..1d60a563e1 100644
--- a/torch/csrc/jit/ir.cpp
+++ b/torch/csrc/jit/ir.cpp
@@ -18,6 +18,9 @@
namespace torch {
namespace jit {
+
+void printQuotedString(std::ostream& stmt, const std::string& str);
+
// Constants relating to maintaining the topological index of nodes.
//
// Lower and upper bounds of the index. Inclusive range.
@@ -113,17 +116,17 @@ static void printPrimList(std::ostream& out, const std::vector<T>& items) {
out << "]";
}
-static std::string escapeString(std::string s) {
- std::vector<char> search = {'\n', '\t', '\v'};
- std::vector<std::string> replace = {"\\n", "\\t", "\\v"};
- for (size_t i = 0; i < search.size(); i++) {
- size_t pos = s.find(search[i]);
- while (pos != std::string::npos) {
- s.replace(pos, 1, replace[i]);
- pos = s.find(search[i], pos + 1);
- }
+static void printStrList(
+ std::ostream& out,
+ const std::vector<std::string>& items) {
+ out << "[";
+ int i = 0;
+ for (auto& item : items) {
+ if (i++ > 0)
+ out << ", ";
+ printQuotedString(out, item);
}
- return s;
+ out << "]";
}
void Node::printAttrValue(std::ostream& out, const Symbol& name) const {
@@ -141,10 +144,10 @@ void Node::printAttrValue(std::ostream& out, const Symbol& name) const {
printPrimList(out, is(name));
break;
case AttributeKind::s:
- out << "\"" << escapeString(s(name)) << "\"";
+ printQuotedString(out, s(name));
break;
case AttributeKind::ss:
- printPrimList(out, ss(name));
+ printStrList(out, ss(name));
break;
case AttributeKind::t: {
at::Tensor tensor = t(name);