diff options
author | Vitaly Fedyunin <vitalyf@fb.com> | 2019-04-02 08:44:27 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-02 08:48:19 -0700 |
commit | c484cf43a02863efd2f4a76aad43246fb0191ab5 (patch) | |
tree | 5afc7b9dbf8c325300e3c49248a3b189ee2e43c0 /test | |
parent | aed7c9bc96fe35fce6508e19dab1e2cfbc766968 (diff) | |
download | pytorch-c484cf43a02863efd2f4a76aad43246fb0191ab5.tar.gz pytorch-c484cf43a02863efd2f4a76aad43246fb0191ab5.tar.bz2 pytorch-c484cf43a02863efd2f4a76aad43246fb0191ab5.zip |
Adding pin_memory kwarg to zeros, ones, empty, ... tensor constructors. (#18455)
Summary:
Make it possible to construct a pinned memory tensor without creating a storage first and without calling pin_memory() function. It is also faster, as copy operation is unnecessary.
Supported functions:
```python
torch.rand_like(t, pin_memory=True)
torch.randn_like(t, pin_memory=True)
torch.empty_like(t, pin_memory=True)
torch.full_like(t, 4, pin_memory=True)
torch.zeros_like(t, pin_memory=True)
torch.ones_like(t, pin_memory=True)
torch.tensor([10,11], pin_memory=True)
torch.randn(3, 5, pin_memory=True)
torch.rand(3, pin_memory=True)
torch.zeros(3, pin_memory=True)
torch.randperm(3, pin_memory=True)
torch.empty(6, pin_memory=True)
torch.ones(6, pin_memory=True)
torch.eye(6, pin_memory=True)
torch.arange(3, 5, pin_memory=True)
```
Part of the bigger: `Remove Storage` plan.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18455
Reviewed By: ezyang
Differential Revision: D14672084
Pulled By: VitalyFedyunin
fbshipit-source-id: 9d0997ec00f59500ee018f8b851934d334012124
Diffstat (limited to 'test')
-rw-r--r-- | test/test_torch.py | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/test/test_torch.py b/test/test_torch.py index 855e8f9ba9..5de9045543 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -9815,6 +9815,40 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], self.assertEqual(pinned, x) self.assertNotEqual(pinned.data_ptr(), x.data_ptr()) + @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA') + def test_pin_memory_from_constructor(self): + + def _get_like(t, **kwargs): + return [ + torch.rand_like(t, **kwargs), + torch.randn_like(t, **kwargs), + torch.empty_like(t, **kwargs), + torch.full_like(t, 4, **kwargs), + torch.zeros_like(t, **kwargs), + torch.ones_like(t, **kwargs), + ] + + def _get_tensors(**kwargs): + return [ + torch.tensor([10,11], **kwargs), + torch.randn(3, 5, **kwargs), + torch.rand(3, **kwargs), + # torch.randint(3,5, **kwargs), // unsupported + torch.zeros(3, **kwargs), + torch.randperm(3, **kwargs), + torch.empty(6, **kwargs), + torch.ones(6, **kwargs), + torch.eye(6, **kwargs), + torch.arange(3, 5, **kwargs),] + + pinned_tensors = _get_tensors(pin_memory=True) + _get_like(torch.empty(5, dtype=torch.float64), pin_memory=True) + for x in pinned_tensors: + self.assertTrue(x.is_pinned()) + + tensors = _get_tensors() + _get_like(torch.empty(5, dtype=torch.float64, pin_memory=True)) + for x in tensors: + self.assertFalse(x.is_pinned()) + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") def test_numpy_unresizable(self): x = np.zeros((2, 2)) |