summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorLu Fang <lufang@fb.com>2019-01-23 21:32:57 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-01-23 21:35:35 -0800
commit8ab4d348f464fe9bf28bf6cbf518f97a65efae1b (patch)
tree483429544df8e49f88e0f9023bad9e6c656aaa8b /test
parent3cba115abb750b74ecc246f76e0592ca3ca751a9 (diff)
downloadpytorch-8ab4d348f464fe9bf28bf6cbf518f97a65efae1b.tar.gz
pytorch-8ab4d348f464fe9bf28bf6cbf518f97a65efae1b.tar.bz2
pytorch-8ab4d348f464fe9bf28bf6cbf518f97a65efae1b.zip
Fix the tensor deserialization problem of jit script module on CUDA (#16279)
Summary: Now we create a temporary tensor for the whole record. Fix https://github.com/pytorch/pytorch/issues/15271 Pull Request resolved: https://github.com/pytorch/pytorch/pull/16279 Reviewed By: BIT-silence Differential Revision: D13791442 Pulled By: houseroad fbshipit-source-id: 6f52ca09627fb684f74121357cc42e4adadec36a
Diffstat (limited to 'test')
-rw-r--r--test/test_jit.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/test/test_jit.py b/test/test_jit.py
index 6e0ad79216..cd95236ff5 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -648,6 +648,21 @@ class TestJit(JitTestCase):
self.assertEqual(origin_result, m3(input.cpu()))
self.assertEqual(origin_result, m4(input.cuda(0)))
+ @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA")
+ def test_restore_shared_storage_on_cuda(self):
+ whole_tensor = torch.randn(4, 5, dtype=torch.float, device='cpu')
+ m = torch.jit.ScriptModule()
+ m.p0 = nn.Parameter(whole_tensor.narrow(0, 0, 1))
+ m.register_buffer('b0', whole_tensor.narrow(0, 3, 1))
+ m2 = self.getExportImportCopy(m, map_location=torch.device('cuda:0'))
+ self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
+ self.assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
+ self.assertTrue(m2.p0.is_cuda)
+ self.assertTrue(m2.b0.is_cuda)
+ self.assertTrue(m2.p0.is_shared())
+ self.assertTrue(m2.b0.is_shared())
+ self.assertEqual(m2.b0.storage().data_ptr(), m2.p0.storage().data_ptr())
+
def test_typeas_trace_check(self):
a = torch.tensor([0.4], requires_grad=True)
b = torch.tensor([0.7], requires_grad=True)