summaryrefslogtreecommitdiff
path: root/test/test_jit.py
diff options
context:
space:
mode:
authorRichard Zou <zou3519@gmail.com>2018-08-13 20:55:56 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-08-13 21:09:25 -0700
commitfed05cf4cf1063df009fc2280a0b5040606fde89 (patch)
tree8ef915737b037569d032290a45b3affff2443078 /test/test_jit.py
parent099a545376f805ce4da10bcb5cfc7bc71bbcba7c (diff)
downloadpytorch-fed05cf4cf1063df009fc2280a0b5040606fde89.tar.gz
pytorch-fed05cf4cf1063df009fc2280a0b5040606fde89.tar.bz2
pytorch-fed05cf4cf1063df009fc2280a0b5040606fde89.zip
Fix prim::FusedConcat bug (#10466)
Summary: Fixes #10456 The graph fuser was fusing together groups with prim::FusedConcat (the producer) with other ops (the consumer) if the consumer is fusable. For example, ``` import torch torch.jit.script def fn(x, y, z): x1 = x + y y1 = x - y w = torch.cat([x1, y1]) return w + z x = torch.randn(2, 2, dtype=torch.float, device='cpu') y = torch.randn(2, 2, dtype=torch.float, device='cpu') z = torch.randn(4, 2, dtype=torch.float, device='cpu') fn(x, y, z) fn.graph_for(x, y, z) ``` produced the following graph: ``` graph(%x : Float(2, 2) %y : Float(2, 2) %z : Float(4, 2)) { %3 : int = prim::Constant[value=1]() %y1 : Float(2, 2) = aten::sub(%x, %y, %3) %8 : int = prim::Constant[value=0]() %14 : Float(4, 2) = prim::FusionGroup_0[device=-1](%z, %y1, %x, %y) return (%14); } with prim::FusionGroup_0 = graph(%1 : Float(4, 2) %5 : Float(2, 2) %7 : Float(2, 2) %8 : Float(2, 2)) { %11 : int = prim::Constant[value=1]() %9 : int = prim::Constant[value=1]() %x1 : Float(2, 2) = aten::add(%7, %8, %9) %w : Float(4, 2) = prim::FusedConcat[dim=0](%x1, %5) %2 : int = prim::Constant[value=1]() %3 : Float(4, 2) = aten::add(%w, %1, %2) return (%3); } ``` this is a problem because it violates two invariants: 1) all inputs to the FusionGroup must have the same size 2) prim::FusedConcat's output must not be used inside the FusionGroup This PR fixes this problem by checking if the output to a FusionGroup came from a prim::FusedConcat node when deciding whether to fuse the consumer and producer. If the producer is a value that came from a prim::FusedConcat node in a FusionGroup, then consumer & producer do not get fused. cc apaszke zdevito Pull Request resolved: https://github.com/pytorch/pytorch/pull/10466 Differential Revision: D9296686 Pulled By: zou3519 fbshipit-source-id: ed826fa9c436b42c04ca7d4d790cece804c162bd
Diffstat (limited to 'test/test_jit.py')
-rw-r--r--test/test_jit.py18
1 files changed, 18 insertions, 0 deletions
diff --git a/test/test_jit.py b/test/test_jit.py
index 01914ec4a1..91f329df28 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -451,6 +451,24 @@ class TestJit(JitTestCase):
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@skipIfRocm
+ def test_concat_fusion_invariant_cuda(self):
+ # Invariant: the output of prim::FusedConcat may
+ # not be an input to any node inside the FusionGroup.
+ def fn(x, y, z):
+ x1 = x + y
+ y1 = x - y
+ w = torch.cat([x1, y1])
+ return w + z
+
+ x = torch.randn(2, 2, dtype=torch.float, device='cuda')
+ y = torch.randn(2, 2, dtype=torch.float, device='cuda')
+ z = torch.randn(4, 2, dtype=torch.float, device='cuda')
+ ge = self.checkTrace(fn, (x, y, z))
+ self.assertExpectedGraph(ge.graph_for(x, y, z))
+
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ @skipIfRocm
def test_fusion_distribute(self):
def f(x, y):
z1, z2 = (x + y).chunk(2, dim=1)