diff options
author | Lu Fang <lufang@fb.com> | 2019-01-23 21:32:57 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-01-23 21:35:35 -0800 |
commit | 8ab4d348f464fe9bf28bf6cbf518f97a65efae1b (patch) | |
tree | 483429544df8e49f88e0f9023bad9e6c656aaa8b /test | |
parent | 3cba115abb750b74ecc246f76e0592ca3ca751a9 (diff) | |
download | pytorch-8ab4d348f464fe9bf28bf6cbf518f97a65efae1b.tar.gz pytorch-8ab4d348f464fe9bf28bf6cbf518f97a65efae1b.tar.bz2 pytorch-8ab4d348f464fe9bf28bf6cbf518f97a65efae1b.zip |
Fix the tensor deserialization problem of jit script module on CUDA (#16279)
Summary:
Now we create a temporary tensor for the whole record.
Fix https://github.com/pytorch/pytorch/issues/15271
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16279
Reviewed By: BIT-silence
Differential Revision: D13791442
Pulled By: houseroad
fbshipit-source-id: 6f52ca09627fb684f74121357cc42e4adadec36a
Diffstat (limited to 'test')
-rw-r--r-- | test/test_jit.py | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/test/test_jit.py b/test/test_jit.py index 6e0ad79216..cd95236ff5 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -648,6 +648,21 @@ class TestJit(JitTestCase): self.assertEqual(origin_result, m3(input.cpu())) self.assertEqual(origin_result, m4(input.cuda(0))) + @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA") + def test_restore_shared_storage_on_cuda(self): + whole_tensor = torch.randn(4, 5, dtype=torch.float, device='cpu') + m = torch.jit.ScriptModule() + m.p0 = nn.Parameter(whole_tensor.narrow(0, 0, 1)) + m.register_buffer('b0', whole_tensor.narrow(0, 3, 1)) + m2 = self.getExportImportCopy(m, map_location=torch.device('cuda:0')) + self.assertEqual(tuple(m.parameters()), tuple(m2.parameters())) + self.assertEqual(tuple(m.buffers()), tuple(m2.buffers())) + self.assertTrue(m2.p0.is_cuda) + self.assertTrue(m2.b0.is_cuda) + self.assertTrue(m2.p0.is_shared()) + self.assertTrue(m2.b0.is_shared()) + self.assertEqual(m2.b0.storage().data_ptr(), m2.p0.storage().data_ptr()) + def test_typeas_trace_check(self): a = torch.tensor([0.4], requires_grad=True) b = torch.tensor([0.7], requires_grad=True) |