diff options
author | James Reed <jamesreed@fb.com> | 2019-04-22 16:54:19 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-22 16:57:18 -0700 |
commit | 5be4bee4ff58d04c8ee73163e2cb2e06bd69b2a5 (patch) | |
tree | 88a18b63b945eb952199e65caa28b3e6e258160b /torch | |
parent | 969af4315a30e96205e125a16e67bd6e3c03e218 (diff) | |
download | pytorch-5be4bee4ff58d04c8ee73163e2cb2e06bd69b2a5.tar.gz pytorch-5be4bee4ff58d04c8ee73163e2cb2e06bd69b2a5.tar.bz2 pytorch-5be4bee4ff58d04c8ee73163e2cb2e06bd69b2a5.zip |
Don't create FusionGroups for known-CPU producer values (#19342)
Summary:
I believe the existing check in FuseGraph was only `false` if PyTorch was built with NO_CUDA=1. Otherwise, we would create fusion groups even if we're on a CPU-only machine running CPU code. This is confusing. Instead I've made it so that the decision to fuse or not is dependent on if the producer Value is a known CPU tensor. If it is, we skip fusion.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19342
Differential Revision: D15038351
Pulled By: jamesr66a
fbshipit-source-id: fce9d83929309a7bf14346833f84b996f3e7f6db
Diffstat (limited to 'torch')
-rw-r--r-- | torch/csrc/jit/init.cpp | 2 | ||||
-rw-r--r-- | torch/csrc/jit/passes/graph_fuser.cpp | 33 |
2 files changed, 22 insertions, 13 deletions
diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp index 87014abf22..06aa9ca7ed 100644 --- a/torch/csrc/jit/init.cpp +++ b/torch/csrc/jit/init.cpp @@ -213,8 +213,6 @@ void initJITBindings(PyObject* module) { .def("_jit_pass_fixup_onnx_loops", FixupONNXLoops) .def("_jit_pass_canonicalize_ops", CanonicalizeOps) .def("_jit_pass_specialize_autogradzero", specializeAutogradZero) - .def("_jit_can_fuse_on_cpu", canFuseOnCPU) - .def("_jit_can_fuse_on_gpu", canFuseOnGPU) .def("_jit_override_can_fuse_on_cpu", &overrideCanFuseOnCPU) .def( "_jit_differentiate", diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index 0c02c93e82..7debb1675c 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -532,6 +532,19 @@ struct GraphFuser { return group; } + bool isFusableDevice(Value *v) { + auto tensor_type = v->type()->cast<DimensionedTensorType>(); + if (!tensor_type) { + return true; + } + if (tensor_type->device().is_cpu()) { + return canFuseOnCPU(); + } else if (tensor_type->device().is_cuda()) { + return canFuseOnGPU(); + } + throw std::runtime_error("Unknown device"); + } + at::optional<Node*> tryFuse(Node* consumer, Value* producer) { // this handles cases where producer can be moved _into_ the fusion group of // consumer. @@ -540,7 +553,7 @@ struct GraphFuser { // we can move the consumer up into the producer. // but this requires better handling of merging fusion groups so it is not // done now - bool shouldFuse = isFusable(producer->node()) && + bool shouldFuse = isFusableDevice(producer) && isFusable(producer->node()) && // Rearrange nodes such that all uses of producer are after the // consumer. Fusion will rewrite those later uses to use the version of // producer generated by the fused blob. In this case, producer becomes @@ -1426,16 +1439,14 @@ bool trackSingleGradSumToSizeToOutputs( } void FuseGraph(std::shared_ptr<Graph>& graph) { - if (canFuseOnCPU() || canFuseOnGPU()) { - GraphFuser(graph->block(), graph).run(); - // After FuseGraph some common subexpressions may come back - EliminateCommonSubexpression(graph); - // We might have emitted a fair amount of useless shape propagating code, so - // remove it - EliminateDeadCode(graph); - // Improve the quality of shape propagation code that was left - PeepholeOptimizeShapeExpressions(graph->block()); - } + GraphFuser(graph->block(), graph).run(); + // After FuseGraph some common subexpressions may come back + EliminateCommonSubexpression(graph); + // We might have emitted a fair amount of useless shape propagating code, so + // remove it + EliminateDeadCode(graph); + // Improve the quality of shape propagation code that was left + PeepholeOptimizeShapeExpressions(graph->block()); } } // namespace jit |