summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJames Reed <jamesreed@fb.com>2019-04-22 16:54:19 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-22 16:57:18 -0700
commit5be4bee4ff58d04c8ee73163e2cb2e06bd69b2a5 (patch)
tree88a18b63b945eb952199e65caa28b3e6e258160b
parent969af4315a30e96205e125a16e67bd6e3c03e218 (diff)
downloadpytorch-5be4bee4ff58d04c8ee73163e2cb2e06bd69b2a5.tar.gz
pytorch-5be4bee4ff58d04c8ee73163e2cb2e06bd69b2a5.tar.bz2
pytorch-5be4bee4ff58d04c8ee73163e2cb2e06bd69b2a5.zip
Don't create FusionGroups for known-CPU producer values (#19342)
Summary: I believe the existing check in FuseGraph was only `false` if PyTorch was built with NO_CUDA=1. Otherwise, we would create fusion groups even if we're on a CPU-only machine running CPU code. This is confusing. Instead I've made it so that the decision to fuse or not is dependent on if the producer Value is a known CPU tensor. If it is, we skip fusion. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19342 Differential Revision: D15038351 Pulled By: jamesr66a fbshipit-source-id: fce9d83929309a7bf14346833f84b996f3e7f6db
-rw-r--r--test/test_jit.py19
-rw-r--r--test/test_jit_fuser.py2
-rw-r--r--torch/csrc/jit/init.cpp2
-rw-r--r--torch/csrc/jit/passes/graph_fuser.cpp33
4 files changed, 35 insertions, 21 deletions
diff --git a/test/test_jit.py b/test/test_jit.py
index f1310519b8..5457edb2fb 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -69,10 +69,6 @@ except ImportError:
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
-# Note: creating FusionGroups is currently device-independent.
-# FusionGroup creation with CPU is disabled.
-FUSION_ENABLED = torch._C._jit_can_fuse_on_cpu() or torch._C._jit_can_fuse_on_gpu()
-
RUN_CUDA = torch.cuda.is_available()
RUN_CUDA_HALF = RUN_CUDA
if torch.cuda.is_available():
@@ -438,9 +434,6 @@ class JitTestCase(TestCase):
self.assertExpected(str(graph), *args, **kwargs)
def assertAutodiffNode(self, graph, should_autodiff_node, nonfusible_nodes, fusible_nodes):
- if not FUSION_ENABLED:
- nonfusible_nodes = nonfusible_nodes + fusible_nodes
- fusible_nodes = []
diff_nodes = graph.findAllNodes('prim::DifferentiableGraph')
diff_subgraphs = [node.g('Subgraph') for node in diff_nodes]
@@ -13138,6 +13131,10 @@ def add_autograd_test(
# we want to close over in some way
def do_test(self, name=name, self_size=self_size, args=new_args, test_name=test_name,
check_ad=check_ad, output_process_fn=output_process_fn):
+ # We enable the CPU fuser during these checks for more consistent
+ # behavior. Otherwise, we are going to have to analyze the graph to
+ # see if producer values are Dimension
+ @enable_cpu_fuser
def check(name):
set_rng_seed(2)
is_magic_method = name[:2] == '__' and name[-2:] == '__'
@@ -13169,6 +13166,10 @@ def add_autograd_test(
check_against_reference(self, traced_fn,
fn, (self_variable,) + args_variable, kwargs_variable,
check_types=check_types)
+ # Fuser not supported on windows
+ if IS_WINDOWS:
+ autodiff_nodes = autodiff_nodes + fusible_nodes
+ fusible_nodes = []
self.assertAutodiffNode(traced_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes)
if not is_magic_method and test_name not in EXCLUDE_SCRIPT:
@@ -13177,6 +13178,10 @@ def add_autograd_test(
fn, (self_variable,) + args_variable, kwargs_variable,
check_types=check_types)
+ # Fuser not supported on windows
+ if IS_WINDOWS:
+ autodiff_nodes = autodiff_nodes + fusible_nodes
+ fusible_nodes = []
self.assertAutodiffNode(script_fn.last_graph,
should_autodiff_node and test_name not in EXCLUDE_SCRIPT_AD_CHECK,
autodiff_nodes,
diff --git a/test/test_jit_fuser.py b/test/test_jit_fuser.py
index f19398b62d..2bfdebdea1 100644
--- a/test/test_jit_fuser.py
+++ b/test/test_jit_fuser.py
@@ -269,7 +269,7 @@ class TestFuser(JitTestCase):
a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
b = torch.randn(4, 4, dtype=torch.float, device='cuda')
- nan = torch.tensor(float('nan'))
+ nan = torch.tensor(float('nan'), dtype=torch.float, device='cuda')
funcs = (func2, funcInf, funcOptMin, funcOptMax)
for f, inputs in product(funcs, [[a, b], [a, nan]]):
diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp
index 87014abf22..06aa9ca7ed 100644
--- a/torch/csrc/jit/init.cpp
+++ b/torch/csrc/jit/init.cpp
@@ -213,8 +213,6 @@ void initJITBindings(PyObject* module) {
.def("_jit_pass_fixup_onnx_loops", FixupONNXLoops)
.def("_jit_pass_canonicalize_ops", CanonicalizeOps)
.def("_jit_pass_specialize_autogradzero", specializeAutogradZero)
- .def("_jit_can_fuse_on_cpu", canFuseOnCPU)
- .def("_jit_can_fuse_on_gpu", canFuseOnGPU)
.def("_jit_override_can_fuse_on_cpu", &overrideCanFuseOnCPU)
.def(
"_jit_differentiate",
diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp
index 0c02c93e82..7debb1675c 100644
--- a/torch/csrc/jit/passes/graph_fuser.cpp
+++ b/torch/csrc/jit/passes/graph_fuser.cpp
@@ -532,6 +532,19 @@ struct GraphFuser {
return group;
}
+ bool isFusableDevice(Value *v) {
+ auto tensor_type = v->type()->cast<DimensionedTensorType>();
+ if (!tensor_type) {
+ return true;
+ }
+ if (tensor_type->device().is_cpu()) {
+ return canFuseOnCPU();
+ } else if (tensor_type->device().is_cuda()) {
+ return canFuseOnGPU();
+ }
+ throw std::runtime_error("Unknown device");
+ }
+
at::optional<Node*> tryFuse(Node* consumer, Value* producer) {
// this handles cases where producer can be moved _into_ the fusion group of
// consumer.
@@ -540,7 +553,7 @@ struct GraphFuser {
// we can move the consumer up into the producer.
// but this requires better handling of merging fusion groups so it is not
// done now
- bool shouldFuse = isFusable(producer->node()) &&
+ bool shouldFuse = isFusableDevice(producer) && isFusable(producer->node()) &&
// Rearrange nodes such that all uses of producer are after the
// consumer. Fusion will rewrite those later uses to use the version of
// producer generated by the fused blob. In this case, producer becomes
@@ -1426,16 +1439,14 @@ bool trackSingleGradSumToSizeToOutputs(
}
void FuseGraph(std::shared_ptr<Graph>& graph) {
- if (canFuseOnCPU() || canFuseOnGPU()) {
- GraphFuser(graph->block(), graph).run();
- // After FuseGraph some common subexpressions may come back
- EliminateCommonSubexpression(graph);
- // We might have emitted a fair amount of useless shape propagating code, so
- // remove it
- EliminateDeadCode(graph);
- // Improve the quality of shape propagation code that was left
- PeepholeOptimizeShapeExpressions(graph->block());
- }
+ GraphFuser(graph->block(), graph).run();
+ // After FuseGraph some common subexpressions may come back
+ EliminateCommonSubexpression(graph);
+ // We might have emitted a fair amount of useless shape propagating code, so
+ // remove it
+ EliminateDeadCode(graph);
+ // Improve the quality of shape propagation code that was left
+ PeepholeOptimizeShapeExpressions(graph->block());
}
} // namespace jit