diff options
-rw-r--r-- | test/test_dataloader.py | 41 | ||||
-rw-r--r-- | torch/utils/data/_utils/pin_memory.py | 2 | ||||
-rw-r--r-- | torch/utils/data/dataloader.py | 40 |
3 files changed, 82 insertions, 1 deletions
diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 1a76459f58..ad7cfdc609 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -956,6 +956,47 @@ class TestDictDataLoader(TestCase): self.assertTrue(sample['another_dict']['a_number'].is_pinned()) +class SimpleCustomBatch: + def __init__(self, data): + transposed_data = list(zip(*data)) + self.inp = torch.stack(transposed_data[0], 0) + self.tgt = torch.stack(transposed_data[1], 0) + + def pin_memory(self): + self.inp = self.inp.pin_memory() + self.tgt = self.tgt.pin_memory() + return self + + +def collate_wrapper(batch): + return SimpleCustomBatch(batch) + + +class TestCustomPinFn(TestCase): + def setUp(self): + inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) + tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) + self.dataset = TensorDataset(inps, tgts) + + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + @skipIfRocm + def test_custom_batch_pin(self): + loader = DataLoader(self.dataset, batch_size=2, collate_fn=collate_wrapper, + pin_memory=True) + for batch_ndx, sample in enumerate(loader): + self.assertTrue(sample.inp.is_pinned()) + self.assertTrue(sample.tgt.is_pinned()) + + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + @skipIfRocm + def test_custom_batch_pin_worker(self): + loader = DataLoader(self.dataset, batch_size=2, collate_fn=collate_wrapper, + pin_memory=True, num_workers=1) + for batch_ndx, sample in enumerate(loader): + self.assertTrue(sample.inp.is_pinned()) + self.assertTrue(sample.tgt.is_pinned()) + + class TestWorkerQueueDataset(Dataset): def __init__(self, data): self.data = data diff --git a/torch/utils/data/_utils/pin_memory.py b/torch/utils/data/_utils/pin_memory.py index 8403b423ac..07022d1cd0 100644 --- a/torch/utils/data/_utils/pin_memory.py +++ b/torch/utils/data/_utils/pin_memory.py @@ -53,5 +53,7 @@ def pin_memory_batch(batch): return {k: pin_memory_batch(sample) for k, sample in batch.items()} elif isinstance(batch, container_abcs.Sequence): return [pin_memory_batch(sample) for sample in batch] + elif hasattr(batch, "pin_memory"): + return batch.pin_memory() else: return batch diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 013b35a403..9d79990364 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -42,7 +42,9 @@ class DataLoader(object): (default: ``0``) collate_fn (callable, optional): merges a list of samples to form a mini-batch. pin_memory (bool, optional): If ``True``, the data loader will copy tensors - into CUDA pinned memory before returning them. + into CUDA pinned memory before returning them. If your data elements + are a custom type, or your ``collate_fn`` returns a batch that is a custom type + see the example below. drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then the last batch @@ -65,6 +67,42 @@ class DataLoader(object): .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an unpicklable object, e.g., a lambda function. + + The default memory pinning logic only recognizes Tensors and maps and iterables + containg Tensors. By default, if the pinning logic sees a batch that is a custom type + (which will occur if you have a ``collate_fn`` that returns a custom batch type), + or if each element of your batch is a custom type, the pinning logic will not + recognize them, and it will return that batch (or those elements) + without pinning the memory. To enable memory pinning for custom batch or data types, + define a ``pin_memory`` method on your custom type(s). + + Example:: + + class SimpleCustomBatch: + def __init__(self, data): + transposed_data = list(zip(*data)) + self.inp = torch.stack(transposed_data[0], 0) + self.tgt = torch.stack(transposed_data[1], 0) + + def pin_memory(self): + self.inp = self.inp.pin_memory() + self.tgt = self.tgt.pin_memory() + return self + + def collate_wrapper(batch): + return SimpleCustomBatch(batch) + + inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) + tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) + dataset = TensorDataset(inps, tgts) + + loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper, + pin_memory=True) + + for batch_ndx, sample in enumerate(loader): + print(sample.inp.is_pinned()) + print(sample.tgt.is_pinned()) + """ __initialized = False |