diff options
author | Michael Carilli <mcarilli@nvidia.com> | 2018-11-23 08:08:35 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-11-23 08:12:43 -0800 |
commit | 7557a993ab9b38583d5fdbebd25d927f31ffb7a9 (patch) | |
tree | 19c9eb7ccd3598ba5dd71a926cbc66c85fb54c35 /torch/utils | |
parent | c36156eded4da10673496b707f4f92c8e53d1358 (diff) | |
download | pytorch-7557a993ab9b38583d5fdbebd25d927f31ffb7a9.tar.gz pytorch-7557a993ab9b38583d5fdbebd25d927f31ffb7a9.tar.bz2 pytorch-7557a993ab9b38583d5fdbebd25d927f31ffb7a9.zip |
Allow dataloader to accept a custom memory pinning function (#14171)
Summary:
Currently, the `pin_memory_batch` function in the dataloader will return a batch comprised of any unrecognized type without pinning the data, because it doesn't know how.
This behavior was preventing us from overlapping data prefetching in Mask-RCNN, whose custom `collate_fn` returns a custom batch type.
The present PR adds the ability for the user to pass a `pin_fn` alongside any custom `collate_fn` to handle such custom types.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14171
Differential Revision: D13166669
Pulled By: soumith
fbshipit-source-id: ca965f9841d4a259b3ca4413c8bd0d8743d433ab
Diffstat (limited to 'torch/utils')
-rw-r--r-- | torch/utils/data/dataloader.py | 20 |
1 files changed, 14 insertions, 6 deletions
diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index c1ee0eb92e..012efd30bf 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -148,7 +148,7 @@ def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed, pass -def _pin_memory_loop(in_queue, out_queue, device_id, done_event): +def _pin_memory_loop(in_queue, out_queue, device_id, done_event, pin_fn): torch.cuda.set_device(device_id) # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the @@ -175,7 +175,7 @@ def _pin_memory_loop(in_queue, out_queue, device_id, done_event): else: idx, batch = r try: - batch = pin_memory_batch(batch) + batch = pin_fn(batch) except Exception: out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) else: @@ -521,6 +521,7 @@ class _DataLoaderIter(object): self.batch_sampler = loader.batch_sampler self.num_workers = loader.num_workers self.pin_memory = loader.pin_memory and torch.cuda.is_available() + self.pin_fn = loader.pin_fn self.timeout = loader.timeout self.sample_iter = iter(self.batch_sampler) @@ -566,7 +567,7 @@ class _DataLoaderIter(object): pin_memory_thread = threading.Thread( target=_pin_memory_loop, args=(self.worker_result_queue, self.data_queue, - torch.cuda.current_device(), self.done_event)) + torch.cuda.current_device(), self.done_event, self.pin_fn)) pin_memory_thread.daemon = True pin_memory_thread.start() # Similar to workers (see comment above), we only register @@ -614,7 +615,7 @@ class _DataLoaderIter(object): indices = next(self.sample_iter) # may raise StopIteration batch = self.collate_fn([self.dataset[i] for i in indices]) if self.pin_memory: - batch = pin_memory_batch(batch) + batch = self.pin_fn(batch) return batch # check if the next sample has already been generated @@ -739,6 +740,12 @@ class DataLoader(object): collate_fn (callable, optional): merges a list of samples to form a mini-batch. pin_memory (bool, optional): If ``True``, the data loader will copy tensors into CUDA pinned memory before returning them. + pin_fn (callable, optional): If the default pinning logic sees a batch that is a custom class, + (or whose elements are a custom class) that it does not recognize, it will return that batch + (or those elements) without pinning them. pin_fn gives control of memory pinning + to the user, to tell the dataloader how to pin the memory for custom classes. + It should acccept a batch, and return that batch with its memory pinned. + If pin_memory is False (or not supplied), pin_fn is ignored. drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then the last batch @@ -766,13 +773,14 @@ class DataLoader(object): __initialized = False def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, - num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, - timeout=0, worker_init_fn=None): + num_workers=0, collate_fn=default_collate, pin_memory=False, pin_fn=pin_memory_batch, + drop_last=False, timeout=0, worker_init_fn=None): self.dataset = dataset self.batch_size = batch_size self.num_workers = num_workers self.collate_fn = collate_fn self.pin_memory = pin_memory + self.pin_fn = pin_fn self.drop_last = drop_last self.timeout = timeout self.worker_init_fn = worker_init_fn |