summaryrefslogtreecommitdiff
path: root/test/run_test.py
diff options
context:
space:
mode:
authorTongzhou Wang <tongzhou.wang.1994@gmail.com>2018-11-01 19:04:17 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-11-01 19:08:06 -0700
commit6d2b3cc869d2d39832800e8127cb66abf3589311 (patch)
treee4d78e6ab4429cfcc99c474fa50c50ccba28b9f3 /test/run_test.py
parent0fd176fea45abb092c903b40008c7119e21ee9f5 (diff)
downloadpytorch-6d2b3cc869d2d39832800e8127cb66abf3589311.tar.gz
pytorch-6d2b3cc869d2d39832800e8127cb66abf3589311.tar.bz2
pytorch-6d2b3cc869d2d39832800e8127cb66abf3589311.zip
Fix pytest, make it work with run_test.py (#13416)
Summary: Fixes #13326 Also now you can use `run_test.py` with `pytest`. E.g., ``` python run_test.py -vci distributed -pt ``` Yes it works with `distributed` and `cpp_extension`. cc zou3519 vishwakftw Pull Request resolved: https://github.com/pytorch/pytorch/pull/13416 Differential Revision: D12895622 Pulled By: SsnL fbshipit-source-id: 2d18106f3a118d642a666bfb1318f41c859c3df7
Diffstat (limited to 'test/run_test.py')
-rw-r--r--test/run_test.py80
1 files changed, 52 insertions, 28 deletions
diff --git a/test/run_test.py b/test/run_test.py
index 40f874a20c..884c562f9f 100644
--- a/test/run_test.py
+++ b/test/run_test.py
@@ -5,7 +5,6 @@ from __future__ import print_function
import argparse
from datetime import datetime
import os
-import shlex
import shutil
import signal
import subprocess
@@ -13,6 +12,7 @@ import sys
import tempfile
import torch
+import torch._six
from torch.utils import cpp_extension
from common_utils import TEST_WITH_ROCM
import torch.distributed as dist
@@ -97,26 +97,45 @@ def print_to_stderr(message):
def shell(command, cwd=None):
sys.stdout.flush()
sys.stderr.flush()
- return subprocess.call(
- shlex.split(command), universal_newlines=True, cwd=cwd)
-
-
-def get_shell_output(command):
- return subprocess.check_output(shlex.split(command)).decode().strip()
+ # The folloing cool snippet is copied from Py3 core library subprocess.call
+ # only the with
+ # 1. `except KeyboardInterrupt` block added for SIGINT handling.
+ # 2. In Py2, subprocess.Popen doesn't return a context manager, so we do
+ # `p.wait()` in a `final` block for the code to be portable.
+ #
+ # https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323
+ assert not isinstance(command, torch._six.string_classes), "Command to shell should be a list or tuple of tokens"
+ p = subprocess.Popen(command, universal_newlines=True, cwd=cwd)
+ try:
+ return p.wait()
+ except KeyboardInterrupt:
+ # Give `p` a chance to handle KeyboardInterrupt. Without this,
+ # `pytest` can't print errors it collected so far upon KeyboardInterrupt.
+ exit_status = p.wait(timeout=5)
+ if exit_status is not None:
+ return exit_status
+ else:
+ p.kill()
+ raise
+ except: # noqa E722, copied from python core library
+ p.kill()
+ raise
+ finally:
+ # Always call p.wait() to ensure exit
+ p.wait()
-def run_test(python, test_module, test_directory, options):
+def run_test(executable, test_module, test_directory, options):
unittest_args = options.additional_unittest_args
if options.verbose:
unittest_args.append('--verbose')
- unittest_args = ' '.join(unittest_args)
# Can't call `python -m unittest test_*` here because it doesn't run code
# in `if __name__ == '__main__': `. So call `python test_*.py` instead.
- return shell('{} {}.py {}'.format(python, test_module, unittest_args),
- test_directory)
+ command = executable + [test_module + '.py'] + unittest_args
+ return shell(command, test_directory)
-def test_cpp_extensions(python, test_module, test_directory, options):
+def test_cpp_extensions(executable, test_module, test_directory, options):
try:
cpp_extension.verify_ninja_availability()
except RuntimeError:
@@ -124,7 +143,7 @@ def test_cpp_extensions(python, test_module, test_directory, options):
'Ninja is not available. Skipping C++ extensions test. '
"Install ninja with 'pip install ninja' or 'conda install ninja'.")
return 0
- return_code = shell('{} setup.py install --root ./install'.format(python),
+ return_code = shell([sys.executable, 'setup.py', 'install', '--root', './install'],
os.path.join(test_directory, 'cpp_extensions'))
if return_code != 0:
return return_code
@@ -141,12 +160,12 @@ def test_cpp_extensions(python, test_module, test_directory, options):
assert install_directory, 'install_directory must not be empty'
os.environ['PYTHONPATH'] = os.pathsep.join([install_directory, python_path])
- return run_test(python, test_module, test_directory, options)
+ return run_test(executable, test_module, test_directory, options)
finally:
os.environ['PYTHONPATH'] = python_path
-def test_distributed(python, test_module, test_directory, options):
+def test_distributed(executable, test_module, test_directory, options):
mpi_available = subprocess.call('command -v mpiexec', shell=True) == 0
if options.verbose and not mpi_available:
print_to_stderr(
@@ -184,12 +203,12 @@ def test_distributed(python, test_module, test_directory, options):
'mpiexec -n 1 --noprefix bash -c ""', shell=True,
stdout=devnull, stderr=subprocess.STDOUT) == 0 else ''
- mpiexec = 'mpiexec -n 3 {} {}'.format(noprefix_opt, python)
+ mpiexec = ['mpiexec', '-n', '3', noprefix_opt] + executable
return_code = run_test(mpiexec, test_module,
test_directory, options)
else:
- return_code = run_test(python, test_module, test_directory,
+ return_code = run_test(executable, test_module, test_directory,
options)
if return_code != 0:
return return_code
@@ -227,7 +246,10 @@ def parse_args():
action='store_true',
help='print verbose information and test-by-test results')
parser.add_argument(
- '-p', '--python', help='the python interpreter to execute tests with')
+ '-pt', '--pytest', action='store_true',
+ help='If true, use `pytest` to execute the tests. E.g., this runs '
+ 'TestTorch with pytest in verbose and coverage mode: '
+ 'python run_test.py -vci torch -pt')
parser.add_argument(
'-c', '--coverage', action='store_true', help='enable coverage')
parser.add_argument(
@@ -272,13 +294,14 @@ def parse_args():
return parser.parse_args()
-def get_python_command(options):
+def get_executable_command(options):
if options.coverage:
- return 'coverage run --parallel-mode --source torch'
- elif options.python:
- return options.python
+ executable = ['coverage', 'run', '--parallel-mode', '--source torch']
else:
- return os.environ.get('PYCMD', 'python')
+ executable = [sys.executable]
+ if options.pytest:
+ executable += ['-m', 'pytest']
+ return executable
def find_test_index(test, selected_tests, find_last_index=False):
@@ -358,7 +381,8 @@ def get_selected_tests(options):
def main():
options = parse_args()
- python = get_python_command(options)
+ executable = get_executable_command(options) # this is a list
+ print_to_stderr('Test executor: {}'.format(executable))
test_directory = os.path.dirname(os.path.abspath(__file__))
selected_tests = get_selected_tests(options)
@@ -366,7 +390,7 @@ def main():
print_to_stderr('Selected tests: {}'.format(', '.join(selected_tests)))
if options.coverage:
- shell('coverage erase')
+ shell(['coverage', 'erase'])
for test in selected_tests:
test_name = 'test_{}'.format(test)
@@ -375,7 +399,7 @@ def main():
# Printing the date here can help diagnose which tests are slow
print_to_stderr('Running {} ... [{}]'.format(test_name, datetime.now()))
handler = CUSTOM_HANDLERS.get(test_module, run_test)
- return_code = handler(python, test_name, test_directory, options)
+ return_code = handler(executable, test_name, test_directory, options)
assert isinstance(return_code, int) and not isinstance(
return_code, bool), 'Return code should be an integer'
if return_code != 0:
@@ -388,8 +412,8 @@ def main():
raise RuntimeError(message)
if options.coverage:
- shell('coverage combine')
- shell('coverage html')
+ shell(['coverage', 'combine'])
+ shell(['coverage', 'html'])
if __name__ == '__main__':