summaryrefslogtreecommitdiff
path: root/torch/utils
diff options
context:
space:
mode:
authorMichael Carilli <mcarilli@nvidia.com>2018-11-23 08:08:35 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-11-23 08:12:43 -0800
commit7557a993ab9b38583d5fdbebd25d927f31ffb7a9 (patch)
tree19c9eb7ccd3598ba5dd71a926cbc66c85fb54c35 /torch/utils
parentc36156eded4da10673496b707f4f92c8e53d1358 (diff)
downloadpytorch-7557a993ab9b38583d5fdbebd25d927f31ffb7a9.tar.gz
pytorch-7557a993ab9b38583d5fdbebd25d927f31ffb7a9.tar.bz2
pytorch-7557a993ab9b38583d5fdbebd25d927f31ffb7a9.zip
Allow dataloader to accept a custom memory pinning function (#14171)
Summary: Currently, the `pin_memory_batch` function in the dataloader will return a batch comprised of any unrecognized type without pinning the data, because it doesn't know how. This behavior was preventing us from overlapping data prefetching in Mask-RCNN, whose custom `collate_fn` returns a custom batch type. The present PR adds the ability for the user to pass a `pin_fn` alongside any custom `collate_fn` to handle such custom types. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14171 Differential Revision: D13166669 Pulled By: soumith fbshipit-source-id: ca965f9841d4a259b3ca4413c8bd0d8743d433ab
Diffstat (limited to 'torch/utils')
-rw-r--r--torch/utils/data/dataloader.py20
1 files changed, 14 insertions, 6 deletions
diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py
index c1ee0eb92e..012efd30bf 100644
--- a/torch/utils/data/dataloader.py
+++ b/torch/utils/data/dataloader.py
@@ -148,7 +148,7 @@ def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed,
pass
-def _pin_memory_loop(in_queue, out_queue, device_id, done_event):
+def _pin_memory_loop(in_queue, out_queue, device_id, done_event, pin_fn):
torch.cuda.set_device(device_id)
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
@@ -175,7 +175,7 @@ def _pin_memory_loop(in_queue, out_queue, device_id, done_event):
else:
idx, batch = r
try:
- batch = pin_memory_batch(batch)
+ batch = pin_fn(batch)
except Exception:
out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
@@ -521,6 +521,7 @@ class _DataLoaderIter(object):
self.batch_sampler = loader.batch_sampler
self.num_workers = loader.num_workers
self.pin_memory = loader.pin_memory and torch.cuda.is_available()
+ self.pin_fn = loader.pin_fn
self.timeout = loader.timeout
self.sample_iter = iter(self.batch_sampler)
@@ -566,7 +567,7 @@ class _DataLoaderIter(object):
pin_memory_thread = threading.Thread(
target=_pin_memory_loop,
args=(self.worker_result_queue, self.data_queue,
- torch.cuda.current_device(), self.done_event))
+ torch.cuda.current_device(), self.done_event, self.pin_fn))
pin_memory_thread.daemon = True
pin_memory_thread.start()
# Similar to workers (see comment above), we only register
@@ -614,7 +615,7 @@ class _DataLoaderIter(object):
indices = next(self.sample_iter) # may raise StopIteration
batch = self.collate_fn([self.dataset[i] for i in indices])
if self.pin_memory:
- batch = pin_memory_batch(batch)
+ batch = self.pin_fn(batch)
return batch
# check if the next sample has already been generated
@@ -739,6 +740,12 @@ class DataLoader(object):
collate_fn (callable, optional): merges a list of samples to form a mini-batch.
pin_memory (bool, optional): If ``True``, the data loader will copy tensors
into CUDA pinned memory before returning them.
+ pin_fn (callable, optional): If the default pinning logic sees a batch that is a custom class,
+ (or whose elements are a custom class) that it does not recognize, it will return that batch
+ (or those elements) without pinning them. pin_fn gives control of memory pinning
+ to the user, to tell the dataloader how to pin the memory for custom classes.
+ It should acccept a batch, and return that batch with its memory pinned.
+ If pin_memory is False (or not supplied), pin_fn is ignored.
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch size, then the last batch
@@ -766,13 +773,14 @@ class DataLoader(object):
__initialized = False
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
- num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
- timeout=0, worker_init_fn=None):
+ num_workers=0, collate_fn=default_collate, pin_memory=False, pin_fn=pin_memory_batch,
+ drop_last=False, timeout=0, worker_init_fn=None):
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.collate_fn = collate_fn
self.pin_memory = pin_memory
+ self.pin_fn = pin_fn
self.drop_last = drop_last
self.timeout = timeout
self.worker_init_fn = worker_init_fn