diff options
author | James Reed <jamesreed@fb.com> | 2019-04-22 16:54:19 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-22 16:57:18 -0700 |
commit | 5be4bee4ff58d04c8ee73163e2cb2e06bd69b2a5 (patch) | |
tree | 88a18b63b945eb952199e65caa28b3e6e258160b | |
parent | 969af4315a30e96205e125a16e67bd6e3c03e218 (diff) | |
download | pytorch-5be4bee4ff58d04c8ee73163e2cb2e06bd69b2a5.tar.gz pytorch-5be4bee4ff58d04c8ee73163e2cb2e06bd69b2a5.tar.bz2 pytorch-5be4bee4ff58d04c8ee73163e2cb2e06bd69b2a5.zip |
Don't create FusionGroups for known-CPU producer values (#19342)
Summary:
I believe the existing check in FuseGraph was only `false` if PyTorch was built with NO_CUDA=1. Otherwise, we would create fusion groups even if we're on a CPU-only machine running CPU code. This is confusing. Instead I've made it so that the decision to fuse or not is dependent on if the producer Value is a known CPU tensor. If it is, we skip fusion.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19342
Differential Revision: D15038351
Pulled By: jamesr66a
fbshipit-source-id: fce9d83929309a7bf14346833f84b996f3e7f6db
-rw-r--r-- | test/test_jit.py | 19 | ||||
-rw-r--r-- | test/test_jit_fuser.py | 2 | ||||
-rw-r--r-- | torch/csrc/jit/init.cpp | 2 | ||||
-rw-r--r-- | torch/csrc/jit/passes/graph_fuser.cpp | 33 |
4 files changed, 35 insertions, 21 deletions
diff --git a/test/test_jit.py b/test/test_jit.py index f1310519b8..5457edb2fb 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -69,10 +69,6 @@ except ImportError: skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") -# Note: creating FusionGroups is currently device-independent. -# FusionGroup creation with CPU is disabled. -FUSION_ENABLED = torch._C._jit_can_fuse_on_cpu() or torch._C._jit_can_fuse_on_gpu() - RUN_CUDA = torch.cuda.is_available() RUN_CUDA_HALF = RUN_CUDA if torch.cuda.is_available(): @@ -438,9 +434,6 @@ class JitTestCase(TestCase): self.assertExpected(str(graph), *args, **kwargs) def assertAutodiffNode(self, graph, should_autodiff_node, nonfusible_nodes, fusible_nodes): - if not FUSION_ENABLED: - nonfusible_nodes = nonfusible_nodes + fusible_nodes - fusible_nodes = [] diff_nodes = graph.findAllNodes('prim::DifferentiableGraph') diff_subgraphs = [node.g('Subgraph') for node in diff_nodes] @@ -13138,6 +13131,10 @@ def add_autograd_test( # we want to close over in some way def do_test(self, name=name, self_size=self_size, args=new_args, test_name=test_name, check_ad=check_ad, output_process_fn=output_process_fn): + # We enable the CPU fuser during these checks for more consistent + # behavior. Otherwise, we are going to have to analyze the graph to + # see if producer values are Dimension + @enable_cpu_fuser def check(name): set_rng_seed(2) is_magic_method = name[:2] == '__' and name[-2:] == '__' @@ -13169,6 +13166,10 @@ def add_autograd_test( check_against_reference(self, traced_fn, fn, (self_variable,) + args_variable, kwargs_variable, check_types=check_types) + # Fuser not supported on windows + if IS_WINDOWS: + autodiff_nodes = autodiff_nodes + fusible_nodes + fusible_nodes = [] self.assertAutodiffNode(traced_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes) if not is_magic_method and test_name not in EXCLUDE_SCRIPT: @@ -13177,6 +13178,10 @@ def add_autograd_test( fn, (self_variable,) + args_variable, kwargs_variable, check_types=check_types) + # Fuser not supported on windows + if IS_WINDOWS: + autodiff_nodes = autodiff_nodes + fusible_nodes + fusible_nodes = [] self.assertAutodiffNode(script_fn.last_graph, should_autodiff_node and test_name not in EXCLUDE_SCRIPT_AD_CHECK, autodiff_nodes, diff --git a/test/test_jit_fuser.py b/test/test_jit_fuser.py index f19398b62d..2bfdebdea1 100644 --- a/test/test_jit_fuser.py +++ b/test/test_jit_fuser.py @@ -269,7 +269,7 @@ class TestFuser(JitTestCase): a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True) b = torch.randn(4, 4, dtype=torch.float, device='cuda') - nan = torch.tensor(float('nan')) + nan = torch.tensor(float('nan'), dtype=torch.float, device='cuda') funcs = (func2, funcInf, funcOptMin, funcOptMax) for f, inputs in product(funcs, [[a, b], [a, nan]]): diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp index 87014abf22..06aa9ca7ed 100644 --- a/torch/csrc/jit/init.cpp +++ b/torch/csrc/jit/init.cpp @@ -213,8 +213,6 @@ void initJITBindings(PyObject* module) { .def("_jit_pass_fixup_onnx_loops", FixupONNXLoops) .def("_jit_pass_canonicalize_ops", CanonicalizeOps) .def("_jit_pass_specialize_autogradzero", specializeAutogradZero) - .def("_jit_can_fuse_on_cpu", canFuseOnCPU) - .def("_jit_can_fuse_on_gpu", canFuseOnGPU) .def("_jit_override_can_fuse_on_cpu", &overrideCanFuseOnCPU) .def( "_jit_differentiate", diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index 0c02c93e82..7debb1675c 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -532,6 +532,19 @@ struct GraphFuser { return group; } + bool isFusableDevice(Value *v) { + auto tensor_type = v->type()->cast<DimensionedTensorType>(); + if (!tensor_type) { + return true; + } + if (tensor_type->device().is_cpu()) { + return canFuseOnCPU(); + } else if (tensor_type->device().is_cuda()) { + return canFuseOnGPU(); + } + throw std::runtime_error("Unknown device"); + } + at::optional<Node*> tryFuse(Node* consumer, Value* producer) { // this handles cases where producer can be moved _into_ the fusion group of // consumer. @@ -540,7 +553,7 @@ struct GraphFuser { // we can move the consumer up into the producer. // but this requires better handling of merging fusion groups so it is not // done now - bool shouldFuse = isFusable(producer->node()) && + bool shouldFuse = isFusableDevice(producer) && isFusable(producer->node()) && // Rearrange nodes such that all uses of producer are after the // consumer. Fusion will rewrite those later uses to use the version of // producer generated by the fused blob. In this case, producer becomes @@ -1426,16 +1439,14 @@ bool trackSingleGradSumToSizeToOutputs( } void FuseGraph(std::shared_ptr<Graph>& graph) { - if (canFuseOnCPU() || canFuseOnGPU()) { - GraphFuser(graph->block(), graph).run(); - // After FuseGraph some common subexpressions may come back - EliminateCommonSubexpression(graph); - // We might have emitted a fair amount of useless shape propagating code, so - // remove it - EliminateDeadCode(graph); - // Improve the quality of shape propagation code that was left - PeepholeOptimizeShapeExpressions(graph->block()); - } + GraphFuser(graph->block(), graph).run(); + // After FuseGraph some common subexpressions may come back + EliminateCommonSubexpression(graph); + // We might have emitted a fair amount of useless shape propagating code, so + // remove it + EliminateDeadCode(graph); + // Improve the quality of shape propagation code that was left + PeepholeOptimizeShapeExpressions(graph->block()); } } // namespace jit |