diff options
author | Elias Ellison <eellison@fb.com> | 2018-12-19 10:45:32 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-12-19 10:52:54 -0800 |
commit | 33018e4e09b16075440ea72a6929b15c7ae670f5 (patch) | |
tree | 056303280bc6fedc8a674793f77c1786e967bdf6 | |
parent | 560530aeecdae8c37d47c8ee7967c8583843a0bf (diff) | |
download | pytorch-33018e4e09b16075440ea72a6929b15c7ae670f5.tar.gz pytorch-33018e4e09b16075440ea72a6929b15c7ae670f5.tar.bz2 pytorch-33018e4e09b16075440ea72a6929b15c7ae670f5.zip |
centralize side effects ops as node method (#15188)
Summary:
A number of different passes rely on whether a node has side effects. This centralizes the list of side effectful ops in one place.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15188
Differential Revision: D13508438
Pulled By: eellison
fbshipit-source-id: 2143e782b787731ce007b6dcd50cbde30e1b8dd0
-rw-r--r-- | torch/csrc/jit/ir.cpp | 11 | ||||
-rw-r--r-- | torch/csrc/jit/ir.h | 1 | ||||
-rw-r--r-- | torch/csrc/jit/passes/common_subexpression_elimination.cpp | 3 | ||||
-rw-r--r-- | torch/csrc/jit/passes/constant_propagation.cpp | 6 | ||||
-rw-r--r-- | torch/csrc/jit/passes/dead_code_elimination.cpp | 4 | ||||
-rw-r--r-- | torch/csrc/jit/passes/shape_analysis.cpp | 9 |
6 files changed, 20 insertions, 14 deletions
diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index 650d5ef55c..df038e32e4 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -686,6 +686,17 @@ bool Node::isNondeterministic() const { return true; } +bool Node::hasSideEffects() const { + switch (kind_) { + case prim::PythonOp: + case prim::Print: + case prim::RaiseException: + case aten::warn: + return true; + } + return false; +} + // Assign this node a topological position, to facilitate fast isBefore() and // isAfter() queries. Must be called right after a node is inserted into the // node list. diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h index 2a6c9cfb2a..71a536149e 100644 --- a/torch/csrc/jit/ir.h +++ b/torch/csrc/jit/ir.h @@ -353,6 +353,7 @@ public: } TORCH_API bool isNondeterministic() const; + TORCH_API bool hasSideEffects () const; // Graphs diff --git a/torch/csrc/jit/passes/common_subexpression_elimination.cpp b/torch/csrc/jit/passes/common_subexpression_elimination.cpp index b96cfaadc6..cac8f6b42b 100644 --- a/torch/csrc/jit/passes/common_subexpression_elimination.cpp +++ b/torch/csrc/jit/passes/common_subexpression_elimination.cpp @@ -23,8 +23,7 @@ void EliminateCommonSubexpression( std::unordered_set<Node*, HashNode, EqualNode> subexprs; for (auto it = block->nodes().begin(); it != block->nodes().end(); ++ it) { auto node = *it; - if (node->kind() == prim::PythonOp || node->kind() == prim::Print || - node->kind() == aten::warn || node->isNondeterministic() || + if (node->hasSideEffects() || node->isNondeterministic() || aliasDb.hasWriters(node) || aliasDb.hasWildcard(node)) { // Do NOT have enough information to do CSE on these nodes. continue; diff --git a/torch/csrc/jit/passes/constant_propagation.cpp b/torch/csrc/jit/passes/constant_propagation.cpp index a1a6c1a817..2446759ce3 100644 --- a/torch/csrc/jit/passes/constant_propagation.cpp +++ b/torch/csrc/jit/passes/constant_propagation.cpp @@ -16,10 +16,6 @@ namespace { std::unordered_set<Symbol> skip_list = { prim::If, prim::Loop, //TODO: handle Loop - prim::Print, - prim::RaiseException, - aten::warn, - prim::PythonOp, //may have side effects prim::Constant, prim::Undefined, prim::NoneGenerator, @@ -125,7 +121,7 @@ void ConstantPropagation(Node* n, const AliasDb& aliasDb, bool recurse) { return v->node()->kind() == prim::Constant; }); bool supported_node = !n->kind().is_onnx() && - skip_list.count(n->kind()) == 0 && !n->isNondeterministic() && + skip_list.count(n->kind()) == 0 && !n->isNondeterministic() && !n->hasSideEffects() && !aliasDb.hasWriters(n) && !aliasDb.hasWildcard(n); auto run_blocks = [&]() { if (recurse) { diff --git a/torch/csrc/jit/passes/dead_code_elimination.cpp b/torch/csrc/jit/passes/dead_code_elimination.cpp index 0167d03d21..b7d606ce63 100644 --- a/torch/csrc/jit/passes/dead_code_elimination.cpp +++ b/torch/csrc/jit/passes/dead_code_elimination.cpp @@ -245,9 +245,7 @@ class DeadCodeEliminator { auto it = memo_.find(node); if (it != memo_.end()) return it->second; - bool has_side_effects = node->kind() == prim::Print || - node->kind() == aten::warn || node->kind() == prim::RaiseException || - node->kind() == prim::PythonOp || + bool has_side_effects = node->hasSideEffects() || std::any_of(node->blocks().begin(), node->blocks().end(), [&](Block* b) { diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 2d9677a2fd..85465d5828 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -434,10 +434,6 @@ class ShapePropagator { } return; } - case prim::PythonOp: - case prim::Print: - case prim::RaiseException: - case aten::warn: case prim::Undefined: { setUnshapedType(node); return; @@ -445,6 +441,11 @@ class ShapePropagator { default: break; // fall-through } + + if (node->hasSideEffects()) { + return; + } + if (node->matches("aten::cat(Tensor[] tensors, int dim) -> Tensor") || node->kind() == prim::FusedConcat) { return PropagateCatShape(node); |