diff options
-rw-r--r-- | test/expect/TestJit.test_concat_fusion_invariant_cuda.expect | 17 | ||||
-rw-r--r-- | test/test_jit.py | 18 | ||||
-rw-r--r-- | torch/csrc/jit/passes/graph_fuser.cpp | 17 |
3 files changed, 51 insertions, 1 deletions
diff --git a/test/expect/TestJit.test_concat_fusion_invariant_cuda.expect b/test/expect/TestJit.test_concat_fusion_invariant_cuda.expect new file mode 100644 index 0000000000..bf45946ecd --- /dev/null +++ b/test/expect/TestJit.test_concat_fusion_invariant_cuda.expect @@ -0,0 +1,17 @@ +graph(%0 : Float(2, 2) + %1 : Float(2, 2) + %2 : Float(4, 2)) { + %3 : int = prim::Constant[value=1]() + %4 : Float(2, 2) = aten::sub(%0, %1, %3) + %5 : Float(4, 2) = prim::FusionGroup_0[device=0](%4, %0, %1) + %6 : Float(4, 2) = aten::add(%5, %2, %3) + return (%6); +} +with prim::FusionGroup_0 = graph(%1 : Float(2, 2) + %3 : Float(2, 2) + %4 : Float(2, 2)) { + %5 : int = prim::Constant[value=1]() + %6 : Float(2, 2) = aten::add(%3, %4, %5) + %2 : Float(4, 2) = prim::FusedConcat[dim=0](%6, %1) + return (%2); +} diff --git a/test/test_jit.py b/test/test_jit.py index 01914ec4a1..91f329df28 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -451,6 +451,24 @@ class TestJit(JitTestCase): @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @skipIfRocm + def test_concat_fusion_invariant_cuda(self): + # Invariant: the output of prim::FusedConcat may + # not be an input to any node inside the FusionGroup. + def fn(x, y, z): + x1 = x + y + y1 = x - y + w = torch.cat([x1, y1]) + return w + z + + x = torch.randn(2, 2, dtype=torch.float, device='cuda') + y = torch.randn(2, 2, dtype=torch.float, device='cuda') + z = torch.randn(4, 2, dtype=torch.float, device='cuda') + ge = self.checkTrace(fn, (x, y, z)) + self.assertExpectedGraph(ge.graph_for(x, y, z)) + + @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") + @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") + @skipIfRocm def test_fusion_distribute(self): def f(x, y): z1, z2 = (x + y).chunk(2, dim=1) diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index 071ca6d57c..4de21160fa 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -207,7 +207,11 @@ struct GraphFuser { // because it is not a simple map, can be put in a fusion group // as long as no items in the group read the output of concat bool isFusableAsExitNode(Node * node) { - return isFusable(node) || isFusableCatNode(node); + return isFusable(node) || isFusableOnlyAsExitNode(node); + } + + bool isFusableOnlyAsExitNode(Node * node) { + return isFusableCatNode(node) || node->kind() == prim::FusedConcat; } // necessary condition for fusion. If all of the uses of producer are consumer @@ -236,6 +240,15 @@ struct GraphFuser { return true; } + bool mustRemainAsFusionGroupOutput(Value * producer) { + if (producer->node()->kind() != prim::FusionGroup) { + return false; + } + auto subgraph = producer->node()->g(attr::Subgraph); + auto * node = subgraph->outputs().at(producer->offset())->node(); + return isFusableOnlyAsExitNode(node); + } + bool shouldFuse(Node * consumer, Value * producer) { // this handles cases where producer can be moved _into_ the fusion group of consumer. // TODO: extend to fusion of consumer into _producer's_ fusion blob @@ -559,6 +572,8 @@ struct GraphFuser { for(auto producer : inputs) { // Don't fuse accross stage boundaries if (producer->stage() != consumer->stage()) continue; + // Don't fuse if producer must come from a FusionGroup exit node + if (mustRemainAsFusionGroupOutput(producer)) continue; if(tryToMoveChunk(consumer,producer)) { // the chunk before this consumer was re-arranged to allow fusion, // we scan this consumer again to perform the fusion |