diff options
author | Adam Paszke <adam.paszke@gmail.com> | 2019-01-08 07:20:22 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-01-08 07:22:55 -0800 |
commit | 29a9d6af45769afd304005be8af03f36f74a14b2 (patch) | |
tree | 0fa5faf726b357545fd77a1746d1c39b6f9d425a /test | |
parent | 5e1b35bf2827d24d626739d14f462f4c87875892 (diff) | |
download | pytorch-29a9d6af45769afd304005be8af03f36f74a14b2.tar.gz pytorch-29a9d6af45769afd304005be8af03f36f74a14b2.tar.bz2 pytorch-29a9d6af45769afd304005be8af03f36f74a14b2.zip |
Stop leaving garbage files after running test_jit.py
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15404
Differential Revision: D13548316
Pulled By: zou3519
fbshipit-source-id: fe8731d8add59777781d34d9c3f3314f11467b23
Diffstat (limited to 'test')
-rw-r--r-- | test/test_jit.py | 39 |
1 files changed, 25 insertions, 14 deletions
diff --git a/test/test_jit.py b/test/test_jit.py index f43d0a4877..1374f1e1f9 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -77,6 +77,25 @@ PY35 = sys.version_info >= (3, 5) WINDOWS = sys.platform == 'win32' +if WINDOWS: + @contextmanager + def TemporaryFileName(): + # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile + # opens the file, and it cannot be opened multiple times in Windows. To support Windows, + # close the file after creation and try to remove it manually + f = tempfile.NamedTemporaryFile(delete=False) + try: + f.close() + yield f.name + finally: + os.unlink(f.name) +else: + @contextmanager + def TemporaryFileName(): + with tempfile.NamedTemporaryFile() as f: + yield f.name + + def LSTMCellF(input, hx, cx, *params): return LSTMCell(input, (hx, cx), *params) @@ -282,18 +301,9 @@ class JitTestCase(TestCase): if not also_test_file: return imported - # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile - # opens the file, and it cannot be opened multiple times in Windows. To support Windows, - # close the file after creation and try to remove it manually - f = tempfile.NamedTemporaryFile(delete=False) - try: - f.close() - imported.save(f.name) - result = torch.jit.load(f.name, map_location=map_location) - finally: - os.unlink(f.name) - - return result + with TemporaryFileName() as fname: + imported.save(fname) + return torch.jit.load(fname, map_location=map_location) def assertGraphContains(self, graph, kind): self.assertTrue(any(n.kind() == kind for n in graph.nodes())) @@ -554,8 +564,9 @@ class TestJit(JitTestCase): self.assertFalse(m2.b0.is_cuda) def test_model_save_error(self): - with self.assertRaisesRegex(pickle.PickleError, "not supported"): - torch.save(FooToPickle(), "will_fail") + with TemporaryFileName() as fname: + with self.assertRaisesRegex(pickle.PickleError, "not supported"): + torch.save(FooToPickle(), fname) def test_single_tuple_trace(self): x = torch.tensor(2.) |