summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRichard Zou <zou3519@gmail.com>2018-12-18 16:13:39 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-12-18 16:33:59 -0800
commit5667af3880af35663e5eee55b657e18e437f52e6 (patch)
treea23ae701be5b7d9aa9e2968539135517656ad49b
parent3681bf7cff5c41f7e177837d55c898065f1aaee8 (diff)
downloadpytorch-5667af3880af35663e5eee55b657e18e437f52e6.tar.gz
pytorch-5667af3880af35663e5eee55b657e18e437f52e6.tar.bz2
pytorch-5667af3880af35663e5eee55b657e18e437f52e6.zip
Minor cleanup for TestFuser tests (#15134)
Summary: Changelog: - change some expect tests that didn't have to be expect tests, instead use self.assertAllFused - Some of the fuser tests weren't using self.assertAllFused. - Minor test renames cc apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/15134 Differential Revision: D13507481 Pulled By: zou3519 fbshipit-source-id: dd0788530a60bb5ed2f42b961fae3db2b4404b64
-rw-r--r--test/expect/TestFuser.test_last_device_cuda.expect16
-rw-r--r--test/expect/TestFuser.test_tensor_scalar_ops_cuda-1.expect11
-rw-r--r--test/expect/TestFuser.test_tensor_scalar_ops_cuda-2.expect8
-rw-r--r--test/test_jit.py30
4 files changed, 11 insertions, 54 deletions
diff --git a/test/expect/TestFuser.test_last_device_cuda.expect b/test/expect/TestFuser.test_last_device_cuda.expect
deleted file mode 100644
index b2ef06bcbe..0000000000
--- a/test/expect/TestFuser.test_last_device_cuda.expect
+++ /dev/null
@@ -1,16 +0,0 @@
-graph(%x : Float(*)
- %y : Float(*)) {
- %2 : Float(*) = prim::FusionGroup_0(%x, %y)
- return (%2);
-}
-with prim::FusionGroup_0 = graph(%0 : Float(*)
- %1 : Float(*)) {
- %2 : int = prim::Constant[value=1]()
- %3 : Float(*) = aten::add(%0, %1, %2)
- %4 : Float(*) = aten::mul(%0, %3)
- %5 : int = prim::Constant[value=1]()
- %6 : Float(*) = aten::add(%4, %0, %5)
- %7 : Float(*) = aten::tanh(%6)
- %8 : Float(*) = aten::sigmoid(%7)
- return (%8);
-}
diff --git a/test/expect/TestFuser.test_tensor_scalar_ops_cuda-1.expect b/test/expect/TestFuser.test_tensor_scalar_ops_cuda-1.expect
deleted file mode 100644
index 60ccea5666..0000000000
--- a/test/expect/TestFuser.test_tensor_scalar_ops_cuda-1.expect
+++ /dev/null
@@ -1,11 +0,0 @@
-graph(%x : Float(*, *)) {
- %1 : Float(*, *) = prim::FusionGroup_0(%x)
- return (%1);
-}
-with prim::FusionGroup_0 = graph(%0 : Float(*, *)) {
- %z : float = prim::Constant[value=3]()
- %2 : int = prim::Constant[value=1]()
- %y : Float(*, *) = aten::add(%0, %z, %2)
- %4 : Float(*, *) = aten::mul(%0, %y)
- return (%4);
-}
diff --git a/test/expect/TestFuser.test_tensor_scalar_ops_cuda-2.expect b/test/expect/TestFuser.test_tensor_scalar_ops_cuda-2.expect
deleted file mode 100644
index 07487929d7..0000000000
--- a/test/expect/TestFuser.test_tensor_scalar_ops_cuda-2.expect
+++ /dev/null
@@ -1,8 +0,0 @@
-graph(%x : Float(*, *)
- %z : Float()) {
- %2 : int = prim::Constant[value=1]()
- %3 : int = prim::Int(%z)
- %y : Float(*, *) = aten::add(%x, %3, %2)
- %5 : Float(*, *) = aten::mul(%x, %y)
- return (%5);
-}
diff --git a/test/test_jit.py b/test/test_jit.py
index 8d88e8c3de..f8d946a256 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -10070,11 +10070,10 @@ class TestFuser(JitTestCase):
# XXX: This assumes that the same kernel isn't already used by another test
self.assertEqual(new_cache_size - prev_cache_size, 1)
- # TODO: This test doesn't offer anything valuable, maybe we should delete it
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
@skipIfRocm
- def test_last_device_cuda(self):
+ def test_nonzero_device_cuda(self):
device = 'cuda:' + str(1)
x = torch.tensor([0.4], dtype=torch.float, device=device)
y = torch.tensor([0.7], dtype=torch.float, device=device)
@@ -10083,7 +10082,7 @@ class TestFuser(JitTestCase):
return torch.sigmoid(torch.tanh(x * (x + y) + x))
ge = self.checkTrace(doit, (x, y))
- self.assertExpectedGraph(ge.graph_for(x, y))
+ self.assertAllFused(ge.graph_for(x, y))
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@@ -10212,6 +10211,7 @@ class TestFuser(JitTestCase):
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
ge = self.checkTrace(self.fn_test_relu, (x, y))
+ self.assertAllFused(ge.graph_for(x, y))
@staticmethod
def fn_test_erf(x):
@@ -10266,14 +10266,15 @@ class TestFuser(JitTestCase):
inputs = [torch.randn(2, 2, dtype=torch.float, device='cuda')]
ge = self.checkScript(should_fuse, inputs)
- self.assertExpectedGraph(ge.graph_for(*inputs), subname='1')
+ self.assertAllFused(ge.graph_for(*inputs))
inputs = [
torch.randn(2, 2, dtype=torch.float, device='cuda'),
torch.tensor(3., dtype=torch.float, device='cuda'),
]
ge = self.checkScript(should_not_fuse, inputs)
- self.assertExpectedGraph(ge.graph_for(*inputs), subname='2')
+ self.assertGraphContainsExactly(
+ ge.graph_for(*inputs), 'prim::FusionGroup', 0, consider_subgraphs=True)
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
@enable_cpu_fuser
@@ -10294,30 +10295,21 @@ class TestFuser(JitTestCase):
self.assertEqual(result2, expected2)
self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'})
- # TODO: This test seems dead
- @unittest.skipIf(not IS_WINDOWS, "Testing Fuse skipped on windows")
+ @unittest.skipIf(not IS_WINDOWS, "Test that the fuser is disabled on Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- def test_windows(self):
+ def test_windows_cuda(self):
def scaleshift(x, scale, shift):
return x * scale + shift
- graph = torch.jit.script(scaleshift).graph
-
inputs = [
torch.randn(4, 4, dtype=torch.float, device='cuda'),
torch.randn(4, dtype=torch.float, device='cuda'),
torch.randn(4, dtype=torch.float, device='cuda'),
]
- ge = self.checkTrace(scaleshift, inputs)
- fuse_graph = ge.graph_for(*inputs)
-
- def run_graph(graph, inputs):
- m = torch.jit.ScriptModule()
- m._create_method_from_graph("forward", graph)
- return m(*inputs)
-
- self.assertEqual(run_graph(graph, inputs), run_graph(fuse_graph, inputs))
+ ge = self.checkScript(scaleshift, inputs)
+ self.assertGraphContainsExactly(
+ ge.graph_for(*inputs), 'prim::FusionGroup', 0, consider_subgraphs=True)
# NB: torch.jit.script, when used as a function, uses the current scope