summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorAdam Paszke <adam.paszke@gmail.com>2019-01-08 07:20:22 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-01-08 07:22:55 -0800
commit29a9d6af45769afd304005be8af03f36f74a14b2 (patch)
tree0fa5faf726b357545fd77a1746d1c39b6f9d425a /test
parent5e1b35bf2827d24d626739d14f462f4c87875892 (diff)
downloadpytorch-29a9d6af45769afd304005be8af03f36f74a14b2.tar.gz
pytorch-29a9d6af45769afd304005be8af03f36f74a14b2.tar.bz2
pytorch-29a9d6af45769afd304005be8af03f36f74a14b2.zip
Stop leaving garbage files after running test_jit.py
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15404 Differential Revision: D13548316 Pulled By: zou3519 fbshipit-source-id: fe8731d8add59777781d34d9c3f3314f11467b23
Diffstat (limited to 'test')
-rw-r--r--test/test_jit.py39
1 files changed, 25 insertions, 14 deletions
diff --git a/test/test_jit.py b/test/test_jit.py
index f43d0a4877..1374f1e1f9 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -77,6 +77,25 @@ PY35 = sys.version_info >= (3, 5)
WINDOWS = sys.platform == 'win32'
+if WINDOWS:
+ @contextmanager
+ def TemporaryFileName():
+ # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
+ # opens the file, and it cannot be opened multiple times in Windows. To support Windows,
+ # close the file after creation and try to remove it manually
+ f = tempfile.NamedTemporaryFile(delete=False)
+ try:
+ f.close()
+ yield f.name
+ finally:
+ os.unlink(f.name)
+else:
+ @contextmanager
+ def TemporaryFileName():
+ with tempfile.NamedTemporaryFile() as f:
+ yield f.name
+
+
def LSTMCellF(input, hx, cx, *params):
return LSTMCell(input, (hx, cx), *params)
@@ -282,18 +301,9 @@ class JitTestCase(TestCase):
if not also_test_file:
return imported
- # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
- # opens the file, and it cannot be opened multiple times in Windows. To support Windows,
- # close the file after creation and try to remove it manually
- f = tempfile.NamedTemporaryFile(delete=False)
- try:
- f.close()
- imported.save(f.name)
- result = torch.jit.load(f.name, map_location=map_location)
- finally:
- os.unlink(f.name)
-
- return result
+ with TemporaryFileName() as fname:
+ imported.save(fname)
+ return torch.jit.load(fname, map_location=map_location)
def assertGraphContains(self, graph, kind):
self.assertTrue(any(n.kind() == kind for n in graph.nodes()))
@@ -554,8 +564,9 @@ class TestJit(JitTestCase):
self.assertFalse(m2.b0.is_cuda)
def test_model_save_error(self):
- with self.assertRaisesRegex(pickle.PickleError, "not supported"):
- torch.save(FooToPickle(), "will_fail")
+ with TemporaryFileName() as fname:
+ with self.assertRaisesRegex(pickle.PickleError, "not supported"):
+ torch.save(FooToPickle(), fname)
def test_single_tuple_trace(self):
x = torch.tensor(2.)