summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorElias Ellison <eellison@fb.com>2018-12-19 10:45:32 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-12-19 10:52:54 -0800
commit33018e4e09b16075440ea72a6929b15c7ae670f5 (patch)
tree056303280bc6fedc8a674793f77c1786e967bdf6
parent560530aeecdae8c37d47c8ee7967c8583843a0bf (diff)
downloadpytorch-33018e4e09b16075440ea72a6929b15c7ae670f5.tar.gz
pytorch-33018e4e09b16075440ea72a6929b15c7ae670f5.tar.bz2
pytorch-33018e4e09b16075440ea72a6929b15c7ae670f5.zip
centralize side effects ops as node method (#15188)
Summary: A number of different passes rely on whether a node has side effects. This centralizes the list of side effectful ops in one place. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15188 Differential Revision: D13508438 Pulled By: eellison fbshipit-source-id: 2143e782b787731ce007b6dcd50cbde30e1b8dd0
-rw-r--r--torch/csrc/jit/ir.cpp11
-rw-r--r--torch/csrc/jit/ir.h1
-rw-r--r--torch/csrc/jit/passes/common_subexpression_elimination.cpp3
-rw-r--r--torch/csrc/jit/passes/constant_propagation.cpp6
-rw-r--r--torch/csrc/jit/passes/dead_code_elimination.cpp4
-rw-r--r--torch/csrc/jit/passes/shape_analysis.cpp9
6 files changed, 20 insertions, 14 deletions
diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp
index 650d5ef55c..df038e32e4 100644
--- a/torch/csrc/jit/ir.cpp
+++ b/torch/csrc/jit/ir.cpp
@@ -686,6 +686,17 @@ bool Node::isNondeterministic() const {
return true;
}
+bool Node::hasSideEffects() const {
+ switch (kind_) {
+ case prim::PythonOp:
+ case prim::Print:
+ case prim::RaiseException:
+ case aten::warn:
+ return true;
+ }
+ return false;
+}
+
// Assign this node a topological position, to facilitate fast isBefore() and
// isAfter() queries. Must be called right after a node is inserted into the
// node list.
diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h
index 2a6c9cfb2a..71a536149e 100644
--- a/torch/csrc/jit/ir.h
+++ b/torch/csrc/jit/ir.h
@@ -353,6 +353,7 @@ public:
}
TORCH_API bool isNondeterministic() const;
+ TORCH_API bool hasSideEffects () const;
// Graphs
diff --git a/torch/csrc/jit/passes/common_subexpression_elimination.cpp b/torch/csrc/jit/passes/common_subexpression_elimination.cpp
index b96cfaadc6..cac8f6b42b 100644
--- a/torch/csrc/jit/passes/common_subexpression_elimination.cpp
+++ b/torch/csrc/jit/passes/common_subexpression_elimination.cpp
@@ -23,8 +23,7 @@ void EliminateCommonSubexpression(
std::unordered_set<Node*, HashNode, EqualNode> subexprs;
for (auto it = block->nodes().begin(); it != block->nodes().end(); ++ it) {
auto node = *it;
- if (node->kind() == prim::PythonOp || node->kind() == prim::Print ||
- node->kind() == aten::warn || node->isNondeterministic() ||
+ if (node->hasSideEffects() || node->isNondeterministic() ||
aliasDb.hasWriters(node) || aliasDb.hasWildcard(node)) {
// Do NOT have enough information to do CSE on these nodes.
continue;
diff --git a/torch/csrc/jit/passes/constant_propagation.cpp b/torch/csrc/jit/passes/constant_propagation.cpp
index a1a6c1a817..2446759ce3 100644
--- a/torch/csrc/jit/passes/constant_propagation.cpp
+++ b/torch/csrc/jit/passes/constant_propagation.cpp
@@ -16,10 +16,6 @@ namespace {
std::unordered_set<Symbol> skip_list = {
prim::If,
prim::Loop, //TODO: handle Loop
- prim::Print,
- prim::RaiseException,
- aten::warn,
- prim::PythonOp, //may have side effects
prim::Constant,
prim::Undefined,
prim::NoneGenerator,
@@ -125,7 +121,7 @@ void ConstantPropagation(Node* n, const AliasDb& aliasDb, bool recurse) {
return v->node()->kind() == prim::Constant;
});
bool supported_node = !n->kind().is_onnx() &&
- skip_list.count(n->kind()) == 0 && !n->isNondeterministic() &&
+ skip_list.count(n->kind()) == 0 && !n->isNondeterministic() && !n->hasSideEffects() &&
!aliasDb.hasWriters(n) && !aliasDb.hasWildcard(n);
auto run_blocks = [&]() {
if (recurse) {
diff --git a/torch/csrc/jit/passes/dead_code_elimination.cpp b/torch/csrc/jit/passes/dead_code_elimination.cpp
index 0167d03d21..b7d606ce63 100644
--- a/torch/csrc/jit/passes/dead_code_elimination.cpp
+++ b/torch/csrc/jit/passes/dead_code_elimination.cpp
@@ -245,9 +245,7 @@ class DeadCodeEliminator {
auto it = memo_.find(node);
if (it != memo_.end())
return it->second;
- bool has_side_effects = node->kind() == prim::Print ||
- node->kind() == aten::warn || node->kind() == prim::RaiseException ||
- node->kind() == prim::PythonOp ||
+ bool has_side_effects = node->hasSideEffects() ||
std::any_of(node->blocks().begin(),
node->blocks().end(),
[&](Block* b) {
diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp
index 2d9677a2fd..85465d5828 100644
--- a/torch/csrc/jit/passes/shape_analysis.cpp
+++ b/torch/csrc/jit/passes/shape_analysis.cpp
@@ -434,10 +434,6 @@ class ShapePropagator {
}
return;
}
- case prim::PythonOp:
- case prim::Print:
- case prim::RaiseException:
- case aten::warn:
case prim::Undefined: {
setUnshapedType(node);
return;
@@ -445,6 +441,11 @@ class ShapePropagator {
default:
break; // fall-through
}
+
+ if (node->hasSideEffects()) {
+ return;
+ }
+
if (node->matches("aten::cat(Tensor[] tensors, int dim) -> Tensor")
|| node->kind() == prim::FusedConcat) {
return PropagateCatShape(node);