summaryrefslogtreecommitdiff
path: root/torch
diff options
context:
space:
mode:
authorJames Reed <jamesreed@fb.com>2019-04-22 16:54:19 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-22 16:57:18 -0700
commit5be4bee4ff58d04c8ee73163e2cb2e06bd69b2a5 (patch)
tree88a18b63b945eb952199e65caa28b3e6e258160b /torch
parent969af4315a30e96205e125a16e67bd6e3c03e218 (diff)
downloadpytorch-5be4bee4ff58d04c8ee73163e2cb2e06bd69b2a5.tar.gz
pytorch-5be4bee4ff58d04c8ee73163e2cb2e06bd69b2a5.tar.bz2
pytorch-5be4bee4ff58d04c8ee73163e2cb2e06bd69b2a5.zip
Don't create FusionGroups for known-CPU producer values (#19342)
Summary: I believe the existing check in FuseGraph was only `false` if PyTorch was built with NO_CUDA=1. Otherwise, we would create fusion groups even if we're on a CPU-only machine running CPU code. This is confusing. Instead I've made it so that the decision to fuse or not is dependent on if the producer Value is a known CPU tensor. If it is, we skip fusion. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19342 Differential Revision: D15038351 Pulled By: jamesr66a fbshipit-source-id: fce9d83929309a7bf14346833f84b996f3e7f6db
Diffstat (limited to 'torch')
-rw-r--r--torch/csrc/jit/init.cpp2
-rw-r--r--torch/csrc/jit/passes/graph_fuser.cpp33
2 files changed, 22 insertions, 13 deletions
diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp
index 87014abf22..06aa9ca7ed 100644
--- a/torch/csrc/jit/init.cpp
+++ b/torch/csrc/jit/init.cpp
@@ -213,8 +213,6 @@ void initJITBindings(PyObject* module) {
.def("_jit_pass_fixup_onnx_loops", FixupONNXLoops)
.def("_jit_pass_canonicalize_ops", CanonicalizeOps)
.def("_jit_pass_specialize_autogradzero", specializeAutogradZero)
- .def("_jit_can_fuse_on_cpu", canFuseOnCPU)
- .def("_jit_can_fuse_on_gpu", canFuseOnGPU)
.def("_jit_override_can_fuse_on_cpu", &overrideCanFuseOnCPU)
.def(
"_jit_differentiate",
diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp
index 0c02c93e82..7debb1675c 100644
--- a/torch/csrc/jit/passes/graph_fuser.cpp
+++ b/torch/csrc/jit/passes/graph_fuser.cpp
@@ -532,6 +532,19 @@ struct GraphFuser {
return group;
}
+ bool isFusableDevice(Value *v) {
+ auto tensor_type = v->type()->cast<DimensionedTensorType>();
+ if (!tensor_type) {
+ return true;
+ }
+ if (tensor_type->device().is_cpu()) {
+ return canFuseOnCPU();
+ } else if (tensor_type->device().is_cuda()) {
+ return canFuseOnGPU();
+ }
+ throw std::runtime_error("Unknown device");
+ }
+
at::optional<Node*> tryFuse(Node* consumer, Value* producer) {
// this handles cases where producer can be moved _into_ the fusion group of
// consumer.
@@ -540,7 +553,7 @@ struct GraphFuser {
// we can move the consumer up into the producer.
// but this requires better handling of merging fusion groups so it is not
// done now
- bool shouldFuse = isFusable(producer->node()) &&
+ bool shouldFuse = isFusableDevice(producer) && isFusable(producer->node()) &&
// Rearrange nodes such that all uses of producer are after the
// consumer. Fusion will rewrite those later uses to use the version of
// producer generated by the fused blob. In this case, producer becomes
@@ -1426,16 +1439,14 @@ bool trackSingleGradSumToSizeToOutputs(
}
void FuseGraph(std::shared_ptr<Graph>& graph) {
- if (canFuseOnCPU() || canFuseOnGPU()) {
- GraphFuser(graph->block(), graph).run();
- // After FuseGraph some common subexpressions may come back
- EliminateCommonSubexpression(graph);
- // We might have emitted a fair amount of useless shape propagating code, so
- // remove it
- EliminateDeadCode(graph);
- // Improve the quality of shape propagation code that was left
- PeepholeOptimizeShapeExpressions(graph->block());
- }
+ GraphFuser(graph->block(), graph).run();
+ // After FuseGraph some common subexpressions may come back
+ EliminateCommonSubexpression(graph);
+ // We might have emitted a fair amount of useless shape propagating code, so
+ // remove it
+ EliminateDeadCode(graph);
+ // Improve the quality of shape propagation code that was left
+ PeepholeOptimizeShapeExpressions(graph->block());
}
} // namespace jit