diff options
author | SsnL <tongzhou.wang.1994@gmail.com> | 2019-01-10 08:44:32 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-01-10 08:47:27 -0800 |
commit | 9b5ec2a076982c57033f2e345cee3051b55de996 (patch) | |
tree | ddbc4ecf1fb3c6887598e229998e7d752edf30ba /test | |
parent | 0ed3f766e9eb803e8f37f728af9494d756aec9a7 (diff) | |
download | pytorch-9b5ec2a076982c57033f2e345cee3051b55de996.tar.gz pytorch-9b5ec2a076982c57033f2e345cee3051b55de996.tar.bz2 pytorch-9b5ec2a076982c57033f2e345cee3051b55de996.zip |
Fix TestDataLoader.test_proper_exit (#15665)
Summary:
Currently, in `test_proper_exit`,
1. we do not kill the correct input `pid` in the `kill_pid` function
https://github.com/pytorch/pytorch/blob/fe15d6a2c231a7bc1b32781217ed336ccf9adff7/test/test_dataloader.py#L325-L329
2. the Windows command that detects process status doesn't actually work
https://github.com/pytorch/pytorch/blob/fe15d6a2c231a7bc1b32781217ed336ccf9adff7/test/test_dataloader.py#L641-L646
3. `worker_error` and `worker_kill` cases (sometimes?) are not tested because the workers may exit naturally due to the pre-fetching mechanism and a too small `dataset size / batch size`.
In this PR, I, in separate commits:
1. Install `psutil` (a python package specifically built for process monitoring) on some CI builds. (Linux builds installation are done in https://github.com/pietern/pytorch-dockerfiles/pull/29 https://github.com/pietern/pytorch-dockerfiles/pull/30 https://github.com/pytorch/ossci-job-dsl/pull/36 and https://github.com/pytorch/pytorch/pull/15795).
2. Rewrite `test_proper_exit` with `psutil` so we
1. do not rely on the hacky `is_process_alive` https://github.com/pytorch/pytorch/blob/fe15d6a2c231a7bc1b32781217ed336ccf9adff7/test/test_dataloader.py#L640-L653
2. increase the #task per worker so `worker_error` and `worker_kill` properly trigger
3. test error message content to ensure that the loader exits with correct message corresponding to each exiting scenario.
3. Fix Windows data loader not having any mechanism to detect worker failures.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15665
Differential Revision: D13615527
Pulled By: soumith
fbshipit-source-id: cfb2f67837d2d87928a53f00b4d20f09754b7949
Diffstat (limited to 'test')
-rw-r--r-- | test/test_dataloader.py | 176 |
1 files changed, 97 insertions, 79 deletions
diff --git a/test/test_dataloader.py b/test/test_dataloader.py index cdd8139942..291b9c09c0 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -11,6 +11,7 @@ import traceback import unittest import subprocess import itertools +import warnings from torch import multiprocessing as mp from torch.utils.data import _utils, Dataset, TensorDataset, DataLoader, ConcatDataset from torch.utils.data._utils import ExceptionWrapper, MP_STATUS_CHECK_INTERVAL @@ -18,6 +19,16 @@ from torch.utils.data.dataset import random_split from common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS, IS_PPC, NO_MULTIPROCESSING_SPAWN, skipIfRocm, load_tests) +try: + import psutil + HAS_PSUTIL = True +except ImportError: + HAS_PSUTIL = False + warnings.warn( + "psutil not found. Some crucial data loader tests relying on it (e.g., " + "TestDataLoader.test_proper_exit) will not run.") + + # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings load_tests = load_tests @@ -34,7 +45,7 @@ if not NO_MULTIPROCESSING_SPAWN: mp = mp.get_context(method='spawn') -JOIN_TIMEOUT = 17.0 if IS_WINDOWS or IS_PPC else 8.5 +JOIN_TIMEOUT = 17.0 if (IS_WINDOWS or IS_PPC) else 11.0 class TestDatasetRandomSplit(TestCase): @@ -304,42 +315,58 @@ class TestProperExitDataset(object): # See TestDataLoader.test_proper_exit for usage def _test_proper_exit(use_workers, pin_memory, exit_method, hold_iter_reference, - worker_pids, setup_event): + loader_setup_event, tester_setup_event): num_workers = 2 if use_workers else 0 if exit_method == 'worker_error' or exit_method == 'worker_kill': assert use_workers is True - ds = TestProperExitDataset(10, setup_event if exit_method == 'worker_error' else None) + if exit_method == 'worker_error': + worker_error_event = mp.Event() + else: + worker_error_event = None - loader = DataLoader(ds, batch_size=2, shuffle=False, + ds = TestProperExitDataset(12, worker_error_event) + + loader = DataLoader(ds, batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=pin_memory) - error_it = 4 - assert len(loader) > error_it + error_it = 2 + + if use_workers: + # 2 is the magical per-worker prefetch number... + # FIXME: change this after the number becomes configurable. + assert len(loader) > (error_it + 2 + 1) * num_workers it = iter(loader) if use_workers: - for i, w in enumerate(it.workers): - worker_pids[i] = w.pid + workers = it.workers def kill_pid(pid): - if IS_WINDOWS: - os.system('taskkill /PID ' + str(os.getpid()) + ' /F') - else: - os.kill(os.getpid(), signal.SIGKILL) + psutil_p = psutil.Process(pid) + psutil_p.kill() + psutil_p.wait(JOIN_TIMEOUT) + assert not psutil_p.is_running() for i, _ in enumerate(it): if i == 0: if not hold_iter_reference: del it - setup_event.set() + loader_setup_event.set() + tester_setup_event.wait() + # ensure that the workers are still alive + if use_workers: + for w in workers: + assert w.is_alive() + if worker_error_event is not None: + worker_error_event.set() + if i == error_it: - if exit_method == 'main_error': - raise RuntimeError('Error') - elif exit_method == 'main_kill': + if exit_method == 'loader_error': + raise RuntimeError('Loader error') + elif exit_method == 'loader_kill': kill_pid(os.getpid()) elif exit_method == 'worker_kill': - kill_pid(worker_pids[0]) + kill_pid(workers[0].pid) if not hold_iter_reference: # Tries to trigger the __del__ clean-up rather than the automatic @@ -637,22 +664,8 @@ class TestDataLoader(TestCase): pin_memory_thread.join(JOIN_TIMEOUT) self.assertFalse(pin_memory_thread.is_alive()) - @staticmethod - def _is_process_alive(pid, pname): - # There is a chance of a terminated child process's pid being reused by a new unrelated process, - # but since we are looping this check very frequently, we will know that the child process dies - # before the new unrelated process starts. - if IS_WINDOWS: - command = 'tasklist | find "{}" /i'.format(pid) - else: - command = 'ps -p {} -o comm='.format(pid) - p = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True) - (output, err) = p.communicate() - p_status = p.wait() - output = output.decode('utf-8') - return pname in output - @skipIfRocm + @unittest.skipIf(not HAS_PSUTIL, "psutil not found") def test_proper_exit(self): (r'''There might be ConnectionResetError or leaked semaphore warning ''' r'''(due to dirty process exit), but they are all safe to ignore''') @@ -660,28 +673,6 @@ class TestDataLoader(TestCase): # TODO: test the case where the pin_memory_thread triggers an # error/fatal signal. I haven't found out how to properly do that. - # Array to store the worker pids. - worker_pids = mp.Array('i', [-1 for _ in range(10)]) - - def wait_pids(pids, timeout): - r"""Wait for all process specified in pids to exit in given timeout.""" - exit_status = [False for _ in pids] - start_time = time.time() - pname = 'python' - while True: - for i in range(len(pids)): - pid = pids[i] - if not exit_status[i]: - if not TestDataLoader._is_process_alive(pid, pname): - exit_status[i] = True - if all(exit_status): - break - else: - if time.time() - start_time > timeout: - break - time.sleep(0.5) - return exit_status - for use_workers, pin_memory, hold_iter_reference in itertools.product([True, False], repeat=3): # `hold_iter_reference` specifies whether we hold a reference to the # iterator. This is interesting because Python3 error traces holds a @@ -700,46 +691,73 @@ class TestDataLoader(TestCase): # - `None` means that no error happens. # In all cases, all processes should end properly. if use_workers: - exit_methods = [None, 'main_error', 'main_kill', 'worker_kill', 'worker_error'] + exit_methods = [None, 'loader_error', 'loader_kill', 'worker_kill', 'worker_error'] else: - exit_methods = [None, 'main_error', 'main_kill'] + exit_methods = [None, 'loader_error', 'loader_kill'] for exit_method in exit_methods: - # clear pids array first - for i in range(len(worker_pids)): - worker_pids[i] = -1 + desc = [] + desc.append('use_workers={}'.format(use_workers)) + desc.append('pin_memory={}'.format(pin_memory)) + desc.append('hold_iter_reference={}'.format(hold_iter_reference)) + desc.append('exit_method={}'.format(exit_method)) + desc = 'test_proper_exit with ' + ', '.join(desc) # Event that the loader process uses to signal testing process # that various things are setup, including that the worker pids # are specified in `worker_pids` array. - setup_event = mp.Event() - - p = ErrorTrackingProcess(target=_test_proper_exit, - args=(use_workers, pin_memory, exit_method, - hold_iter_reference, worker_pids, setup_event)) - p.start() + loader_setup_event = mp.Event() + + # Event that this process has finished setting up, and the + # loader process can now proceed to trigger error events or + # finish normally. + tester_setup_event = mp.Event() + + loader_p = ErrorTrackingProcess(target=_test_proper_exit, + args=(use_workers, pin_memory, exit_method, + hold_iter_reference, loader_setup_event, + tester_setup_event)) + loader_p.start() + + # Wait for loader process to set everything up, e.g., starting + # workers. + loader_setup_event.wait(timeout=JOIN_TIMEOUT) + if not loader_setup_event.is_set(): + fail_msg = desc + ': loader process failed to setup with given time' + if loader_p.exception is not None: + self.fail(fail_msg + ', and had exception {}'.format(loader_p.exception)) + elif not loader_p.is_alive(): + self.fail(fail_msg + ', and exited with code {} but no exception'.format(loader_p.exitcode)) + else: + self.fail(fail_msg + ', and is still alive.') - # Wait for loader process to set everything up, i.e., filling - # worker pids in `worker_pids`. - setup_event.wait(timeout=JOIN_TIMEOUT) - self.assertTrue(setup_event.is_set(), 'loader process setup timed out') + worker_psutil_p = psutil.Process(loader_p.pid).children() - pids = [pid for pid in worker_pids if pid > 0] + tester_setup_event.set() try: - exit_status = wait_pids(pids, timeout=(MP_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT)) - if not all(exit_status): - self.fail('subprocess (pid(s) {}) not terminated'.format( - ', '.join(p for p, exited in zip(pids, exit_status) if not exited))) - p.join(JOIN_TIMEOUT + MP_STATUS_CHECK_INTERVAL) - self.assertFalse(p.is_alive(), 'loader process not terminated') + loader_p.join(JOIN_TIMEOUT + MP_STATUS_CHECK_INTERVAL) + self.assertFalse(loader_p.is_alive(), desc + ': loader process not terminated') + _, alive = psutil.wait_procs(worker_psutil_p, timeout=(MP_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT)) + if len(alive) > 0: + self.fail(desc + ': worker process (pid(s) {}) not terminated'.format( + ', '.join(str(p.pid) for p in alive))) if exit_method is None: - self.assertEqual(p.exitcode, 0) + self.assertEqual(loader_p.exitcode, 0) else: - self.assertNotEqual(p.exitcode, 0) + self.assertNotEqual(loader_p.exitcode, 0) + if exit_method == 'loader_error': + self.assertIsInstance(loader_p.exception, RuntimeError, desc) + self.assertIn('Loader error', str(loader_p.exception), desc) + elif exit_method == 'worker_kill': + self.assertIsInstance(loader_p.exception, RuntimeError, desc) + self.assertIn('DataLoader worker (pid', str(loader_p.exception), desc) + elif exit_method == 'worker_error': + self.assertIsInstance(loader_p.exception, RuntimeError, desc) + self.assertIn('Worker error', str(loader_p.exception), desc) finally: - p.terminate() + loader_p.terminate() def test_len(self): def check_len(dl, expected): |