diff options
-rw-r--r-- | .gitignore | 2 | ||||
-rwxr-xr-x | .jenkins/pytorch/test.sh | 4 | ||||
-rw-r--r-- | setup.py | 1 | ||||
-rw-r--r-- | test/run_test.py | 1 | ||||
-rw-r--r-- | test/test_type_hints.py | 167 | ||||
-rw-r--r-- | tools/autograd/gen_python_functions.py | 65 | ||||
-rw-r--r-- | tools/autograd/utils.py | 6 | ||||
-rw-r--r-- | tools/pyi/__init__.py | 0 | ||||
-rw-r--r-- | tools/pyi/gen_pyi.py | 529 | ||||
-rw-r--r-- | torch/CMakeLists.txt | 18 | ||||
-rw-r--r-- | torch/__init__.py | 1 | ||||
-rw-r--r-- | torch/__init__.pyi.in | 106 | ||||
-rw-r--r-- | torch/_six.py | 8 | ||||
-rw-r--r-- | torch/_utils.py | 10 | ||||
-rw-r--r-- | torch/csrc/utils/python_arg_parser.cpp | 4 | ||||
-rw-r--r-- | torch/functional.py | 3 | ||||
-rw-r--r-- | torch/serialization.py | 2 | ||||
-rw-r--r-- | torch/tensor.py | 4 |
18 files changed, 910 insertions, 21 deletions
diff --git a/.gitignore b/.gitignore index f6ac0c66ce..7d7add5e70 100644 --- a/.gitignore +++ b/.gitignore @@ -35,11 +35,13 @@ test/data/gpu_tensors.pt test/data/legacy_modules.t7 test/data/legacy_serialized.pt test/data/linear.pt +test/generated_type_hints_smoketest.py test/htmlcov test/cpp_extensions/install/ third_party/build/ tools/shared/_utils_internal.py torch.egg-info/ +torch/__init__.pyi torch/csrc/autograd/generated/* torch/csrc/cudnn/cuDNN.cpp torch/csrc/generated diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index cd6d72563f..ef180b3407 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -34,6 +34,10 @@ if [[ "$BUILD_ENVIRONMENT" != *ppc64le* ]]; then # TODO: move this to Docker pip install -q hypothesis --user + + # mypy will fail to install on Python <3.4. In that case, + # we just won't run these tests. + pip install mypy --user || true fi # DANGER WILL ROBINSON. The LD_PRELOAD here could cause you problems @@ -731,6 +731,7 @@ if __name__ == '__main__': entry_points=entry_points, package_data={ 'torch': [ + '__init__.pyi', 'lib/*.so*', 'lib/*.dylib*', 'lib/*.dll', diff --git a/test/run_test.py b/test/run_test.py index 72e7b22061..24c1641798 100644 --- a/test/run_test.py +++ b/test/run_test.py @@ -42,6 +42,7 @@ TESTS = [ 'thd_distributed', 'torch', 'type_info', + 'type_hints', 'utils', ] diff --git a/test/test_type_hints.py b/test/test_type_hints.py new file mode 100644 index 0000000000..c2df94d811 --- /dev/null +++ b/test/test_type_hints.py @@ -0,0 +1,167 @@ +from __future__ import print_function +import unittest +from common_utils import TestCase, run_tests, download_file +import tempfile +import torch +import re +import os +import sys +import subprocess +import inspect + +try: + import mypy + HAVE_MYPY = True +except ImportError: + HAVE_MYPY = False + + +def get_examples_from_docstring(docstr): + """ + Extracts all runnable python code from the examples + in docstrings; returns a list of lines. + """ + # TODO: Figure out if there's a way to use doctest directly to + # implement this + example_file_lines = [] + # the detection is a bit hacky because there isn't a nice way of detecting + # where multiline commands end. Thus we keep track of how far we got in beginning + # and continue to add lines until we have a compileable Python statement. + exampleline_re = re.compile(r"^\s+(?:>>>|\.\.\.) (.*)$") + beginning = "" + for l in docstr.split('\n'): + if beginning: + m = exampleline_re.match(l) + if m: + beginning += m.group(1) + else: + beginning += l + else: + m = exampleline_re.match(l) + if m: + beginning += m.group(1) + if beginning: + complete = True + try: + compile(beginning, "", "exec") + except SyntaxError: + complete = False + if complete: + # found one + example_file_lines += beginning.split('\n') + beginning = "" + else: + beginning += "\n" + return [' ' + l for l in example_file_lines] + + +def get_all_examples(): + """get_all_examples() -> str + + This function grabs (hopefully all) examples from the torch documentation + strings and puts them in one nonsensical module returned as a string. + """ + blacklist = {"_np"} + allexamples = "" + + example_file_lines = [ + "import torch", + "import torch.nn.functional as F", + "import math # type: ignore", # mypy complains about floats where SupportFloat is expected + "import numpy # type: ignore", + "import io # type: ignore", + "import itertools # type: ignore", + "", + # for requires_grad_ example + # NB: We are parsing this file as Python 2, so we must use + # Python 2 type annotation syntax + "def preprocess(inp):", + " # type: (torch.Tensor) -> torch.Tensor", + " return inp", + ] + + for fname in dir(torch): + fn = getattr(torch, fname) + docstr = inspect.getdoc(fn) + if docstr and fname not in blacklist: + e = get_examples_from_docstring(docstr) + if e: + example_file_lines.append("\n\ndef example_torch_{}():".format(fname)) + example_file_lines += e + + for fname in dir(torch.Tensor): + fn = getattr(torch.Tensor, fname) + docstr = inspect.getdoc(fn) + if docstr and fname not in blacklist: + e = get_examples_from_docstring(docstr) + if e: + example_file_lines.append("\n\ndef example_torch_tensor_{}():".format(fname)) + example_file_lines += e + + return "\n".join(example_file_lines) + + +class TestTypeHints(TestCase): + @unittest.skipIf(sys.version_info[0] == 2, "no type hints for Python 2") + @unittest.skipIf(not HAVE_MYPY, "need mypy") + def test_doc_examples(self): + """ + Run documentation examples through mypy. + """ + fn = os.path.join(os.path.dirname(__file__), 'generated_type_hints_smoketest.py') + with open(fn, "w") as f: + print(get_all_examples(), file=f) + + # OK, so here's the deal. mypy treats installed packages + # and local modules differently: if a package is installed, + # mypy will refuse to use modules from that package for type + # checking unless the module explicitly says that it supports + # type checking. (Reference: + # https://mypy.readthedocs.io/en/latest/running_mypy.html#missing-imports + # ) + # + # Now, PyTorch doesn't support typechecking, and we shouldn't + # claim that it supports typechecking (it doesn't.) However, not + # claiming we support typechecking is bad for this test, which + # wants to use the partial information we get from the bits of + # PyTorch which are typed to check if it typechecks. And + # although mypy will work directly if you are working in source, + # some of our tests involve installing PyTorch and then running + # its tests. + # + # The guidance we got from Michael Sullivan and Joshua Oreman, + # and also independently developed by Thomas Viehmann, + # is that we should create a fake directory and add symlinks for + # the packages that should typecheck. So that is what we do + # here. + # + # If you want to run mypy by hand, and you run from PyTorch + # root directory, it should work fine to skip this step (since + # mypy will preferentially pick up the local files first). The + # temporary directory here is purely needed for CI. For this + # reason, we also still drop the generated file in the test + # source folder, for ease of inspection when there are failures. + with tempfile.TemporaryDirectory() as tmp_dir: + try: + os.symlink( + os.path.dirname(torch.__file__), + os.path.join(tmp_dir, 'torch'), + target_is_directory=True + ) + except OSError: + raise unittest.SkipTest('cannot symlink') + try: + subprocess.run([ + sys.executable, + '-mmypy', + '--follow-imports', 'silent', + '--check-untyped-defs', + os.path.abspath(fn)], + cwd=tmp_dir, + check=True) + except subprocess.CalledProcessError as e: + raise AssertionError("mypy failed. Look above this error for mypy's output.") + + +if __name__ == '__main__': + run_tests() diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index ffacdf8e89..f0a9677575 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -182,33 +182,49 @@ def should_generate_python_binding(declaration): return True -def gen_py_variable_methods(out, declarations, template_path): - PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path + '/python_variable_methods.cpp') - PY_VARIABLE_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_variable_methods_dispatch.h') - +def get_py_variable_methods(declarations): + """ + Get declarations (grouped by name) which should be generated + as methods on Tensor. + """ def should_bind(declaration): return (should_generate_python_binding(declaration) and declaration['mode'] != 'NN' and declaration.get('python_module') != 'nn' and 'Tensor' in declaration['method_of']) - py_variable_methods = group_declarations_by_name(declarations, should_bind) + return group_declarations_by_name(declarations, should_bind) + + +def gen_py_variable_methods(out, declarations, template_path): + PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path + '/python_variable_methods.cpp') + PY_VARIABLE_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_variable_methods_dispatch.h') + + py_variable_methods = get_py_variable_methods(declarations) env = create_python_bindings(py_variable_methods, True) write(out, 'python_variable_methods.cpp', PY_VARIABLE_METHODS_CPP, env) write(out, 'python_variable_methods_dispatch.h', PY_VARIABLE_DISPATCH_H, env) +def get_py_nn_functions(declarations): + """ + Get declarations (grouped by name) which should be generated + as functions in the "nn" module. + """ + def should_bind(declaration): + return (should_generate_python_binding(declaration) and + (declaration['mode'] == 'NN' or declaration.get('python_module') == 'nn')) + + return group_declarations_by_name(declarations, should_bind) + + def gen_py_nn_functions(out, declarations, template_path): PY_NN_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_nn_functions.cpp') PY_NN_FUNCTIONS_H = CodeTemplate.from_file(template_path + '/python_nn_functions.h') PY_NN_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_nn_functions_dispatch.h') - def should_bind(declaration): - return (should_generate_python_binding(declaration) and - (declaration['mode'] == 'NN' or declaration.get('python_module') == 'nn')) - - py_nn_functions = group_declarations_by_name(declarations, should_bind) + py_nn_functions = get_py_nn_functions(declarations) env = create_python_bindings(py_nn_functions, has_self=False, is_module=True) write(out, 'python_nn_functions.cpp', PY_NN_FUNCTIONS_CPP, env) @@ -216,17 +232,25 @@ def gen_py_nn_functions(out, declarations, template_path): write(out, 'python_nn_functions_dispatch.h', PY_NN_DISPATCH_H, env) -def gen_py_torch_functions(out, declarations, template_path): - PY_TORCH_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_torch_functions.cpp') - PY_TORCH_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_torch_functions_dispatch.h') - +def get_py_torch_functions(declarations): + """ + Get declarations (grouped by name) which should be generated + as functions in the "torch" module. + """ def should_bind(declaration): return (should_generate_python_binding(declaration) and declaration['mode'] != 'NN' and declaration.get('python_module') != 'nn' and 'namespace' in declaration['method_of']) - py_torch_functions = group_declarations_by_name(declarations, should_bind) + return group_declarations_by_name(declarations, should_bind) + + +def gen_py_torch_functions(out, declarations, template_path): + PY_TORCH_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_torch_functions.cpp') + PY_TORCH_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_torch_functions_dispatch.h') + + py_torch_functions = get_py_torch_functions(declarations) env = create_python_bindings(py_torch_functions, has_self=False) write(out, 'python_torch_functions.cpp', PY_TORCH_FUNCTIONS_CPP, env) @@ -800,7 +824,16 @@ def sort_declarations(grouped_decls): def get_python_signature(declaration, include_out): - # Compute the Python function signature for argument parsing + # Compute the Python function signature for argument parsing, + # as specified in torch/csrc/utils/python_arg_parser.h. WARNING: + # this is NOT the same type signature as specified by PEP 484 + # as understood by mypy; our format was independently developed + # and has some quirks to make it more suitable specifically + # for error parsing. + # + # For a translation to mypy-valid type signatures, see + # tools/gen_pyi.py. If you change any logic here, please + # check that file too. py_formal_args = [] output_args = [] type_args = [] diff --git a/tools/autograd/utils.py b/tools/autograd/utils.py index 5b82b4e751..9a799c9e30 100644 --- a/tools/autograd/utils.py +++ b/tools/autograd/utils.py @@ -14,6 +14,12 @@ except ImportError: from tools.shared.module_loader import import_module CodeTemplate = import_module('code_template', 'aten/src/ATen/code_template.py').CodeTemplate +# You should use these lines, rather than doing it manually. +# Especially if you see this error! +# +# File "/usr/local/lib/python2.7/dist-packages/yaml/__init__.py", line 69, in load +# loader = Loader(stream) +# TypeError: 'module' object is not callable try: # use faster C loader if available from yaml import CLoader as YamlLoader diff --git a/tools/pyi/__init__.py b/tools/pyi/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/tools/pyi/__init__.py diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py new file mode 100644 index 0000000000..6905f80f47 --- /dev/null +++ b/tools/pyi/gen_pyi.py @@ -0,0 +1,529 @@ +from __future__ import print_function +import multiprocessing +import sys +import os +import inspect +import collections +import yaml +import types +import re +import argparse + +from ..autograd.utils import YamlLoader, CodeTemplate, write +from ..autograd.gen_python_functions import get_py_torch_functions, get_py_variable_methods +from ..autograd.gen_autograd import load_aten_declarations + +""" +This module implements generation of type stubs for PyTorch, +enabling use of autocomplete in IDEs like PyCharm, which otherwise +don't understand C extension modules. + +At the moment, this module only handles type stubs for torch and +torch.Tensor. It should eventually be expanded to cover all functions +which come are autogenerated. + +Here's our general strategy: + +- We start off with a hand-written __init__.pyi.in file. This + file contains type definitions for everything we cannot automatically + generate, including pure Python definitions directly in __init__.py + (the latter case should be pretty rare). + +- We go through automatically bound functions based on the + type information recorded in Declarations.yaml and + generate type hints for them (generate_type_hints) + +There are a number of type hints which we've special-cased; +read gen_pyi for the gory details. +""" + +# TODO: Consider defining some aliases for our Union[...] types, to make +# the stubs to read on the human eye. + +needed_modules = set() + +FACTORY_PARAMS = "dtype: Optional[_dtype]=None, device: Union[_device, str, None]=None, requires_grad: bool=False" + +# this could be more precise w.r.t list contents etc. How to do Ellipsis? +INDICES = "indices: Union[None, _int, slice, Tensor, List, Tuple]" + +blacklist = [ + '__init_subclass__', + '__new__', + '__subclasshook__', + 'clamp', + 'clamp_', + 'device', + 'grad', + 'requires_grad', + 'range', + # defined in functional + 'einsum', + # reduction argument; these bindings don't make sense + 'ctc_loss', + 'cosine_embedding_loss', + 'hinge_embedding_loss', + 'kl_div', + 'margin_ranking_loss', + 'triplet_margin_loss', + # Somehow, these are defined in both _C and in functional. Ick! + 'broadcast_tensors', + 'meshgrid', + 'cartesian_prod', + 'norm', + 'chain_matmul', + 'stft', + 'tensordot', + 'norm', + 'split', + # These are handled specially by python_arg_parser.cpp + 'add', + 'add_', + 'add_out', + 'sub', + 'sub_', + 'sub_out', + 'mul', + 'mul_', + 'mul_out', + 'div', + 'div_', + 'div_out', +] + + +def type_to_python(typename, size=None): + """type_to_python(typename: str, size: str) -> str + + Transforms a Declarations.yaml type name into a Python type specification + as used for type hints. + """ + typename = typename.replace(' ', '') # normalize spaces, e.g., 'Generator *' + + # Disambiguate explicitly sized int/tensor lists from implicitly + # sized ones. These permit non-list inputs too. (IntList[] and + # TensorList[] are not real types; this is just for convenience.) + if typename in {'IntList', 'TensorList'} and size is not None: + typename += '[]' + + typename = { + 'Device': 'Union[_device, str, None]', + 'Generator*': 'Generator', + 'IntegerTensor': 'Tensor', + 'Scalar': 'Number', + 'ScalarType': '_dtype', + 'Storage': 'Storage', + 'BoolTensor': 'Tensor', + 'IndexTensor': 'Tensor', + 'SparseTensorRef': 'Tensor', + 'Tensor': 'Tensor', + 'IntList': '_size', + 'IntList[]': 'Union[_int, _size]', + 'TensorList': 'Union[Tuple[Tensor, ...], List[Tensor]]', + 'TensorList[]': 'Union[Tensor, Tuple[Tensor, ...], List[Tensor]]', + 'bool': 'bool', + 'double': '_float', + 'int64_t': '_int', + 'accreal': 'Number', + 'real': 'Number', + 'void*': '_int', # data_ptr + 'void': 'None', + 'std::string': 'str', + }[typename] + + return typename + + +def arg_to_type_hint(arg): + """arg_to_type_hint(arg) -> str + + This takes one argument in a Declarations and returns a string + representing this argument in a type hint signature. + """ + name = arg['name'] + if name == 'from': # from is a Python keyword... + name += '_' + typename = type_to_python(arg['dynamic_type'], arg.get('size')) + if arg.get('is_nullable'): + typename = 'Optional[' + typename + ']' + if 'default' in arg: + default = arg['default'] + if default == 'nullptr': + default = None + elif default == 'c10::nullopt': + default = None + elif isinstance(default, str) and default.startswith('{') and default.endswith('}'): + if arg['dynamic_type'] == 'Tensor' and default == '{}': + default = None + elif arg['dynamic_type'] == 'IntList': + default = '(' + default[1:-1] + ')' + else: + raise Exception("Unexpected default constructor argument of type {}".format(arg['dynamic_type'])) + default = '={}'.format(default) + else: + default = '' + return name + ': ' + typename + default + + +binary_ops = ('add', 'sub', 'mul', 'div', 'pow', 'lshift', 'rshift', 'mod', 'truediv', + 'matmul', + 'radd', 'rmul', # reverse arithmetic + 'and', 'or', 'xor', # logic + 'iadd', 'iand', 'idiv', 'ilshift', 'imul', + 'ior', 'irshift', 'isub', 'itruediv', 'ixor', # inplace ops + ) +comparison_ops = ('eq', 'ne', 'ge', 'gt', 'lt', 'le') +unary_ops = ('neg', 'abs', 'invert') +to_py_type_ops = ('bool', 'float', 'long', 'index', 'int', 'nonzero') +all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops + + +def sig_for_ops(opname): + """sig_for_ops(opname : str) -> List[str] + + Returns signatures for operator special functions (__add__ etc.)""" + + # we have to do this by hand, because they are hand-bound in Python + + assert opname.endswith('__') and opname.startswith('__'), "Unexpected op {}".format(opname) + + name = opname[2:-2] + if name in binary_ops: + return ['def {}(self, other: Any) -> Tensor: ...'.format(opname)] + elif name in comparison_ops: + # unsafe override https://github.com/python/mypy/issues/5704 + return ['def {}(self, other: Any) -> Tensor: ... # type: ignore'.format(opname)] + elif name in unary_ops: + return ['def {}(self) -> Tensor: ...'.format(opname)] + elif name in to_py_type_ops: + if name in {'bool', 'float'}: + tname = name + elif name == 'nonzero': + tname = 'bool' + else: + tname = 'int' + if tname in {'float', 'int'}: + tname = 'builtins.' + tname + return ['def {}(self) -> {}: ...'.format(opname, tname)] + else: + raise Exception("unknown op", opname) + + +def generate_type_hints(fname, decls, is_tensor=False): + """generate_type_hints(fname, decls, is_tensor=False) + + Generates type hints for the declarations pertaining to the function + :attr:`fname`. attr:`decls` are the declarations from the parsed + Declarations.yaml. + The :attr:`is_tensor` flag indicates whether we are parsing + members of the Tensor class (true) or functions in the + `torch` namespace (default, false). + + This function currently encodes quite a bit about the semantics of + the translation C++ -> Python. + """ + if fname in blacklist: + return [] + + type_hints = [] + dnames = ([d['name'] for d in decls]) + has_out = fname + '_out' in dnames + + if has_out: + decls = [d for d in decls if d['name'] != fname + '_out'] + + for decl in decls: + render_kw_only_separator = True # whether we add a '*' if we see a keyword only argument + python_args = [] + + has_tensor_options = 'TensorOptions' in [a['dynamic_type'] for a in decl['arguments']] + + for a in decl['arguments']: + if a['dynamic_type'] != 'TensorOptions': + if a.get('kwarg_only', False) and render_kw_only_separator: + python_args.append('*') + render_kw_only_separator = False + python_args.append(arg_to_type_hint(a)) + + if is_tensor: + if 'self: Tensor' in python_args: + python_args.remove('self: Tensor') + python_args = ['self'] + python_args + else: + raise Exception("method without self is unexpected") + + if has_out: + if render_kw_only_separator: + python_args.append('*') + render_kw_only_separator = False + python_args.append('out: Optional[Tensor]=None') + + if has_tensor_options: + if render_kw_only_separator: + python_args.append('*') + render_kw_only_separator = False + python_args += ["dtype: _dtype=None", + "layout: layout=strided", + "device: Union[_device, str, None]=None", + "requires_grad:bool=False"] + + python_args_s = ', '.join(python_args) + python_returns = [type_to_python(r['dynamic_type']) for r in decl['returns']] + + if len(python_returns) > 1: + python_returns_s = 'Tuple[' + ', '.join(python_returns) + ']' + else: + python_returns_s = python_returns[0] + + type_hint = "def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s) + numargs = len(decl['arguments']) + vararg_pos = int(is_tensor) + have_vararg_version = (numargs > vararg_pos and + decl['arguments'][vararg_pos]['dynamic_type'] in {'IntList', 'TensorList'} and + (numargs == vararg_pos + 1 or python_args[vararg_pos + 1] == '*') and + (not is_tensor or decl['arguments'][0]['name'] == 'self')) + + type_hints.append(type_hint) + + if have_vararg_version: + # Two things come into play here: PyTorch has the "magic" that if the first and only positional argument + # is an IntList or TensorList, it will be used as a vararg variant. + # The following outputs the vararg variant, the "pass a list variant" is output above. + # The other thing is that in Python, the varargs are annotated with the element type, not the list type. + typelist = decl['arguments'][vararg_pos]['dynamic_type'] + if typelist == 'IntList': + vararg_type = '_int' + else: + vararg_type = 'Tensor' + # replace first argument and eliminate '*' if present + python_args = ((['self'] if is_tensor else []) + ['*' + decl['arguments'][vararg_pos]['name'] + + ': ' + vararg_type] + python_args[vararg_pos + 2:]) + python_args_s = ', '.join(python_args) + type_hint = "def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s) + type_hints.append(type_hint) + + return type_hints + + +def gen_pyi(declarations_path, out): + """gen_pyi() + + This function generates a pyi file for torch. + """ + + # Some of this logic overlaps with generate_python_signature in + # tools/autograd/gen_python_functions.py; however, this + # function is all about generating mypy type signatures, whereas + # the other function generates are custom format for argument + # checking. If you are update this, consider if your change + # also needs to update the other file. + + # Load information from YAML + declarations = load_aten_declarations(declarations_path) + + # Generate type signatures for top-level functions + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + unsorted_function_hints = collections.defaultdict(list) + unsorted_function_hints.update({ + 'set_flush_denormal': ['def set_flush_denormal(mode: bool) -> bool: ...'], + 'get_default_dtype': ['def get_default_dtype() -> _dtype: ...'], + 'from_numpy': ['def from_numpy(ndarray) -> Tensor: ...'], + 'clamp': ["def clamp(self, min: _float=-inf, max: _float=inf," + " *, out: Optional[Tensor]=None) -> Tensor: ..."], + 'as_tensor': ["def as_tensor(data: Any, dtype: _dtype=None, device: Optional[_device]=None) -> Tensor: ..."], + 'get_num_threads': ['def get_num_threads() -> _int: ...'], + 'set_num_threads': ['def set_num_threads(num: _int) -> None: ...'], + # These functions are explicitly disabled by + # SKIP_PYTHON_BINDINGS because they are hand bound. + # Correspondingly, we must hand-write their signatures. + 'tensor': ["def tensor(data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)], + 'sparse_coo_tensor': ['def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List],' + ' size: Optional[_size]=None, *, dtype: Optional[_dtype]=None,' + ' device: Union[_device, str, None]=None, requires_grad:bool=False) -> Tensor: ...'], + 'range': ['def range(start: Number, end: Number,' + ' step: Number=1, *, out: Optional[Tensor]=None, {}) -> Tensor: ...' + .format(FACTORY_PARAMS)], + 'arange': ['def arange(start: Number, end: Number, step: Number, *,' + ' out: Optional[Tensor]=None, {}) -> Tensor: ...' + .format(FACTORY_PARAMS), + 'def arange(start: Number, end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...' + .format(FACTORY_PARAMS), + 'def arange(end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...' + .format(FACTORY_PARAMS)], + 'randint': ['def randint(low: _int, high: _int, size: _size, *, {}) -> Tensor: ...' + .format(FACTORY_PARAMS), + 'def randint(high: _int, size: _size, *, {}) -> Tensor: ...' + .format(FACTORY_PARAMS)], + }) + for binop in ['add', 'sub', 'mul', 'div']: + unsorted_function_hints[binop].append( + 'def {}(input: Union[Tensor, Number],' + ' other: Union[Tensor, Number],' + ' *, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop)) + unsorted_function_hints[binop].append( + 'def {}(input: Union[Tensor, Number],' + ' value: Number,' + ' other: Union[Tensor, Number],' + ' *, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop)) + + function_declarations = get_py_torch_functions(declarations) + for name in sorted(function_declarations.keys()): + unsorted_function_hints[name] += generate_type_hints(name, function_declarations[name]) + + # Generate type signatures for deprecated functions + + # TODO: Maybe we shouldn't generate type hints for deprecated + # functions :) However, examples like those addcdiv rely on these. + with open('tools/autograd/deprecated.yaml', 'r') as f: + deprecated = yaml.load(f, Loader=YamlLoader) + for d in deprecated: + name, sig = re.match(r"^([^\(]+)\(([^\)]*)", d['name']).groups() + sig = ['*' if p.strip() == '*' else p.split() for p in sig.split(',')] + sig = ['*' if p == '*' else (p[1] + ': ' + type_to_python(p[0])) for p in sig] + unsorted_function_hints[name].append("def {}({}) -> Tensor: ...".format(name, ', '.join(sig))) + + function_hints = [] + for name, hints in sorted(unsorted_function_hints.items()): + if len(hints) > 1: + hints = ['@overload\n' + h for h in hints] + function_hints += hints + + # Generate type signatures for Tensor methods + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + unsorted_tensor_method_hints = collections.defaultdict(list) + unsorted_tensor_method_hints.update({ + 'size': ['def size(self) -> Size: ...', + 'def size(self, _int) -> _int: ...'], + 'stride': ['def stride(self) -> Tuple[_int]: ...', + 'def stride(self, _int) -> _int: ...'], + 'new_empty': ['def new_empty(self, size: {}, {}) -> Tensor: ...'. + format(type_to_python('IntList'), FACTORY_PARAMS)], + 'new_ones': ['def new_ones(self, size: {}, {}) -> Tensor: ...'. + format(type_to_python('IntList'), FACTORY_PARAMS)], + 'new_zeros': ['def new_zeros(self, size: {}, {}) -> Tensor: ...'. + format(type_to_python('IntList'), FACTORY_PARAMS)], + 'new_full': ['def new_full(self, size: {}, value: {}, {}) -> Tensor: ...'. + format(type_to_python('IntList'), type_to_python('Scalar'), FACTORY_PARAMS)], + 'new_tensor': ["def new_tensor(self, data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)], + # clamp has no default values in the Declarations + 'clamp': ["def clamp(self, min: _float=-inf, max: _float=inf," + " *, out: Optional[Tensor]=None) -> Tensor: ..."], + 'clamp_': ["def clamp_(self, min: _float=-inf, max: _float=inf) -> Tensor: ..."], + '__getitem__': ["def __getitem__(self, {}) -> Tensor: ...".format(INDICES)], + '__setitem__': ["def __setitem__(self, {}, val: Union[Tensor, Number])" + " -> None: ...".format(INDICES)], + 'tolist': ['def tolist(self) -> List: ...'], + 'requires_grad_': ['def requires_grad_(self, mode: bool=True) -> Tensor: ...'], + 'element_size': ['def element_size(self) -> _int: ...'], + 'dim': ['def dim(self) -> _int: ...'], + 'ndimension': ['def ndimension(self) -> _int: ...'], + 'nelement': ['def nelement(self) -> _int: ...'], + 'cuda': ['def cuda(self, device: Optional[_device]=None, non_blocking: bool=False) -> Tensor: ...'], + 'numpy': ['def numpy(self) -> Any: ...'], + 'apply_': ['def apply_(self, callable: Callable) -> Tensor: ...'], + 'map_': ['def map_(tensor: Tensor, callable: Callable) -> Tensor: ...'], + 'copy_': ['def copy_(self, src: Tensor, non_blocking: bool=False) -> Tensor: ...'], + 'storage': ['def storage(self) -> Storage: ...'], + 'type': ['def type(self, dtype: Union[None, str, _dtype]=None, non_blocking: bool=False)' + ' -> Union[str, Tensor]: ...'], + 'get_device': ['def get_device(self) -> _int: ...'], + 'is_contiguous': ['def is_contiguous(self) -> bool: ...'], + 'is_cuda': ['def is_cuda(self) -> bool: ...'], + 'is_leaf': ['def is_leaf(self) -> bool: ...'], + 'storage_offset': ['def storage_offset(self) -> _int: ...'], + 'to': ['def to(self, dtype: _dtype, non_blocking: bool=False, copy: bool=False) -> Tensor: ...', + 'def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, ' + 'non_blocking: bool=False, copy: bool=False) -> Tensor: ...', + 'def to(self, other: Tensor, non_blocking: bool=False, copy: bool=False) -> Tensor: ...', + ], + 'item': ["def item(self) -> Number: ..."], + }) + for binop in ['add', 'sub', 'mul', 'div']: + for inplace in [True, False]: + out_suffix = ', *, out: Optional[Tensor]=None' + if inplace: + name += '_' + out_suffix = '' + unsorted_tensor_method_hints[name].append( + 'def {}(self, other: Union[Tensor, Number]{})' + ' -> Tensor: ...'.format(name, out_suffix)) + unsorted_tensor_method_hints[name].append( + 'def {}(self, value: Number,' + ' other: Union[Tensor, Number]{})' + ' -> Tensor: ...'.format(name, out_suffix)) + simple_conversions = ['byte', 'char', 'cpu', 'double', 'float', 'half', 'int', 'long', 'short'] + for name in simple_conversions: + unsorted_tensor_method_hints[name].append('def {}(self) -> Tensor: ...'.format(name)) + + tensor_method_declarations = get_py_variable_methods(declarations) + for name in sorted(tensor_method_declarations.keys()): + unsorted_tensor_method_hints[name] += \ + generate_type_hints(name, tensor_method_declarations[name], is_tensor=True) + + for op in all_ops: + name = '__{}__'.format(op) + unsorted_tensor_method_hints[name] += sig_for_ops(name) + + tensor_method_hints = [] + for name, hints in sorted(unsorted_tensor_method_hints.items()): + if len(hints) > 1: + hints = ['@overload\n' + h for h in hints] + tensor_method_hints += hints + + # TODO: Missing type hints for nn + + # Generate type signatures for legacy classes + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # TODO: These are deprecated, maybe we shouldn't type hint them + legacy_class_hints = [] + for c in ('DoubleStorage', 'FloatStorage', 'LongStorage', 'IntStorage', + 'ShortStorage', 'CharStorage', 'ByteStorage'): + legacy_class_hints.append('class {}(Storage): ...'.format(c)) + + for c in ('DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor', + 'ShortTensor', 'CharTensor', 'ByteTensor'): + legacy_class_hints.append('class {}(Tensor): ...'.format(c)) + + # Generate type signatures for dtype classes + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # TODO: don't explicitly list dtypes here; get it from canonical + # source + dtype_class_hints = ['{}: dtype = ...'.format(n) + for n in + ['float32', 'float', 'float64', 'double', 'float16', 'half', + 'uint8', 'int8', 'int16', 'short', 'int32', 'int', 'int64', 'long', + 'complex32', 'complex64', 'complex128']] + + # Write out the stub + # ~~~~~~~~~~~~~~~~~~ + + env = { + 'function_hints': function_hints, + 'tensor_method_hints': tensor_method_hints, + 'legacy_class_hints': legacy_class_hints, + 'dtype_class_hints': dtype_class_hints, + } + TORCH_TYPE_STUBS = CodeTemplate.from_file(os.path.join('torch', '__init__.pyi.in')) + + write(out, 'torch/__init__.pyi', TORCH_TYPE_STUBS, env) + + +def main(): + parser = argparse.ArgumentParser( + description='Generate type stubs for PyTorch') + parser.add_argument('--declarations-path', metavar='DECL', + default='torch/share/ATen/Declarations.yaml', + help='path to Declarations.yaml') + parser.add_argument('--out', metavar='OUT', + default='.', + help='path to output directory') + args = parser.parse_args() + gen_pyi(args.declarations_path, args.out) + + +if __name__ == '__main__': + main() diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index dc9d9fd515..5e8ba7185d 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -713,8 +713,26 @@ if (BUILD_PYTHON) endif() endif() + add_custom_target(torch_python_stubs DEPENDS "${TORCH_SRC_DIR}/__init__.pyi") + # For Declarations.yaml dependency + add_dependencies(torch_python_stubs ATEN_CPU_FILES_GEN_TARGET) + add_custom_command( + OUTPUT + "${TORCH_SRC_DIR}/__init__.pyi" + COMMAND + ${PYCMD} -mtools.pyi.gen_pyi + --declarations-path "${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml" + DEPENDS + "${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml" + "${TORCH_SRC_DIR}/__init__.pyi.in" + WORKING_DIRECTORY + "${TORCH_ROOT}" + ) + add_library(torch_python SHARED ${TORCH_PYTHON_SRCS}) + add_dependencies(torch_python torch_python_stubs) + target_link_libraries(torch_python ${TORCH_PYTHON_LINK_LIBRARIES}) target_compile_definitions(torch_python PRIVATE ${TORCH_PYTHON_COMPILE_DEFINITIONS}) diff --git a/torch/__init__.py b/torch/__init__.py index bb985deb6f..a40ed1afe0 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -179,6 +179,7 @@ def set_default_dtype(d): """ _C._set_default_dtype(d) +# If you edit these imports, please update torch/__init__.py.in as well from .random import set_rng_state, get_rng_state, manual_seed, initial_seed from .serialization import save, load from ._tensor_str import set_printoptions diff --git a/torch/__init__.pyi.in b/torch/__init__.pyi.in new file mode 100644 index 0000000000..348e3bad42 --- /dev/null +++ b/torch/__init__.pyi.in @@ -0,0 +1,106 @@ +# ${generated_comment} + +from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload +from torch._six import inf + +import builtins + +# These identifiers are reexported from other modules. These modules +# are not mypy-clean yet, so in order to use this stub file usefully +# from mypy you will need to specify --follow-imports=silent. +# Not all is lost: these imports still enable IDEs like PyCharm to offer +# autocomplete. +# +# Note: Why does the syntax here look so strange? Import visibility +# rules in stubs are different from normal Python files! You must use +# 'from ... import ... as ...' syntax to cause an identifier to be +# exposed (or use a wildcard); regular syntax is not exposed. +from .random import set_rng_state as set_rng_state, get_rng_state as get_rng_state, \ + manual_seed as manual_seed, initial_seed as initial_seed +from ._tensor_str import set_printoptions as set_printoptions +from .functional import * +from .serialization import save as save, load as load +from .autograd import no_grad as no_grad, enable_grad as enable_grad, \ + set_grad_enabled as set_grad_enabled + +class dtype: ... + +class layout: ... + +strided : layout = ... + +# See https://github.com/python/mypy/issues/4146 for why these workarounds +# is necessary +_int = builtins.int +_float = builtins.float + +class device: + def __init__(self, device: Union[_int, str, None]=None) -> None: ... + +class Generator: ... + +class Size(tuple): ... + +class Storage: ... + +# See https://github.com/python/mypy/issues/4146 for why these workarounds +# is necessary +_dtype = dtype +_device = device +_size = Union[Size, List[_int], Tuple[_int, ...]] + +# Meta-type for "numeric" things; matches our docs +Number = Union[builtins.int, builtins.float] + +# TODO: One downside of doing it this way, is direct use of +# torch.tensor.Tensor doesn't get type annotations. Nobody +# should really do that, so maybe this is not so bad. +class Tensor: + dtype: _dtype = ... + shape: Size = ... + device: _device = ... + requires_grad: bool = ... + grad: Optional[Tensor] = ... + + ${tensor_method_hints} + + # Manually defined methods from torch/tensor.py + def backward(self, gradient: Optional[Tensor]=None, retain_graph: Optional[bool]=None, create_graph: bool=False) -> None: ... + def register_hook(self, hook: Callable) -> Any: ... + def retain_grad(self) -> None: ... + def is_pinned(self) -> bool: ... + def is_shared(self) -> bool: ... + def share_memory_(self) -> None: ... + # TODO: fill in the types for these, or otherwise figure out some + # way to not have to write these out again... + def argmax(self, dim=None, keepdim=False): ... + def argmin(self, dim=None, keepdim=False): ... + def argsort(self, dim=None, descending=False): ... + def norm(self, p="fro", dim=None, keepdim=False): ... + def stft(self, n_fft, hop_length=None, win_length=None, window=None, + center=True, pad_mode='reflect', normalized=False, onesided=True): ... + def split(self, split_size, dim=0): ... + def index_add(self, dim, index, tensor): ... + def index_copy(self, dim, index, tensor): ... + def index_fill(self, dim, index, value): ... + def scatter(self, dim, index, source): ... + def scatter_add(self, dim, index, source): ... + def masked_scatter(self, mask, tensor): ... + def masked_fill(self, mask, value): ... + def unique(self, sorted=True, return_inverse=False, dim=None): ... + +${function_hints} + +${legacy_class_hints} + +${dtype_class_hints} + +# Pure Python functions defined in torch/__init__.py + +def typename(obj) -> str: ... +def is_tensor(obj) -> bool: ... +def is_storage(obj) -> bool: ... +def set_default_tensor_type(type) -> None: ... # ick, what a bad legacy API +def set_default_dtype(d : _dtype) -> None: ... +def manager_path() -> str: ... +def compiled_with_cxx11_abi() -> bool: ... diff --git a/torch/_six.py b/torch/_six.py index f6bdd39c30..f23fe61733 100644 --- a/torch/_six.py +++ b/torch/_six.py @@ -20,6 +20,7 @@ import itertools import sys +import builtins PY2 = sys.version_info[0] == 2 @@ -48,7 +49,7 @@ else: if PY2: FileNotFoundError = IOError else: - FileNotFoundError = FileNotFoundError + FileNotFoundError = builtins.FileNotFoundError if PY2: @@ -71,11 +72,10 @@ def with_metaclass(meta, *bases): # A portable way of referring to the generator version of map # in both Python 2 and Python 3. -# TODO: Move this into an appropriate utility library. if hasattr(itertools, 'imap'): - imap = itertools.imap + imap = itertools.imap # type: ignore else: - imap = map + imap = map # type: ignore if PY3: diff --git a/torch/_utils.py b/torch/_utils.py index c593b7aa48..10b9dca83a 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -309,3 +309,13 @@ def _take_tensors(tensors, size_limit): for buf, _ in buf_dict.values(): if len(buf) > 0: yield buf + + +# annotation decorator to get annotations in a way that is compatible +# with both Python 2 and 3 +def annotate(ret, **kwargs): + def dec(fun): + fun.__annotations__ = dict(kwargs) + fun.__annotations__['return'] = ret + return fun + return dec diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 6c3de1ee43..8ab98804a2 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -36,6 +36,10 @@ static std::unordered_map<std::string, ParameterType> type_map = { // numbers to bind to Tensors. Some binary ops have separate Tensor and Scalar // overloads and binding to the Tensor overload with a number of a different // type will trigger a type error. +// +// If you modify this, you will need to adjust the blacklist in +// tools/pyi/gen_pyi.py (and add hardcoded signatures for these +// functions.) static bool should_allow_numbers_as_tensors(const std::string& name) { static std::unordered_set<std::string> allowed = { "add", "add_", "add_out", diff --git a/torch/functional.py b/torch/functional.py index 656b2e2ca9..c47a30c1dc 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -4,8 +4,11 @@ from torch._six import inf from torch._C import _add_docstr from operator import mul from functools import reduce +from collections import Iterable +from torch._utils import annotate from itertools import product import math +from typing import Optional, Tuple, List, Union import warnings __all__ = [ diff --git a/torch/serialization.py b/torch/serialization.py index 9a6ce83a9e..ef7d415569 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -366,7 +366,7 @@ def load(f, map_location=None, pickle_module=pickle, **pickle_load_args): # Map tensors from GPU 1 to GPU 0 >>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'}) # Load tensor from io.BytesIO object - >>> with open('tensor.pt') as f: + >>> with open('tensor.pt', 'rb') as f: buffer = io.BytesIO(f.read()) >>> torch.load(buffer) """ diff --git a/torch/tensor.py b/torch/tensor.py index be936bf38e..ec12e08664 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -12,6 +12,10 @@ from torch._C import _add_docstr # NB: If you subclass Tensor, and want to share the subclassed class # across processes, you must also update torch/multiprocessing/reductions.py # to define a ForkingPickler serialization mode for the class. +# +# NB: If you add a new method to Tensor, you must update +# torch/__init__.py.in to add a type annotation for your method; +# otherwise, it will not show up in autocomplete. class Tensor(torch._C._TensorBase): def __deepcopy__(self, memo): if not self.is_leaf: |