diff options
author | Tongzhou Wang <tongzhou.wang.1994@gmail.com> | 2018-10-09 09:51:42 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-10-09 09:54:19 -0700 |
commit | 11c31aef047776d0ffca2a3838f511bd0196a81c (patch) | |
tree | 58166e64fbc1b449ae4e62aec324072cb6601386 /torch | |
parent | 1a0d82e4f43ecce9d13c90d9219f0b419a4852c9 (diff) | |
download | pytorch-11c31aef047776d0ffca2a3838f511bd0196a81c.tar.gz pytorch-11c31aef047776d0ffca2a3838f511bd0196a81c.tar.bz2 pytorch-11c31aef047776d0ffca2a3838f511bd0196a81c.zip |
Prevent hanging in data loader altogether
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11985
Differential Revision: D10202374
Pulled By: SsnL
fbshipit-source-id: 1ab1a07185f78a104f9b05930a87ef5a32f431e4
Diffstat (limited to 'torch')
-rw-r--r-- | torch/csrc/DataLoader.cpp | 21 | ||||
-rw-r--r-- | torch/utils/data/dataloader.py | 374 |
2 files changed, 331 insertions, 64 deletions
diff --git a/torch/csrc/DataLoader.cpp b/torch/csrc/DataLoader.cpp index f8c25e56dc..c5cdf64544 100644 --- a/torch/csrc/DataLoader.cpp +++ b/torch/csrc/DataLoader.cpp @@ -1,15 +1,14 @@ #include "DataLoader.h" -// In cases like DataLoader, if a worker process die due to bus error/segfault -// or just hang, the main process, if implemented with -// multiprocessing.queue.SimpleQueue, will hang waiting for data. This is -// difficult to avoid on PyTorch side as it can be caused by limited shm, or -// other libraries users call in the workers. The following methods is an effort -// to do our best provide some error message to users when such unfortunate -// events happen. +// In cases like DataLoader, if a worker process dies due to bus error/segfault +// or just hang, the main process will hang waiting for data. This is difficult +// to avoid on PyTorch side as it can be caused by limited shm, or other +// libraries users call in the workers. The following methods is an effort to do +// our best to provide some error message to users when such unfortunate events +// happen. // TODO: The following don't work on Windows. Specifically, sigaction, waitid -// calls ,and SIGCHLD handler. Currently, dummy implementations are provided +// calls, and SIGCHLD handler. Currently, dummy implementations are provided // for Windows. #ifndef _WIN32 @@ -63,6 +62,7 @@ static inline void setSignalHandler(int signal, void(*handler)(int, siginfo_t *, SIGNAL_HANDLER(SIGBUS, handler_SIGBUS, "ERROR: Unexpected bus error encountered in worker. " "This might be caused by insufficient shared memory (shm).\n"); SIGNAL_HANDLER(SIGSEGV, handler_SIGSEGV, "ERROR: Unexpected segmentation fault encountered in worker.\n"); +SIGNAL_HANDLER(SIGFPE, handler_SIGFPE, "ERROR: Unexpected floating-point exception encountered in worker.\n"); // When an error happend in DataLoader methods and Python starts to exit, the // error trace will keep the loader alive, and Python may kill the children @@ -92,6 +92,7 @@ static PyObject *THPModule_setWorkerSignalHandlers(PyObject *module, PyObject *a setSignalHandler(SIGBUS, &handler_SIGBUS, nullptr); setSignalHandler(SIGSEGV, &handler_SIGSEGV, nullptr); setSignalHandler(SIGTERM, &handler_SIGTERM, nullptr); + setSignalHandler(SIGFPE, &handler_SIGFPE, nullptr); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -130,9 +131,7 @@ static PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module) { } else if (infop.si_code == CLD_KILLED || infop.si_code == CLD_DUMPED) { // killed by signal std::ostringstream oss; oss << "DataLoader worker (pid " << worker_pid << ") is killed " - << "by signal: " << strsignal(infop.si_status) << ". " - << "Details are lost due to multiprocessing. Rerunning with " - << "num_workers=0 may give better error trace."; + << "by signal: " << strsignal(infop.si_status) << ". "; // This is necessary. Otherwise, the runtime error will kill the other // workers, and trigger this again. pid_set->clear(); diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index bdf27ad897..e6b2744cee 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -26,10 +26,20 @@ else: import queue +# NOTE [ Python Traceback Reference Cycle Problem ] +# +# When using sys.exc_info(), it is important to **not** store the exc_info[2], +# which is the traceback, because otherwise you will run into the traceback +# reference cycle problem, i.e., the traceback holding reference to the frame, +# and the frame (which holds reference to all the object in its temporary scope) +# holding reference the traceback. + + class ExceptionWrapper(object): r"""Wraps an exception plus traceback to communicate across threads""" - def __init__(self, exc_info): + # It is important that we don't store exc_info, see + # NOTE [ Python Traceback Reference Cycle Problem ] self.exc_type = exc_info[0] self.exc_msg = "".join(traceback.format_exception(*exc_info)) @@ -37,7 +47,11 @@ class ExceptionWrapper(object): _use_shared_memory = False r"""Whether to use shared memory in default_collate""" -MANAGER_STATUS_CHECK_INTERVAL = 5.0 +MP_STATUS_CHECK_INTERVAL = 5.0 +r"""Interval (in seconds) to check status of processes to avoid hanging in + multiprocessing data loading. This is mainly used in getting data from + another process, in which case we need to periodically check whether the + sender is alive to prevent hanging.""" if IS_WINDOWS: # On Windows, the parent ID of the worker process remains unchanged when the manager process @@ -60,19 +74,29 @@ if IS_WINDOWS: if not self.manager_handle: raise ctypes.WinError(ctypes.get_last_error()) + self.manager_dead = False + def is_alive(self): - # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx - return self.kernel32.WaitForSingleObject(self.manager_handle, 0) != 0 + if not self.manager_dead: + # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx + self.manager_dead = self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0 + return not self.manager_dead else: class ManagerWatchdog(object): def __init__(self): self.manager_pid = os.getppid() + self.manager_dead = False def is_alive(self): - return os.getppid() == self.manager_pid + if not self.manager_dead: + self.manager_dead = os.getppid() != self.manager_pid + return not self.manager_dead def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed, init_fn, worker_id): + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the + # logic of this function. + try: global _use_shared_memory _use_shared_memory = True @@ -87,9 +111,6 @@ def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed, random.seed(seed) torch.manual_seed(seed) - # Do not wait for putting thread to join when this worker exits. - # Otherwise, this worker may always be waiting to put and doesn't check - # index_queue and done_event for termination signal. data_queue.cancel_join_thread() if init_fn is not None: @@ -97,22 +118,26 @@ def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed, watchdog = ManagerWatchdog() - while True: + while watchdog.is_alive(): try: - r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL) + r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) except queue.Empty: - if watchdog.is_alive() and not done_event.is_set(): - continue - else: - break - # use done_event so that we can get faster exiting signal even if there - # are still indices in index_queue - if r is None or done_event.is_set(): - break + continue + if r is None: + # Received the final signal + assert done_event.is_set() + return + elif done_event.is_set(): + # Done event is set. But I haven't received the final signal + # (None) yet. I will keep continuing until get it, and skip the + # processing steps. + continue idx, batch_indices = r try: samples = collate_fn([dataset[i] for i in batch_indices]) except Exception: + # It is important that we don't store exc_info in a variable, + # see NOTE [ Python Traceback Reference Cycle Problem ] data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) else: data_queue.put((idx, samples)) @@ -122,30 +147,38 @@ def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed, pass -def _pin_memory_loop(in_queue, out_queue, done_event, pin_memory, device_id): - if pin_memory: - torch.cuda.set_device(device_id) +def _pin_memory_loop(in_queue, out_queue, device_id, done_event): + torch.cuda.set_device(device_id) + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the + # logic of this function. while True: try: - r = in_queue.get() + r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) + except queue.Empty: + continue except Exception: if done_event.is_set(): - return + # Weird things can happen when shutting down, e.g., fd being + # closed when tensors are shared via fds. + break raise - if r is None or done_event.is_set(): - break - if isinstance(r[1], ExceptionWrapper): - out_queue.put(r) + if r is None: + assert done_event.is_set() + return + elif done_event.is_set(): + # Haven't seen the final signal yet. Keep getting until None. continue - idx, batch = r - try: - if pin_memory: - batch = pin_memory_batch(batch) - except Exception: - out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) + elif isinstance(r[1], ExceptionWrapper): + out_queue.put(r) else: - out_queue.put((idx, batch)) + idx, batch = r + try: + batch = pin_memory_batch(batch) + except Exception: + out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) + else: + out_queue.put((idx, batch)) numpy_type_map = { 'float64': torch.DoubleTensor, @@ -230,6 +263,8 @@ def _set_SIGCHLD_handler(): return previous_handler = signal.getsignal(signal.SIGCHLD) if not callable(previous_handler): + # This doesn't catch default handler, but SIGCHLD default handler is a + # no-op. previous_handler = None def handler(signum, frame): @@ -246,6 +281,207 @@ def _set_SIGCHLD_handler(): class _DataLoaderIter(object): r"""Iterates once over the DataLoader's dataset, as specified by the sampler""" + # NOTE [ Data Loader Multiprocessing Shutdown Logic ] + # + # Preliminary: + # + # Our data model looks like this (queues are indicated with curly brackets): + # + # main process || + # | || + # {index_queue} || + # | || + # worker processes || DATA + # | || + # {worker_result_queue} || FLOW + # | || + # pin_memory_thread of main process || DIRECTION + # | || + # {data_queue} || + # | || + # data output \/ + # + # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if + # `pin_memory=False`. + # + # + # Terminating multiprocessing logic requires very careful design. In + # particular, we need to make sure that + # + # 1. The iterator gracefully exits the workers when its last reference is + # gone. + # + # In this case, the workers should be gracefully exited because the + # main process may still need to continue to run, and we want cleaning + # up code in the workers to be executed (e.g., releasing GPU memory). + # Naturally, we implement the shutdown logic in `__del__` of + # DataLoaderIterator. + # + # We delay the discussion on the logic in this case until later. + # + # 2. The iterator exits the workers when the loader process and/or worker + # processes exits unexpectedly (e.g., SIGKILL-ed). + # + # We set all workers and `pin_memory_thread` to have `daemon=True`. + # + # You may ask, why can't we make the workers non-daemonic, and + # gracefully exit using the same logic as we have in `__del__` when the + # iterator gets deleted (see 1 above)? + # + # When a process ends, it shuts the all its daemonic children down with + # a SIGTERM (instead of joining them without a timeout). Simiarly for + # threads, but by a different mechanism. This fact, together with a few + # implementation details of multiprocessing, forces us to make workers + # daemonic. All of our problems arise when a DataLoader is used in a + # subprocess, and are caused by multiprocessing code which looks more + # or less like this: + # + # try: + # your_function_using_a_dataloader() + # finally: + # multiprocessing.util._exit_function() + # + # The joining/termination mentioned above happens inside + # `_exit_function()`. Now, if `your_function_using_a_dataloader()` + # throws, the stack trace stored in the exception will prevent the + # frame which uses `DataLoaderIter` to be freed. If the frame has any + # reference to the `DataLoaderIter` (e.g., in a method of the iter), + # its `__del__`, which starts the shutdown procedure, will not be + # called. That, in turn, means that workers aren't notified. Attempting + # to join in `_exit_function` will then result in a hang. + # + # For context, `_exit_function` is also registered as an `atexit` call. + # So it is unclear to me (@ssnl) why this is needed in a finally block. + # The code dates back to 2008 and there is no comment on the original + # PEP 371 or patch https://bugs.python.org/issue3050 (containing both + # the finally block and the `atexit` registration) that explains this. + # + # Another choice is to just shutdown workers with logic in 1 above + # whenever we see an error in `next`. This isn't ideal because + # a. It prevents users from using try-catch to resume data loading. + # b. It doesn't prevent hanging if users have references to the + # iterator. + # + # 3. All processes exit if any of them die unexpectedly (e.g., error, + # fatal signals). + # + # As shown above, the workers are set as daemonic children of the main + # process. However, automatic cleaning-up of such child processes only + # happens if the parent process exits gracefully (e.g., not via fatal + # signals like SIGKILL). So we must ensure that each process will exit + # even the process that should send/receive data to/from it were + # killed, i.e., + # + # a. A process won't hang when getting from a queue. + # + # Even with carefully designed data dependencies (i.e., a `put()` + # always corresponding to a `get()`), hanging on `get()` can still + # happen when data in queue is corrupted (e.g., due to + # `cancel_join_thread` or unexpected exit). + # + # For child exit, we register SIGCHLD handler on main process, + # which checks if any of the workers fail in the (Python) handler. + # See DataLoader.cpp. + # + # For `.get()` calls where the sender(s) is not the workers, we + # guard them with timeouts, and check the status of the sender + # when timeout happens: + # + in the workers, the `ManagerWatchdog` class checks the main + # process status. + # + if `pin_memory=True`, when getting from `pin_memory_thread`, + # check `pin_memory_thread` status periodically until `.get()` + # returns or see that `pin_memory_thread` died. + # + # b. A process won't hang when putting into a queue; + # + # We use `mp.Queue` which has a separate background thread to put + # objects from an unbounded buffer array. The background thread is + # daemonic and usually automatically joined when the process + # exits. + # + # However, in case that the receiver has ended abruptly while + # reading from the pipe, the join will hang forever. Therefore, + # for both `worker_result_queue` (worker -> main process/pin_memory_thread) + # and each `index_queue` (main process -> worker), we use + # `q.cancel_join_thread()` in sender process before any `q.put` to + # prevent this automatic join. + # + # Moreover, having all queues called `cancel_join_thread` makes + # implementing graceful shutdown logic in `__del__` much easier. + # It won't need to get from any queue, which would also need to be + # guarded by periodic status checks. + # + # Note that this may leave corrupted data in the queue, but we + # don't care about the data anyways once we are shutting down. + # + # + # Now let's get back to 1: + # how we gracefully exit the workers when the last reference to the + # iteartor is gone. + # + # To achieve this, we implement the following logic along with the design + # choices mentioned above: + # + # [worker processes] + # While loader process is alive: + # Get from index_queue. + # If got a `None`, exit. + # If get anything else, + # Check `done_event`. + # If set, continue to next iteration + # i.e., keep getting until see the `None`, then exit. + # Otherwise, process data. + # If timed out, + # No matter `done_event` is set (still need to see `None`) or not, + # must continue to next iteration . + # + # [pin_memory_thread] + # # No need to check main thread. If this thread is alive, the main loader + # # thread must be alive, because this thread is set as daemonic. + # While True: + # Get from index_queue. + # If got a `None`, exit. + # If get anything else, + # Check `done_event`. + # If set, continue to next iteration + # i.e., keep getting until see the `None`, then exit. + # Otherwise, process data. + # + # NOTE: we don't check the status of the main thread because + # 1. if the process is killed by fatal signal, `pin_memory_thread` + # ends. + # 2. in other cases, either the cleaning-up in __del__ or the + # automatic exit of daemonic thread will take care of it. + # This won't busy-wait either because `.get(timeout)` does not + # busy-wait. + # + # [main process] + # In the DataLoader Iter's `__del__` + # a. Set `done_event` (shared with `pin_memory_thread` and workers). + # + # Note: from here on, the workers & `pin_memory_thread` may exit at + # any time after they receive `None`. + # + # b. Exit `pin_memory_thread` + # i. Put `None` in `worker_result_queue`. + # ii. Join the `pin_memory_thread`. + # + # c. Exit the workers. + # i. Put `None` in each worker's `index_queue`. + # ii. Join the workers. + # + # NOTE: This has to be after (b) because it may leave corrupted data + # in `worker_result_queue`, which `pin_memory_thread` reads + # from. + # + # NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b) + # can be omitted + # + # NB: `done_event`s isn't strictly needed. E.g., we can just check for + # `None` from `index_queue`, but it allows us to skip wasting resources + # processing indices already in `index_queue` if we are already shutting + # down. + def __init__(self, loader): self.dataset = loader.dataset self.collate_fn = loader.collate_fn @@ -274,18 +510,19 @@ class _DataLoaderIter(object): self.workers = [] for i in range(self.num_workers): index_queue = multiprocessing.Queue() + index_queue.cancel_join_thread() w = multiprocessing.Process( target=_worker_loop, args=(self.dataset, index_queue, self.worker_result_queue, self.done_event, self.collate_fn, base_seed + i, self.worker_init_fn, i)) - w.daemon = True # ensure that the worker exits on process exit - # Process.start() actually take some time as it needs to start a - # process and pass the arguments over via a pipe. Therefore, we - # only add a worker to self.workers list after it started, so - # that we do not call .join() if program dies before it starts, - # and __del__ tries to join it but will get: + w.daemon = True + # NB: Process.start() actually take some time as it needs to + # start a process and pass the arguments over via a pipe. + # Therefore, we only add a worker to self.workers list after + # it started, so that we do not call .join() if program dies + # before it starts, and __del__ tries to join but will get: # AssertionError: can only join a started process. w.start() self.index_queues.append(index_queue) @@ -295,8 +532,8 @@ class _DataLoaderIter(object): self.data_queue = queue.Queue() pin_memory_thread = threading.Thread( target=_pin_memory_loop, - args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory, - torch.cuda.current_device())) + args=(self.worker_result_queue, self.data_queue, + torch.cuda.current_device(), self.done_event)) pin_memory_thread.daemon = True pin_memory_thread.start() # Similar to workers (see comment above), we only register @@ -317,11 +554,25 @@ class _DataLoaderIter(object): return len(self.batch_sampler) def _get_batch(self): + # In the non-timeout case, worker exit is covered by SIGCHLD handler. + # But if `pin_memory=True`, we still need account for the possibility + # that `pin_memory_thread` dies. if self.timeout > 0: try: return self.data_queue.get(timeout=self.timeout) except queue.Empty: raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout)) + elif self.pin_memory: + while self.pin_memory_thread.is_alive(): + try: + return self.data_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) + except queue.Empty: + continue + else: + # while condition is false, i.e., pin_memory_thread died. + raise RuntimeError('Pin memory thread exited unexpectedly') + # In this case, `self.data_queue` is a `queue.Queue`,. But we don't + # need to call `.task_done()` because we don't use `.join()`. else: return self.data_queue.get() @@ -383,29 +634,46 @@ class _DataLoaderIter(object): raise NotImplementedError("_DataLoaderIter cannot be pickled") def _shutdown_workers(self): + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on + # the logic of this function. if not self.shutdown: self.shutdown = True - # removes pids from the C side data structure first so worker + # Removes pids from the C side data structure first so worker # termination afterwards won't trigger false positive error report. if self.worker_pids_set: _remove_worker_pids(id(self)) self.worker_pids_set = False + self.done_event.set() - if self.pin_memory: - # Sending `None` to `pin_memory_thread` must be before - # stopping worker processes because the workers may leave - # corrupted data in `worker_result_queue`, causing - # `pin_memory_thread` unable to read and terminate properly. + + # Exit `pin_memory_thread` first because exiting workers may leave + # corrupted data in `worker_result_queue` which `pin_memory_thread` + # reads from. + if hasattr(self, 'pin_memory_thread'): + # Use hasattr in case error happens before we set the attribute. + # First time do `worker_result_queue.put` in this process. + + # `cancel_join_thread` in case that `pin_memory_thread` exited. + self.worker_result_queue.cancel_join_thread() self.worker_result_queue.put(None) - # Workers can't be waiting to put be cause their output queue - # is a multiprocessing.Queue and its .put is non-blocking. - # They can only be waiting to get, so we put `None` here. + self.pin_memory_thread.join() + + # Indicate that no more data will be put on this queue by the + # current process. This **must** be called after + # `pin_memory_thread` is joined because that thread shares the + # same pipe handles with this loader thread. If the handle is + # closed, Py3 will error in this case, but Py2 will just time + # out even if there is data in the queue. + self.worker_result_queue.close() + + # Exit workers now. for q in self.index_queues: q.put(None) + # Indicate that no more data will be put on this queue by the + # current process. + q.close() for w in self.workers: w.join() - if hasattr(self, 'pin_memory_thread'): - self.pin_memory_thread.join() def __del__(self): if self.num_workers > 0: |