summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--test/test_dataloader.py41
-rw-r--r--torch/utils/data/_utils/pin_memory.py2
-rw-r--r--torch/utils/data/dataloader.py40
3 files changed, 82 insertions, 1 deletions
diff --git a/test/test_dataloader.py b/test/test_dataloader.py
index 1a76459f58..ad7cfdc609 100644
--- a/test/test_dataloader.py
+++ b/test/test_dataloader.py
@@ -956,6 +956,47 @@ class TestDictDataLoader(TestCase):
self.assertTrue(sample['another_dict']['a_number'].is_pinned())
+class SimpleCustomBatch:
+ def __init__(self, data):
+ transposed_data = list(zip(*data))
+ self.inp = torch.stack(transposed_data[0], 0)
+ self.tgt = torch.stack(transposed_data[1], 0)
+
+ def pin_memory(self):
+ self.inp = self.inp.pin_memory()
+ self.tgt = self.tgt.pin_memory()
+ return self
+
+
+def collate_wrapper(batch):
+ return SimpleCustomBatch(batch)
+
+
+class TestCustomPinFn(TestCase):
+ def setUp(self):
+ inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
+ tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
+ self.dataset = TensorDataset(inps, tgts)
+
+ @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
+ @skipIfRocm
+ def test_custom_batch_pin(self):
+ loader = DataLoader(self.dataset, batch_size=2, collate_fn=collate_wrapper,
+ pin_memory=True)
+ for batch_ndx, sample in enumerate(loader):
+ self.assertTrue(sample.inp.is_pinned())
+ self.assertTrue(sample.tgt.is_pinned())
+
+ @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
+ @skipIfRocm
+ def test_custom_batch_pin_worker(self):
+ loader = DataLoader(self.dataset, batch_size=2, collate_fn=collate_wrapper,
+ pin_memory=True, num_workers=1)
+ for batch_ndx, sample in enumerate(loader):
+ self.assertTrue(sample.inp.is_pinned())
+ self.assertTrue(sample.tgt.is_pinned())
+
+
class TestWorkerQueueDataset(Dataset):
def __init__(self, data):
self.data = data
diff --git a/torch/utils/data/_utils/pin_memory.py b/torch/utils/data/_utils/pin_memory.py
index 8403b423ac..07022d1cd0 100644
--- a/torch/utils/data/_utils/pin_memory.py
+++ b/torch/utils/data/_utils/pin_memory.py
@@ -53,5 +53,7 @@ def pin_memory_batch(batch):
return {k: pin_memory_batch(sample) for k, sample in batch.items()}
elif isinstance(batch, container_abcs.Sequence):
return [pin_memory_batch(sample) for sample in batch]
+ elif hasattr(batch, "pin_memory"):
+ return batch.pin_memory()
else:
return batch
diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py
index 013b35a403..9d79990364 100644
--- a/torch/utils/data/dataloader.py
+++ b/torch/utils/data/dataloader.py
@@ -42,7 +42,9 @@ class DataLoader(object):
(default: ``0``)
collate_fn (callable, optional): merges a list of samples to form a mini-batch.
pin_memory (bool, optional): If ``True``, the data loader will copy tensors
- into CUDA pinned memory before returning them.
+ into CUDA pinned memory before returning them. If your data elements
+ are a custom type, or your ``collate_fn`` returns a batch that is a custom type
+ see the example below.
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch size, then the last batch
@@ -65,6 +67,42 @@ class DataLoader(object):
.. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
unpicklable object, e.g., a lambda function.
+
+ The default memory pinning logic only recognizes Tensors and maps and iterables
+ containg Tensors. By default, if the pinning logic sees a batch that is a custom type
+ (which will occur if you have a ``collate_fn`` that returns a custom batch type),
+ or if each element of your batch is a custom type, the pinning logic will not
+ recognize them, and it will return that batch (or those elements)
+ without pinning the memory. To enable memory pinning for custom batch or data types,
+ define a ``pin_memory`` method on your custom type(s).
+
+ Example::
+
+ class SimpleCustomBatch:
+ def __init__(self, data):
+ transposed_data = list(zip(*data))
+ self.inp = torch.stack(transposed_data[0], 0)
+ self.tgt = torch.stack(transposed_data[1], 0)
+
+ def pin_memory(self):
+ self.inp = self.inp.pin_memory()
+ self.tgt = self.tgt.pin_memory()
+ return self
+
+ def collate_wrapper(batch):
+ return SimpleCustomBatch(batch)
+
+ inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
+ tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
+ dataset = TensorDataset(inps, tgts)
+
+ loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
+ pin_memory=True)
+
+ for batch_ndx, sample in enumerate(loader):
+ print(sample.inp.is_pinned())
+ print(sample.tgt.is_pinned())
+
"""
__initialized = False