diff options
author | Richard Zou <zou3519@gmail.com> | 2018-08-13 20:55:56 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-08-13 21:09:25 -0700 |
commit | fed05cf4cf1063df009fc2280a0b5040606fde89 (patch) | |
tree | 8ef915737b037569d032290a45b3affff2443078 /test/test_jit.py | |
parent | 099a545376f805ce4da10bcb5cfc7bc71bbcba7c (diff) | |
download | pytorch-fed05cf4cf1063df009fc2280a0b5040606fde89.tar.gz pytorch-fed05cf4cf1063df009fc2280a0b5040606fde89.tar.bz2 pytorch-fed05cf4cf1063df009fc2280a0b5040606fde89.zip |
Fix prim::FusedConcat bug (#10466)
Summary:
Fixes #10456
The graph fuser was fusing together groups with prim::FusedConcat (the producer) with other ops (the consumer) if the consumer is fusable. For example,
```
import torch
torch.jit.script
def fn(x, y, z):
x1 = x + y
y1 = x - y
w = torch.cat([x1, y1])
return w + z
x = torch.randn(2, 2, dtype=torch.float, device='cpu')
y = torch.randn(2, 2, dtype=torch.float, device='cpu')
z = torch.randn(4, 2, dtype=torch.float, device='cpu')
fn(x, y, z)
fn.graph_for(x, y, z)
```
produced the following graph:
```
graph(%x : Float(2, 2)
%y : Float(2, 2)
%z : Float(4, 2)) {
%3 : int = prim::Constant[value=1]()
%y1 : Float(2, 2) = aten::sub(%x, %y, %3)
%8 : int = prim::Constant[value=0]()
%14 : Float(4, 2) = prim::FusionGroup_0[device=-1](%z, %y1, %x, %y)
return (%14);
}
with prim::FusionGroup_0 = graph(%1 : Float(4, 2)
%5 : Float(2, 2)
%7 : Float(2, 2)
%8 : Float(2, 2)) {
%11 : int = prim::Constant[value=1]()
%9 : int = prim::Constant[value=1]()
%x1 : Float(2, 2) = aten::add(%7, %8, %9)
%w : Float(4, 2) = prim::FusedConcat[dim=0](%x1, %5)
%2 : int = prim::Constant[value=1]()
%3 : Float(4, 2) = aten::add(%w, %1, %2)
return (%3);
}
```
this is a problem because it violates two invariants:
1) all inputs to the FusionGroup must have the same size
2) prim::FusedConcat's output must not be used inside the FusionGroup
This PR fixes this problem by checking if the output to a FusionGroup came from a prim::FusedConcat node when deciding whether to fuse the consumer and producer.
If the producer is a value that came from a prim::FusedConcat node in a FusionGroup, then consumer & producer do not get fused.
cc apaszke zdevito
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10466
Differential Revision: D9296686
Pulled By: zou3519
fbshipit-source-id: ed826fa9c436b42c04ca7d4d790cece804c162bd
Diffstat (limited to 'test/test_jit.py')
-rw-r--r-- | test/test_jit.py | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/test/test_jit.py b/test/test_jit.py index 01914ec4a1..91f329df28 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -451,6 +451,24 @@ class TestJit(JitTestCase): @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @skipIfRocm + def test_concat_fusion_invariant_cuda(self): + # Invariant: the output of prim::FusedConcat may + # not be an input to any node inside the FusionGroup. + def fn(x, y, z): + x1 = x + y + y1 = x - y + w = torch.cat([x1, y1]) + return w + z + + x = torch.randn(2, 2, dtype=torch.float, device='cuda') + y = torch.randn(2, 2, dtype=torch.float, device='cuda') + z = torch.randn(4, 2, dtype=torch.float, device='cuda') + ge = self.checkTrace(fn, (x, y, z)) + self.assertExpectedGraph(ge.graph_for(x, y, z)) + + @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") + @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") + @skipIfRocm def test_fusion_distribute(self): def f(x, y): z1, z2 = (x + y).chunk(2, dim=1) |