diff options
author | Michael Carilli <mcarilli@nvidia.com> | 2019-02-10 19:31:23 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-10 19:37:53 -0800 |
commit | 0742874643dc91fe7231e55f1ba9839ef0188a48 (patch) | |
tree | 1246a97e647935ad1257dfebfb5ec0a744018947 /test | |
parent | 73d7ecd18398cf121a9a2ff2faad302d429bef3e (diff) | |
download | pytorch-0742874643dc91fe7231e55f1ba9839ef0188a48.tar.gz pytorch-0742874643dc91fe7231e55f1ba9839ef0188a48.tar.bz2 pytorch-0742874643dc91fe7231e55f1ba9839ef0188a48.zip |
Allow dataloader to accept a custom memory pinning function (#16743)
Summary:
Renewed attempt at https://github.com/pytorch/pytorch/pull/14171
From the original PR:
> Currently, the pin_memory_batch function in the dataloader will return a batch comprised of any unrecognized type without pinning the data, because it doesn't know how.
>
>This behavior was preventing us from overlapping data prefetching in Mask-RCNN, whose custom collate_fn returns a custom batch type.
The old PR allowed the user to implement batch pinning for custom batch and data types by passing a custom pin function to the dataloader. slayton58 suggested a cleaner approach: allow the user to define a `pin_memory` method on their custom types, and have `pin_memory_batch` [check for the presence of that method](https://github.com/pytorch/pytorch/pull/16743/files#diff-9f154cbd884fe654066b1621fad654f3R56) in the incoming batch as a fallback. I've updated the test and docstrings accordingly.
The old PR was merged but then reverted due to weird cuda OOM errors on windows that may or may not have been related. I have no idea why my changes would cause such errors (then or now) but it's something to keep an eye out for.
fmassa and yf225 who were my POCs on the old PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16743
Differential Revision: D13991745
Pulled By: ezyang
fbshipit-source-id: 74e71f62a03be453b4caa9f5524e9bc53467fa17
Diffstat (limited to 'test')
-rw-r--r-- | test/test_dataloader.py | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 1a76459f58..ad7cfdc609 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -956,6 +956,47 @@ class TestDictDataLoader(TestCase): self.assertTrue(sample['another_dict']['a_number'].is_pinned()) +class SimpleCustomBatch: + def __init__(self, data): + transposed_data = list(zip(*data)) + self.inp = torch.stack(transposed_data[0], 0) + self.tgt = torch.stack(transposed_data[1], 0) + + def pin_memory(self): + self.inp = self.inp.pin_memory() + self.tgt = self.tgt.pin_memory() + return self + + +def collate_wrapper(batch): + return SimpleCustomBatch(batch) + + +class TestCustomPinFn(TestCase): + def setUp(self): + inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) + tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) + self.dataset = TensorDataset(inps, tgts) + + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + @skipIfRocm + def test_custom_batch_pin(self): + loader = DataLoader(self.dataset, batch_size=2, collate_fn=collate_wrapper, + pin_memory=True) + for batch_ndx, sample in enumerate(loader): + self.assertTrue(sample.inp.is_pinned()) + self.assertTrue(sample.tgt.is_pinned()) + + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + @skipIfRocm + def test_custom_batch_pin_worker(self): + loader = DataLoader(self.dataset, batch_size=2, collate_fn=collate_wrapper, + pin_memory=True, num_workers=1) + for batch_ndx, sample in enumerate(loader): + self.assertTrue(sample.inp.is_pinned()) + self.assertTrue(sample.tgt.is_pinned()) + + class TestWorkerQueueDataset(Dataset): def __init__(self, data): self.data = data |