summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--test/expect/TestJit.test_concat_fusion_invariant_cuda.expect17
-rw-r--r--test/test_jit.py18
-rw-r--r--torch/csrc/jit/passes/graph_fuser.cpp17
3 files changed, 51 insertions, 1 deletions
diff --git a/test/expect/TestJit.test_concat_fusion_invariant_cuda.expect b/test/expect/TestJit.test_concat_fusion_invariant_cuda.expect
new file mode 100644
index 0000000000..bf45946ecd
--- /dev/null
+++ b/test/expect/TestJit.test_concat_fusion_invariant_cuda.expect
@@ -0,0 +1,17 @@
+graph(%0 : Float(2, 2)
+ %1 : Float(2, 2)
+ %2 : Float(4, 2)) {
+ %3 : int = prim::Constant[value=1]()
+ %4 : Float(2, 2) = aten::sub(%0, %1, %3)
+ %5 : Float(4, 2) = prim::FusionGroup_0[device=0](%4, %0, %1)
+ %6 : Float(4, 2) = aten::add(%5, %2, %3)
+ return (%6);
+}
+with prim::FusionGroup_0 = graph(%1 : Float(2, 2)
+ %3 : Float(2, 2)
+ %4 : Float(2, 2)) {
+ %5 : int = prim::Constant[value=1]()
+ %6 : Float(2, 2) = aten::add(%3, %4, %5)
+ %2 : Float(4, 2) = prim::FusedConcat[dim=0](%6, %1)
+ return (%2);
+}
diff --git a/test/test_jit.py b/test/test_jit.py
index 01914ec4a1..91f329df28 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -451,6 +451,24 @@ class TestJit(JitTestCase):
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@skipIfRocm
+ def test_concat_fusion_invariant_cuda(self):
+ # Invariant: the output of prim::FusedConcat may
+ # not be an input to any node inside the FusionGroup.
+ def fn(x, y, z):
+ x1 = x + y
+ y1 = x - y
+ w = torch.cat([x1, y1])
+ return w + z
+
+ x = torch.randn(2, 2, dtype=torch.float, device='cuda')
+ y = torch.randn(2, 2, dtype=torch.float, device='cuda')
+ z = torch.randn(4, 2, dtype=torch.float, device='cuda')
+ ge = self.checkTrace(fn, (x, y, z))
+ self.assertExpectedGraph(ge.graph_for(x, y, z))
+
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ @skipIfRocm
def test_fusion_distribute(self):
def f(x, y):
z1, z2 = (x + y).chunk(2, dim=1)
diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp
index 071ca6d57c..4de21160fa 100644
--- a/torch/csrc/jit/passes/graph_fuser.cpp
+++ b/torch/csrc/jit/passes/graph_fuser.cpp
@@ -207,7 +207,11 @@ struct GraphFuser {
// because it is not a simple map, can be put in a fusion group
// as long as no items in the group read the output of concat
bool isFusableAsExitNode(Node * node) {
- return isFusable(node) || isFusableCatNode(node);
+ return isFusable(node) || isFusableOnlyAsExitNode(node);
+ }
+
+ bool isFusableOnlyAsExitNode(Node * node) {
+ return isFusableCatNode(node) || node->kind() == prim::FusedConcat;
}
// necessary condition for fusion. If all of the uses of producer are consumer
@@ -236,6 +240,15 @@ struct GraphFuser {
return true;
}
+ bool mustRemainAsFusionGroupOutput(Value * producer) {
+ if (producer->node()->kind() != prim::FusionGroup) {
+ return false;
+ }
+ auto subgraph = producer->node()->g(attr::Subgraph);
+ auto * node = subgraph->outputs().at(producer->offset())->node();
+ return isFusableOnlyAsExitNode(node);
+ }
+
bool shouldFuse(Node * consumer, Value * producer) {
// this handles cases where producer can be moved _into_ the fusion group of consumer.
// TODO: extend to fusion of consumer into _producer's_ fusion blob
@@ -559,6 +572,8 @@ struct GraphFuser {
for(auto producer : inputs) {
// Don't fuse accross stage boundaries
if (producer->stage() != consumer->stage()) continue;
+ // Don't fuse if producer must come from a FusionGroup exit node
+ if (mustRemainAsFusionGroupOutput(producer)) continue;
if(tryToMoveChunk(consumer,producer)) {
// the chunk before this consumer was re-arranged to allow fusion,
// we scan this consumer again to perform the fusion