summaryrefslogtreecommitdiff
path: root/torch
diff options
context:
space:
mode:
authorTongzhou Wang <tongzhou.wang.1994@gmail.com>2018-10-09 09:51:42 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-10-09 09:54:19 -0700
commit11c31aef047776d0ffca2a3838f511bd0196a81c (patch)
tree58166e64fbc1b449ae4e62aec324072cb6601386 /torch
parent1a0d82e4f43ecce9d13c90d9219f0b419a4852c9 (diff)
downloadpytorch-11c31aef047776d0ffca2a3838f511bd0196a81c.tar.gz
pytorch-11c31aef047776d0ffca2a3838f511bd0196a81c.tar.bz2
pytorch-11c31aef047776d0ffca2a3838f511bd0196a81c.zip
Prevent hanging in data loader altogether
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11985 Differential Revision: D10202374 Pulled By: SsnL fbshipit-source-id: 1ab1a07185f78a104f9b05930a87ef5a32f431e4
Diffstat (limited to 'torch')
-rw-r--r--torch/csrc/DataLoader.cpp21
-rw-r--r--torch/utils/data/dataloader.py374
2 files changed, 331 insertions, 64 deletions
diff --git a/torch/csrc/DataLoader.cpp b/torch/csrc/DataLoader.cpp
index f8c25e56dc..c5cdf64544 100644
--- a/torch/csrc/DataLoader.cpp
+++ b/torch/csrc/DataLoader.cpp
@@ -1,15 +1,14 @@
#include "DataLoader.h"
-// In cases like DataLoader, if a worker process die due to bus error/segfault
-// or just hang, the main process, if implemented with
-// multiprocessing.queue.SimpleQueue, will hang waiting for data. This is
-// difficult to avoid on PyTorch side as it can be caused by limited shm, or
-// other libraries users call in the workers. The following methods is an effort
-// to do our best provide some error message to users when such unfortunate
-// events happen.
+// In cases like DataLoader, if a worker process dies due to bus error/segfault
+// or just hang, the main process will hang waiting for data. This is difficult
+// to avoid on PyTorch side as it can be caused by limited shm, or other
+// libraries users call in the workers. The following methods is an effort to do
+// our best to provide some error message to users when such unfortunate events
+// happen.
// TODO: The following don't work on Windows. Specifically, sigaction, waitid
-// calls ,and SIGCHLD handler. Currently, dummy implementations are provided
+// calls, and SIGCHLD handler. Currently, dummy implementations are provided
// for Windows.
#ifndef _WIN32
@@ -63,6 +62,7 @@ static inline void setSignalHandler(int signal, void(*handler)(int, siginfo_t *,
SIGNAL_HANDLER(SIGBUS, handler_SIGBUS, "ERROR: Unexpected bus error encountered in worker. "
"This might be caused by insufficient shared memory (shm).\n");
SIGNAL_HANDLER(SIGSEGV, handler_SIGSEGV, "ERROR: Unexpected segmentation fault encountered in worker.\n");
+SIGNAL_HANDLER(SIGFPE, handler_SIGFPE, "ERROR: Unexpected floating-point exception encountered in worker.\n");
// When an error happend in DataLoader methods and Python starts to exit, the
// error trace will keep the loader alive, and Python may kill the children
@@ -92,6 +92,7 @@ static PyObject *THPModule_setWorkerSignalHandlers(PyObject *module, PyObject *a
setSignalHandler(SIGBUS, &handler_SIGBUS, nullptr);
setSignalHandler(SIGSEGV, &handler_SIGSEGV, nullptr);
setSignalHandler(SIGTERM, &handler_SIGTERM, nullptr);
+ setSignalHandler(SIGFPE, &handler_SIGFPE, nullptr);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
@@ -130,9 +131,7 @@ static PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module) {
} else if (infop.si_code == CLD_KILLED || infop.si_code == CLD_DUMPED) { // killed by signal
std::ostringstream oss;
oss << "DataLoader worker (pid " << worker_pid << ") is killed "
- << "by signal: " << strsignal(infop.si_status) << ". "
- << "Details are lost due to multiprocessing. Rerunning with "
- << "num_workers=0 may give better error trace.";
+ << "by signal: " << strsignal(infop.si_status) << ". ";
// This is necessary. Otherwise, the runtime error will kill the other
// workers, and trigger this again.
pid_set->clear();
diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py
index bdf27ad897..e6b2744cee 100644
--- a/torch/utils/data/dataloader.py
+++ b/torch/utils/data/dataloader.py
@@ -26,10 +26,20 @@ else:
import queue
+# NOTE [ Python Traceback Reference Cycle Problem ]
+#
+# When using sys.exc_info(), it is important to **not** store the exc_info[2],
+# which is the traceback, because otherwise you will run into the traceback
+# reference cycle problem, i.e., the traceback holding reference to the frame,
+# and the frame (which holds reference to all the object in its temporary scope)
+# holding reference the traceback.
+
+
class ExceptionWrapper(object):
r"""Wraps an exception plus traceback to communicate across threads"""
-
def __init__(self, exc_info):
+ # It is important that we don't store exc_info, see
+ # NOTE [ Python Traceback Reference Cycle Problem ]
self.exc_type = exc_info[0]
self.exc_msg = "".join(traceback.format_exception(*exc_info))
@@ -37,7 +47,11 @@ class ExceptionWrapper(object):
_use_shared_memory = False
r"""Whether to use shared memory in default_collate"""
-MANAGER_STATUS_CHECK_INTERVAL = 5.0
+MP_STATUS_CHECK_INTERVAL = 5.0
+r"""Interval (in seconds) to check status of processes to avoid hanging in
+ multiprocessing data loading. This is mainly used in getting data from
+ another process, in which case we need to periodically check whether the
+ sender is alive to prevent hanging."""
if IS_WINDOWS:
# On Windows, the parent ID of the worker process remains unchanged when the manager process
@@ -60,19 +74,29 @@ if IS_WINDOWS:
if not self.manager_handle:
raise ctypes.WinError(ctypes.get_last_error())
+ self.manager_dead = False
+
def is_alive(self):
- # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
- return self.kernel32.WaitForSingleObject(self.manager_handle, 0) != 0
+ if not self.manager_dead:
+ # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
+ self.manager_dead = self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
+ return not self.manager_dead
else:
class ManagerWatchdog(object):
def __init__(self):
self.manager_pid = os.getppid()
+ self.manager_dead = False
def is_alive(self):
- return os.getppid() == self.manager_pid
+ if not self.manager_dead:
+ self.manager_dead = os.getppid() != self.manager_pid
+ return not self.manager_dead
def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed, init_fn, worker_id):
+ # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
+ # logic of this function.
+
try:
global _use_shared_memory
_use_shared_memory = True
@@ -87,9 +111,6 @@ def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed,
random.seed(seed)
torch.manual_seed(seed)
- # Do not wait for putting thread to join when this worker exits.
- # Otherwise, this worker may always be waiting to put and doesn't check
- # index_queue and done_event for termination signal.
data_queue.cancel_join_thread()
if init_fn is not None:
@@ -97,22 +118,26 @@ def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed,
watchdog = ManagerWatchdog()
- while True:
+ while watchdog.is_alive():
try:
- r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
+ r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
- if watchdog.is_alive() and not done_event.is_set():
- continue
- else:
- break
- # use done_event so that we can get faster exiting signal even if there
- # are still indices in index_queue
- if r is None or done_event.is_set():
- break
+ continue
+ if r is None:
+ # Received the final signal
+ assert done_event.is_set()
+ return
+ elif done_event.is_set():
+ # Done event is set. But I haven't received the final signal
+ # (None) yet. I will keep continuing until get it, and skip the
+ # processing steps.
+ continue
idx, batch_indices = r
try:
samples = collate_fn([dataset[i] for i in batch_indices])
except Exception:
+ # It is important that we don't store exc_info in a variable,
+ # see NOTE [ Python Traceback Reference Cycle Problem ]
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
data_queue.put((idx, samples))
@@ -122,30 +147,38 @@ def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed,
pass
-def _pin_memory_loop(in_queue, out_queue, done_event, pin_memory, device_id):
- if pin_memory:
- torch.cuda.set_device(device_id)
+def _pin_memory_loop(in_queue, out_queue, device_id, done_event):
+ torch.cuda.set_device(device_id)
+ # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
+ # logic of this function.
while True:
try:
- r = in_queue.get()
+ r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
+ except queue.Empty:
+ continue
except Exception:
if done_event.is_set():
- return
+ # Weird things can happen when shutting down, e.g., fd being
+ # closed when tensors are shared via fds.
+ break
raise
- if r is None or done_event.is_set():
- break
- if isinstance(r[1], ExceptionWrapper):
- out_queue.put(r)
+ if r is None:
+ assert done_event.is_set()
+ return
+ elif done_event.is_set():
+ # Haven't seen the final signal yet. Keep getting until None.
continue
- idx, batch = r
- try:
- if pin_memory:
- batch = pin_memory_batch(batch)
- except Exception:
- out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
+ elif isinstance(r[1], ExceptionWrapper):
+ out_queue.put(r)
else:
- out_queue.put((idx, batch))
+ idx, batch = r
+ try:
+ batch = pin_memory_batch(batch)
+ except Exception:
+ out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
+ else:
+ out_queue.put((idx, batch))
numpy_type_map = {
'float64': torch.DoubleTensor,
@@ -230,6 +263,8 @@ def _set_SIGCHLD_handler():
return
previous_handler = signal.getsignal(signal.SIGCHLD)
if not callable(previous_handler):
+ # This doesn't catch default handler, but SIGCHLD default handler is a
+ # no-op.
previous_handler = None
def handler(signum, frame):
@@ -246,6 +281,207 @@ def _set_SIGCHLD_handler():
class _DataLoaderIter(object):
r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
+ # NOTE [ Data Loader Multiprocessing Shutdown Logic ]
+ #
+ # Preliminary:
+ #
+ # Our data model looks like this (queues are indicated with curly brackets):
+ #
+ # main process ||
+ # | ||
+ # {index_queue} ||
+ # | ||
+ # worker processes || DATA
+ # | ||
+ # {worker_result_queue} || FLOW
+ # | ||
+ # pin_memory_thread of main process || DIRECTION
+ # | ||
+ # {data_queue} ||
+ # | ||
+ # data output \/
+ #
+ # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
+ # `pin_memory=False`.
+ #
+ #
+ # Terminating multiprocessing logic requires very careful design. In
+ # particular, we need to make sure that
+ #
+ # 1. The iterator gracefully exits the workers when its last reference is
+ # gone.
+ #
+ # In this case, the workers should be gracefully exited because the
+ # main process may still need to continue to run, and we want cleaning
+ # up code in the workers to be executed (e.g., releasing GPU memory).
+ # Naturally, we implement the shutdown logic in `__del__` of
+ # DataLoaderIterator.
+ #
+ # We delay the discussion on the logic in this case until later.
+ #
+ # 2. The iterator exits the workers when the loader process and/or worker
+ # processes exits unexpectedly (e.g., SIGKILL-ed).
+ #
+ # We set all workers and `pin_memory_thread` to have `daemon=True`.
+ #
+ # You may ask, why can't we make the workers non-daemonic, and
+ # gracefully exit using the same logic as we have in `__del__` when the
+ # iterator gets deleted (see 1 above)?
+ #
+ # When a process ends, it shuts the all its daemonic children down with
+ # a SIGTERM (instead of joining them without a timeout). Simiarly for
+ # threads, but by a different mechanism. This fact, together with a few
+ # implementation details of multiprocessing, forces us to make workers
+ # daemonic. All of our problems arise when a DataLoader is used in a
+ # subprocess, and are caused by multiprocessing code which looks more
+ # or less like this:
+ #
+ # try:
+ # your_function_using_a_dataloader()
+ # finally:
+ # multiprocessing.util._exit_function()
+ #
+ # The joining/termination mentioned above happens inside
+ # `_exit_function()`. Now, if `your_function_using_a_dataloader()`
+ # throws, the stack trace stored in the exception will prevent the
+ # frame which uses `DataLoaderIter` to be freed. If the frame has any
+ # reference to the `DataLoaderIter` (e.g., in a method of the iter),
+ # its `__del__`, which starts the shutdown procedure, will not be
+ # called. That, in turn, means that workers aren't notified. Attempting
+ # to join in `_exit_function` will then result in a hang.
+ #
+ # For context, `_exit_function` is also registered as an `atexit` call.
+ # So it is unclear to me (@ssnl) why this is needed in a finally block.
+ # The code dates back to 2008 and there is no comment on the original
+ # PEP 371 or patch https://bugs.python.org/issue3050 (containing both
+ # the finally block and the `atexit` registration) that explains this.
+ #
+ # Another choice is to just shutdown workers with logic in 1 above
+ # whenever we see an error in `next`. This isn't ideal because
+ # a. It prevents users from using try-catch to resume data loading.
+ # b. It doesn't prevent hanging if users have references to the
+ # iterator.
+ #
+ # 3. All processes exit if any of them die unexpectedly (e.g., error,
+ # fatal signals).
+ #
+ # As shown above, the workers are set as daemonic children of the main
+ # process. However, automatic cleaning-up of such child processes only
+ # happens if the parent process exits gracefully (e.g., not via fatal
+ # signals like SIGKILL). So we must ensure that each process will exit
+ # even the process that should send/receive data to/from it were
+ # killed, i.e.,
+ #
+ # a. A process won't hang when getting from a queue.
+ #
+ # Even with carefully designed data dependencies (i.e., a `put()`
+ # always corresponding to a `get()`), hanging on `get()` can still
+ # happen when data in queue is corrupted (e.g., due to
+ # `cancel_join_thread` or unexpected exit).
+ #
+ # For child exit, we register SIGCHLD handler on main process,
+ # which checks if any of the workers fail in the (Python) handler.
+ # See DataLoader.cpp.
+ #
+ # For `.get()` calls where the sender(s) is not the workers, we
+ # guard them with timeouts, and check the status of the sender
+ # when timeout happens:
+ # + in the workers, the `ManagerWatchdog` class checks the main
+ # process status.
+ # + if `pin_memory=True`, when getting from `pin_memory_thread`,
+ # check `pin_memory_thread` status periodically until `.get()`
+ # returns or see that `pin_memory_thread` died.
+ #
+ # b. A process won't hang when putting into a queue;
+ #
+ # We use `mp.Queue` which has a separate background thread to put
+ # objects from an unbounded buffer array. The background thread is
+ # daemonic and usually automatically joined when the process
+ # exits.
+ #
+ # However, in case that the receiver has ended abruptly while
+ # reading from the pipe, the join will hang forever. Therefore,
+ # for both `worker_result_queue` (worker -> main process/pin_memory_thread)
+ # and each `index_queue` (main process -> worker), we use
+ # `q.cancel_join_thread()` in sender process before any `q.put` to
+ # prevent this automatic join.
+ #
+ # Moreover, having all queues called `cancel_join_thread` makes
+ # implementing graceful shutdown logic in `__del__` much easier.
+ # It won't need to get from any queue, which would also need to be
+ # guarded by periodic status checks.
+ #
+ # Note that this may leave corrupted data in the queue, but we
+ # don't care about the data anyways once we are shutting down.
+ #
+ #
+ # Now let's get back to 1:
+ # how we gracefully exit the workers when the last reference to the
+ # iteartor is gone.
+ #
+ # To achieve this, we implement the following logic along with the design
+ # choices mentioned above:
+ #
+ # [worker processes]
+ # While loader process is alive:
+ # Get from index_queue.
+ # If got a `None`, exit.
+ # If get anything else,
+ # Check `done_event`.
+ # If set, continue to next iteration
+ # i.e., keep getting until see the `None`, then exit.
+ # Otherwise, process data.
+ # If timed out,
+ # No matter `done_event` is set (still need to see `None`) or not,
+ # must continue to next iteration .
+ #
+ # [pin_memory_thread]
+ # # No need to check main thread. If this thread is alive, the main loader
+ # # thread must be alive, because this thread is set as daemonic.
+ # While True:
+ # Get from index_queue.
+ # If got a `None`, exit.
+ # If get anything else,
+ # Check `done_event`.
+ # If set, continue to next iteration
+ # i.e., keep getting until see the `None`, then exit.
+ # Otherwise, process data.
+ #
+ # NOTE: we don't check the status of the main thread because
+ # 1. if the process is killed by fatal signal, `pin_memory_thread`
+ # ends.
+ # 2. in other cases, either the cleaning-up in __del__ or the
+ # automatic exit of daemonic thread will take care of it.
+ # This won't busy-wait either because `.get(timeout)` does not
+ # busy-wait.
+ #
+ # [main process]
+ # In the DataLoader Iter's `__del__`
+ # a. Set `done_event` (shared with `pin_memory_thread` and workers).
+ #
+ # Note: from here on, the workers & `pin_memory_thread` may exit at
+ # any time after they receive `None`.
+ #
+ # b. Exit `pin_memory_thread`
+ # i. Put `None` in `worker_result_queue`.
+ # ii. Join the `pin_memory_thread`.
+ #
+ # c. Exit the workers.
+ # i. Put `None` in each worker's `index_queue`.
+ # ii. Join the workers.
+ #
+ # NOTE: This has to be after (b) because it may leave corrupted data
+ # in `worker_result_queue`, which `pin_memory_thread` reads
+ # from.
+ #
+ # NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
+ # can be omitted
+ #
+ # NB: `done_event`s isn't strictly needed. E.g., we can just check for
+ # `None` from `index_queue`, but it allows us to skip wasting resources
+ # processing indices already in `index_queue` if we are already shutting
+ # down.
+
def __init__(self, loader):
self.dataset = loader.dataset
self.collate_fn = loader.collate_fn
@@ -274,18 +510,19 @@ class _DataLoaderIter(object):
self.workers = []
for i in range(self.num_workers):
index_queue = multiprocessing.Queue()
+ index_queue.cancel_join_thread()
w = multiprocessing.Process(
target=_worker_loop,
args=(self.dataset, index_queue,
self.worker_result_queue, self.done_event,
self.collate_fn, base_seed + i,
self.worker_init_fn, i))
- w.daemon = True # ensure that the worker exits on process exit
- # Process.start() actually take some time as it needs to start a
- # process and pass the arguments over via a pipe. Therefore, we
- # only add a worker to self.workers list after it started, so
- # that we do not call .join() if program dies before it starts,
- # and __del__ tries to join it but will get:
+ w.daemon = True
+ # NB: Process.start() actually take some time as it needs to
+ # start a process and pass the arguments over via a pipe.
+ # Therefore, we only add a worker to self.workers list after
+ # it started, so that we do not call .join() if program dies
+ # before it starts, and __del__ tries to join but will get:
# AssertionError: can only join a started process.
w.start()
self.index_queues.append(index_queue)
@@ -295,8 +532,8 @@ class _DataLoaderIter(object):
self.data_queue = queue.Queue()
pin_memory_thread = threading.Thread(
target=_pin_memory_loop,
- args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
- torch.cuda.current_device()))
+ args=(self.worker_result_queue, self.data_queue,
+ torch.cuda.current_device(), self.done_event))
pin_memory_thread.daemon = True
pin_memory_thread.start()
# Similar to workers (see comment above), we only register
@@ -317,11 +554,25 @@ class _DataLoaderIter(object):
return len(self.batch_sampler)
def _get_batch(self):
+ # In the non-timeout case, worker exit is covered by SIGCHLD handler.
+ # But if `pin_memory=True`, we still need account for the possibility
+ # that `pin_memory_thread` dies.
if self.timeout > 0:
try:
return self.data_queue.get(timeout=self.timeout)
except queue.Empty:
raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
+ elif self.pin_memory:
+ while self.pin_memory_thread.is_alive():
+ try:
+ return self.data_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
+ except queue.Empty:
+ continue
+ else:
+ # while condition is false, i.e., pin_memory_thread died.
+ raise RuntimeError('Pin memory thread exited unexpectedly')
+ # In this case, `self.data_queue` is a `queue.Queue`,. But we don't
+ # need to call `.task_done()` because we don't use `.join()`.
else:
return self.data_queue.get()
@@ -383,29 +634,46 @@ class _DataLoaderIter(object):
raise NotImplementedError("_DataLoaderIter cannot be pickled")
def _shutdown_workers(self):
+ # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
+ # the logic of this function.
if not self.shutdown:
self.shutdown = True
- # removes pids from the C side data structure first so worker
+ # Removes pids from the C side data structure first so worker
# termination afterwards won't trigger false positive error report.
if self.worker_pids_set:
_remove_worker_pids(id(self))
self.worker_pids_set = False
+
self.done_event.set()
- if self.pin_memory:
- # Sending `None` to `pin_memory_thread` must be before
- # stopping worker processes because the workers may leave
- # corrupted data in `worker_result_queue`, causing
- # `pin_memory_thread` unable to read and terminate properly.
+
+ # Exit `pin_memory_thread` first because exiting workers may leave
+ # corrupted data in `worker_result_queue` which `pin_memory_thread`
+ # reads from.
+ if hasattr(self, 'pin_memory_thread'):
+ # Use hasattr in case error happens before we set the attribute.
+ # First time do `worker_result_queue.put` in this process.
+
+ # `cancel_join_thread` in case that `pin_memory_thread` exited.
+ self.worker_result_queue.cancel_join_thread()
self.worker_result_queue.put(None)
- # Workers can't be waiting to put be cause their output queue
- # is a multiprocessing.Queue and its .put is non-blocking.
- # They can only be waiting to get, so we put `None` here.
+ self.pin_memory_thread.join()
+
+ # Indicate that no more data will be put on this queue by the
+ # current process. This **must** be called after
+ # `pin_memory_thread` is joined because that thread shares the
+ # same pipe handles with this loader thread. If the handle is
+ # closed, Py3 will error in this case, but Py2 will just time
+ # out even if there is data in the queue.
+ self.worker_result_queue.close()
+
+ # Exit workers now.
for q in self.index_queues:
q.put(None)
+ # Indicate that no more data will be put on this queue by the
+ # current process.
+ q.close()
for w in self.workers:
w.join()
- if hasattr(self, 'pin_memory_thread'):
- self.pin_memory_thread.join()
def __del__(self):
if self.num_workers > 0: