summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore2
-rwxr-xr-x.jenkins/pytorch/test.sh4
-rw-r--r--setup.py1
-rw-r--r--test/run_test.py1
-rw-r--r--test/test_type_hints.py167
-rw-r--r--tools/autograd/gen_python_functions.py65
-rw-r--r--tools/autograd/utils.py6
-rw-r--r--tools/pyi/__init__.py0
-rw-r--r--tools/pyi/gen_pyi.py529
-rw-r--r--torch/CMakeLists.txt18
-rw-r--r--torch/__init__.py1
-rw-r--r--torch/__init__.pyi.in106
-rw-r--r--torch/_six.py8
-rw-r--r--torch/_utils.py10
-rw-r--r--torch/csrc/utils/python_arg_parser.cpp4
-rw-r--r--torch/functional.py3
-rw-r--r--torch/serialization.py2
-rw-r--r--torch/tensor.py4
18 files changed, 910 insertions, 21 deletions
diff --git a/.gitignore b/.gitignore
index f6ac0c66ce..7d7add5e70 100644
--- a/.gitignore
+++ b/.gitignore
@@ -35,11 +35,13 @@ test/data/gpu_tensors.pt
test/data/legacy_modules.t7
test/data/legacy_serialized.pt
test/data/linear.pt
+test/generated_type_hints_smoketest.py
test/htmlcov
test/cpp_extensions/install/
third_party/build/
tools/shared/_utils_internal.py
torch.egg-info/
+torch/__init__.pyi
torch/csrc/autograd/generated/*
torch/csrc/cudnn/cuDNN.cpp
torch/csrc/generated
diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh
index cd6d72563f..ef180b3407 100755
--- a/.jenkins/pytorch/test.sh
+++ b/.jenkins/pytorch/test.sh
@@ -34,6 +34,10 @@ if [[ "$BUILD_ENVIRONMENT" != *ppc64le* ]]; then
# TODO: move this to Docker
pip install -q hypothesis --user
+
+ # mypy will fail to install on Python <3.4. In that case,
+ # we just won't run these tests.
+ pip install mypy --user || true
fi
# DANGER WILL ROBINSON. The LD_PRELOAD here could cause you problems
diff --git a/setup.py b/setup.py
index 465ed78a39..9b79d79801 100644
--- a/setup.py
+++ b/setup.py
@@ -731,6 +731,7 @@ if __name__ == '__main__':
entry_points=entry_points,
package_data={
'torch': [
+ '__init__.pyi',
'lib/*.so*',
'lib/*.dylib*',
'lib/*.dll',
diff --git a/test/run_test.py b/test/run_test.py
index 72e7b22061..24c1641798 100644
--- a/test/run_test.py
+++ b/test/run_test.py
@@ -42,6 +42,7 @@ TESTS = [
'thd_distributed',
'torch',
'type_info',
+ 'type_hints',
'utils',
]
diff --git a/test/test_type_hints.py b/test/test_type_hints.py
new file mode 100644
index 0000000000..c2df94d811
--- /dev/null
+++ b/test/test_type_hints.py
@@ -0,0 +1,167 @@
+from __future__ import print_function
+import unittest
+from common_utils import TestCase, run_tests, download_file
+import tempfile
+import torch
+import re
+import os
+import sys
+import subprocess
+import inspect
+
+try:
+ import mypy
+ HAVE_MYPY = True
+except ImportError:
+ HAVE_MYPY = False
+
+
+def get_examples_from_docstring(docstr):
+ """
+ Extracts all runnable python code from the examples
+ in docstrings; returns a list of lines.
+ """
+ # TODO: Figure out if there's a way to use doctest directly to
+ # implement this
+ example_file_lines = []
+ # the detection is a bit hacky because there isn't a nice way of detecting
+ # where multiline commands end. Thus we keep track of how far we got in beginning
+ # and continue to add lines until we have a compileable Python statement.
+ exampleline_re = re.compile(r"^\s+(?:>>>|\.\.\.) (.*)$")
+ beginning = ""
+ for l in docstr.split('\n'):
+ if beginning:
+ m = exampleline_re.match(l)
+ if m:
+ beginning += m.group(1)
+ else:
+ beginning += l
+ else:
+ m = exampleline_re.match(l)
+ if m:
+ beginning += m.group(1)
+ if beginning:
+ complete = True
+ try:
+ compile(beginning, "", "exec")
+ except SyntaxError:
+ complete = False
+ if complete:
+ # found one
+ example_file_lines += beginning.split('\n')
+ beginning = ""
+ else:
+ beginning += "\n"
+ return [' ' + l for l in example_file_lines]
+
+
+def get_all_examples():
+ """get_all_examples() -> str
+
+ This function grabs (hopefully all) examples from the torch documentation
+ strings and puts them in one nonsensical module returned as a string.
+ """
+ blacklist = {"_np"}
+ allexamples = ""
+
+ example_file_lines = [
+ "import torch",
+ "import torch.nn.functional as F",
+ "import math # type: ignore", # mypy complains about floats where SupportFloat is expected
+ "import numpy # type: ignore",
+ "import io # type: ignore",
+ "import itertools # type: ignore",
+ "",
+ # for requires_grad_ example
+ # NB: We are parsing this file as Python 2, so we must use
+ # Python 2 type annotation syntax
+ "def preprocess(inp):",
+ " # type: (torch.Tensor) -> torch.Tensor",
+ " return inp",
+ ]
+
+ for fname in dir(torch):
+ fn = getattr(torch, fname)
+ docstr = inspect.getdoc(fn)
+ if docstr and fname not in blacklist:
+ e = get_examples_from_docstring(docstr)
+ if e:
+ example_file_lines.append("\n\ndef example_torch_{}():".format(fname))
+ example_file_lines += e
+
+ for fname in dir(torch.Tensor):
+ fn = getattr(torch.Tensor, fname)
+ docstr = inspect.getdoc(fn)
+ if docstr and fname not in blacklist:
+ e = get_examples_from_docstring(docstr)
+ if e:
+ example_file_lines.append("\n\ndef example_torch_tensor_{}():".format(fname))
+ example_file_lines += e
+
+ return "\n".join(example_file_lines)
+
+
+class TestTypeHints(TestCase):
+ @unittest.skipIf(sys.version_info[0] == 2, "no type hints for Python 2")
+ @unittest.skipIf(not HAVE_MYPY, "need mypy")
+ def test_doc_examples(self):
+ """
+ Run documentation examples through mypy.
+ """
+ fn = os.path.join(os.path.dirname(__file__), 'generated_type_hints_smoketest.py')
+ with open(fn, "w") as f:
+ print(get_all_examples(), file=f)
+
+ # OK, so here's the deal. mypy treats installed packages
+ # and local modules differently: if a package is installed,
+ # mypy will refuse to use modules from that package for type
+ # checking unless the module explicitly says that it supports
+ # type checking. (Reference:
+ # https://mypy.readthedocs.io/en/latest/running_mypy.html#missing-imports
+ # )
+ #
+ # Now, PyTorch doesn't support typechecking, and we shouldn't
+ # claim that it supports typechecking (it doesn't.) However, not
+ # claiming we support typechecking is bad for this test, which
+ # wants to use the partial information we get from the bits of
+ # PyTorch which are typed to check if it typechecks. And
+ # although mypy will work directly if you are working in source,
+ # some of our tests involve installing PyTorch and then running
+ # its tests.
+ #
+ # The guidance we got from Michael Sullivan and Joshua Oreman,
+ # and also independently developed by Thomas Viehmann,
+ # is that we should create a fake directory and add symlinks for
+ # the packages that should typecheck. So that is what we do
+ # here.
+ #
+ # If you want to run mypy by hand, and you run from PyTorch
+ # root directory, it should work fine to skip this step (since
+ # mypy will preferentially pick up the local files first). The
+ # temporary directory here is purely needed for CI. For this
+ # reason, we also still drop the generated file in the test
+ # source folder, for ease of inspection when there are failures.
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ try:
+ os.symlink(
+ os.path.dirname(torch.__file__),
+ os.path.join(tmp_dir, 'torch'),
+ target_is_directory=True
+ )
+ except OSError:
+ raise unittest.SkipTest('cannot symlink')
+ try:
+ subprocess.run([
+ sys.executable,
+ '-mmypy',
+ '--follow-imports', 'silent',
+ '--check-untyped-defs',
+ os.path.abspath(fn)],
+ cwd=tmp_dir,
+ check=True)
+ except subprocess.CalledProcessError as e:
+ raise AssertionError("mypy failed. Look above this error for mypy's output.")
+
+
+if __name__ == '__main__':
+ run_tests()
diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py
index ffacdf8e89..f0a9677575 100644
--- a/tools/autograd/gen_python_functions.py
+++ b/tools/autograd/gen_python_functions.py
@@ -182,33 +182,49 @@ def should_generate_python_binding(declaration):
return True
-def gen_py_variable_methods(out, declarations, template_path):
- PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path + '/python_variable_methods.cpp')
- PY_VARIABLE_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_variable_methods_dispatch.h')
-
+def get_py_variable_methods(declarations):
+ """
+ Get declarations (grouped by name) which should be generated
+ as methods on Tensor.
+ """
def should_bind(declaration):
return (should_generate_python_binding(declaration) and
declaration['mode'] != 'NN' and
declaration.get('python_module') != 'nn' and
'Tensor' in declaration['method_of'])
- py_variable_methods = group_declarations_by_name(declarations, should_bind)
+ return group_declarations_by_name(declarations, should_bind)
+
+
+def gen_py_variable_methods(out, declarations, template_path):
+ PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path + '/python_variable_methods.cpp')
+ PY_VARIABLE_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_variable_methods_dispatch.h')
+
+ py_variable_methods = get_py_variable_methods(declarations)
env = create_python_bindings(py_variable_methods, True)
write(out, 'python_variable_methods.cpp', PY_VARIABLE_METHODS_CPP, env)
write(out, 'python_variable_methods_dispatch.h', PY_VARIABLE_DISPATCH_H, env)
+def get_py_nn_functions(declarations):
+ """
+ Get declarations (grouped by name) which should be generated
+ as functions in the "nn" module.
+ """
+ def should_bind(declaration):
+ return (should_generate_python_binding(declaration) and
+ (declaration['mode'] == 'NN' or declaration.get('python_module') == 'nn'))
+
+ return group_declarations_by_name(declarations, should_bind)
+
+
def gen_py_nn_functions(out, declarations, template_path):
PY_NN_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_nn_functions.cpp')
PY_NN_FUNCTIONS_H = CodeTemplate.from_file(template_path + '/python_nn_functions.h')
PY_NN_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_nn_functions_dispatch.h')
- def should_bind(declaration):
- return (should_generate_python_binding(declaration) and
- (declaration['mode'] == 'NN' or declaration.get('python_module') == 'nn'))
-
- py_nn_functions = group_declarations_by_name(declarations, should_bind)
+ py_nn_functions = get_py_nn_functions(declarations)
env = create_python_bindings(py_nn_functions, has_self=False, is_module=True)
write(out, 'python_nn_functions.cpp', PY_NN_FUNCTIONS_CPP, env)
@@ -216,17 +232,25 @@ def gen_py_nn_functions(out, declarations, template_path):
write(out, 'python_nn_functions_dispatch.h', PY_NN_DISPATCH_H, env)
-def gen_py_torch_functions(out, declarations, template_path):
- PY_TORCH_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_torch_functions.cpp')
- PY_TORCH_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_torch_functions_dispatch.h')
-
+def get_py_torch_functions(declarations):
+ """
+ Get declarations (grouped by name) which should be generated
+ as functions in the "torch" module.
+ """
def should_bind(declaration):
return (should_generate_python_binding(declaration) and
declaration['mode'] != 'NN' and
declaration.get('python_module') != 'nn' and
'namespace' in declaration['method_of'])
- py_torch_functions = group_declarations_by_name(declarations, should_bind)
+ return group_declarations_by_name(declarations, should_bind)
+
+
+def gen_py_torch_functions(out, declarations, template_path):
+ PY_TORCH_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_torch_functions.cpp')
+ PY_TORCH_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_torch_functions_dispatch.h')
+
+ py_torch_functions = get_py_torch_functions(declarations)
env = create_python_bindings(py_torch_functions, has_self=False)
write(out, 'python_torch_functions.cpp', PY_TORCH_FUNCTIONS_CPP, env)
@@ -800,7 +824,16 @@ def sort_declarations(grouped_decls):
def get_python_signature(declaration, include_out):
- # Compute the Python function signature for argument parsing
+ # Compute the Python function signature for argument parsing,
+ # as specified in torch/csrc/utils/python_arg_parser.h. WARNING:
+ # this is NOT the same type signature as specified by PEP 484
+ # as understood by mypy; our format was independently developed
+ # and has some quirks to make it more suitable specifically
+ # for error parsing.
+ #
+ # For a translation to mypy-valid type signatures, see
+ # tools/gen_pyi.py. If you change any logic here, please
+ # check that file too.
py_formal_args = []
output_args = []
type_args = []
diff --git a/tools/autograd/utils.py b/tools/autograd/utils.py
index 5b82b4e751..9a799c9e30 100644
--- a/tools/autograd/utils.py
+++ b/tools/autograd/utils.py
@@ -14,6 +14,12 @@ except ImportError:
from tools.shared.module_loader import import_module
CodeTemplate = import_module('code_template', 'aten/src/ATen/code_template.py').CodeTemplate
+# You should use these lines, rather than doing it manually.
+# Especially if you see this error!
+#
+# File "/usr/local/lib/python2.7/dist-packages/yaml/__init__.py", line 69, in load
+# loader = Loader(stream)
+# TypeError: 'module' object is not callable
try:
# use faster C loader if available
from yaml import CLoader as YamlLoader
diff --git a/tools/pyi/__init__.py b/tools/pyi/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tools/pyi/__init__.py
diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py
new file mode 100644
index 0000000000..6905f80f47
--- /dev/null
+++ b/tools/pyi/gen_pyi.py
@@ -0,0 +1,529 @@
+from __future__ import print_function
+import multiprocessing
+import sys
+import os
+import inspect
+import collections
+import yaml
+import types
+import re
+import argparse
+
+from ..autograd.utils import YamlLoader, CodeTemplate, write
+from ..autograd.gen_python_functions import get_py_torch_functions, get_py_variable_methods
+from ..autograd.gen_autograd import load_aten_declarations
+
+"""
+This module implements generation of type stubs for PyTorch,
+enabling use of autocomplete in IDEs like PyCharm, which otherwise
+don't understand C extension modules.
+
+At the moment, this module only handles type stubs for torch and
+torch.Tensor. It should eventually be expanded to cover all functions
+which come are autogenerated.
+
+Here's our general strategy:
+
+- We start off with a hand-written __init__.pyi.in file. This
+ file contains type definitions for everything we cannot automatically
+ generate, including pure Python definitions directly in __init__.py
+ (the latter case should be pretty rare).
+
+- We go through automatically bound functions based on the
+ type information recorded in Declarations.yaml and
+ generate type hints for them (generate_type_hints)
+
+There are a number of type hints which we've special-cased;
+read gen_pyi for the gory details.
+"""
+
+# TODO: Consider defining some aliases for our Union[...] types, to make
+# the stubs to read on the human eye.
+
+needed_modules = set()
+
+FACTORY_PARAMS = "dtype: Optional[_dtype]=None, device: Union[_device, str, None]=None, requires_grad: bool=False"
+
+# this could be more precise w.r.t list contents etc. How to do Ellipsis?
+INDICES = "indices: Union[None, _int, slice, Tensor, List, Tuple]"
+
+blacklist = [
+ '__init_subclass__',
+ '__new__',
+ '__subclasshook__',
+ 'clamp',
+ 'clamp_',
+ 'device',
+ 'grad',
+ 'requires_grad',
+ 'range',
+ # defined in functional
+ 'einsum',
+ # reduction argument; these bindings don't make sense
+ 'ctc_loss',
+ 'cosine_embedding_loss',
+ 'hinge_embedding_loss',
+ 'kl_div',
+ 'margin_ranking_loss',
+ 'triplet_margin_loss',
+ # Somehow, these are defined in both _C and in functional. Ick!
+ 'broadcast_tensors',
+ 'meshgrid',
+ 'cartesian_prod',
+ 'norm',
+ 'chain_matmul',
+ 'stft',
+ 'tensordot',
+ 'norm',
+ 'split',
+ # These are handled specially by python_arg_parser.cpp
+ 'add',
+ 'add_',
+ 'add_out',
+ 'sub',
+ 'sub_',
+ 'sub_out',
+ 'mul',
+ 'mul_',
+ 'mul_out',
+ 'div',
+ 'div_',
+ 'div_out',
+]
+
+
+def type_to_python(typename, size=None):
+ """type_to_python(typename: str, size: str) -> str
+
+ Transforms a Declarations.yaml type name into a Python type specification
+ as used for type hints.
+ """
+ typename = typename.replace(' ', '') # normalize spaces, e.g., 'Generator *'
+
+ # Disambiguate explicitly sized int/tensor lists from implicitly
+ # sized ones. These permit non-list inputs too. (IntList[] and
+ # TensorList[] are not real types; this is just for convenience.)
+ if typename in {'IntList', 'TensorList'} and size is not None:
+ typename += '[]'
+
+ typename = {
+ 'Device': 'Union[_device, str, None]',
+ 'Generator*': 'Generator',
+ 'IntegerTensor': 'Tensor',
+ 'Scalar': 'Number',
+ 'ScalarType': '_dtype',
+ 'Storage': 'Storage',
+ 'BoolTensor': 'Tensor',
+ 'IndexTensor': 'Tensor',
+ 'SparseTensorRef': 'Tensor',
+ 'Tensor': 'Tensor',
+ 'IntList': '_size',
+ 'IntList[]': 'Union[_int, _size]',
+ 'TensorList': 'Union[Tuple[Tensor, ...], List[Tensor]]',
+ 'TensorList[]': 'Union[Tensor, Tuple[Tensor, ...], List[Tensor]]',
+ 'bool': 'bool',
+ 'double': '_float',
+ 'int64_t': '_int',
+ 'accreal': 'Number',
+ 'real': 'Number',
+ 'void*': '_int', # data_ptr
+ 'void': 'None',
+ 'std::string': 'str',
+ }[typename]
+
+ return typename
+
+
+def arg_to_type_hint(arg):
+ """arg_to_type_hint(arg) -> str
+
+ This takes one argument in a Declarations and returns a string
+ representing this argument in a type hint signature.
+ """
+ name = arg['name']
+ if name == 'from': # from is a Python keyword...
+ name += '_'
+ typename = type_to_python(arg['dynamic_type'], arg.get('size'))
+ if arg.get('is_nullable'):
+ typename = 'Optional[' + typename + ']'
+ if 'default' in arg:
+ default = arg['default']
+ if default == 'nullptr':
+ default = None
+ elif default == 'c10::nullopt':
+ default = None
+ elif isinstance(default, str) and default.startswith('{') and default.endswith('}'):
+ if arg['dynamic_type'] == 'Tensor' and default == '{}':
+ default = None
+ elif arg['dynamic_type'] == 'IntList':
+ default = '(' + default[1:-1] + ')'
+ else:
+ raise Exception("Unexpected default constructor argument of type {}".format(arg['dynamic_type']))
+ default = '={}'.format(default)
+ else:
+ default = ''
+ return name + ': ' + typename + default
+
+
+binary_ops = ('add', 'sub', 'mul', 'div', 'pow', 'lshift', 'rshift', 'mod', 'truediv',
+ 'matmul',
+ 'radd', 'rmul', # reverse arithmetic
+ 'and', 'or', 'xor', # logic
+ 'iadd', 'iand', 'idiv', 'ilshift', 'imul',
+ 'ior', 'irshift', 'isub', 'itruediv', 'ixor', # inplace ops
+ )
+comparison_ops = ('eq', 'ne', 'ge', 'gt', 'lt', 'le')
+unary_ops = ('neg', 'abs', 'invert')
+to_py_type_ops = ('bool', 'float', 'long', 'index', 'int', 'nonzero')
+all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops
+
+
+def sig_for_ops(opname):
+ """sig_for_ops(opname : str) -> List[str]
+
+ Returns signatures for operator special functions (__add__ etc.)"""
+
+ # we have to do this by hand, because they are hand-bound in Python
+
+ assert opname.endswith('__') and opname.startswith('__'), "Unexpected op {}".format(opname)
+
+ name = opname[2:-2]
+ if name in binary_ops:
+ return ['def {}(self, other: Any) -> Tensor: ...'.format(opname)]
+ elif name in comparison_ops:
+ # unsafe override https://github.com/python/mypy/issues/5704
+ return ['def {}(self, other: Any) -> Tensor: ... # type: ignore'.format(opname)]
+ elif name in unary_ops:
+ return ['def {}(self) -> Tensor: ...'.format(opname)]
+ elif name in to_py_type_ops:
+ if name in {'bool', 'float'}:
+ tname = name
+ elif name == 'nonzero':
+ tname = 'bool'
+ else:
+ tname = 'int'
+ if tname in {'float', 'int'}:
+ tname = 'builtins.' + tname
+ return ['def {}(self) -> {}: ...'.format(opname, tname)]
+ else:
+ raise Exception("unknown op", opname)
+
+
+def generate_type_hints(fname, decls, is_tensor=False):
+ """generate_type_hints(fname, decls, is_tensor=False)
+
+ Generates type hints for the declarations pertaining to the function
+ :attr:`fname`. attr:`decls` are the declarations from the parsed
+ Declarations.yaml.
+ The :attr:`is_tensor` flag indicates whether we are parsing
+ members of the Tensor class (true) or functions in the
+ `torch` namespace (default, false).
+
+ This function currently encodes quite a bit about the semantics of
+ the translation C++ -> Python.
+ """
+ if fname in blacklist:
+ return []
+
+ type_hints = []
+ dnames = ([d['name'] for d in decls])
+ has_out = fname + '_out' in dnames
+
+ if has_out:
+ decls = [d for d in decls if d['name'] != fname + '_out']
+
+ for decl in decls:
+ render_kw_only_separator = True # whether we add a '*' if we see a keyword only argument
+ python_args = []
+
+ has_tensor_options = 'TensorOptions' in [a['dynamic_type'] for a in decl['arguments']]
+
+ for a in decl['arguments']:
+ if a['dynamic_type'] != 'TensorOptions':
+ if a.get('kwarg_only', False) and render_kw_only_separator:
+ python_args.append('*')
+ render_kw_only_separator = False
+ python_args.append(arg_to_type_hint(a))
+
+ if is_tensor:
+ if 'self: Tensor' in python_args:
+ python_args.remove('self: Tensor')
+ python_args = ['self'] + python_args
+ else:
+ raise Exception("method without self is unexpected")
+
+ if has_out:
+ if render_kw_only_separator:
+ python_args.append('*')
+ render_kw_only_separator = False
+ python_args.append('out: Optional[Tensor]=None')
+
+ if has_tensor_options:
+ if render_kw_only_separator:
+ python_args.append('*')
+ render_kw_only_separator = False
+ python_args += ["dtype: _dtype=None",
+ "layout: layout=strided",
+ "device: Union[_device, str, None]=None",
+ "requires_grad:bool=False"]
+
+ python_args_s = ', '.join(python_args)
+ python_returns = [type_to_python(r['dynamic_type']) for r in decl['returns']]
+
+ if len(python_returns) > 1:
+ python_returns_s = 'Tuple[' + ', '.join(python_returns) + ']'
+ else:
+ python_returns_s = python_returns[0]
+
+ type_hint = "def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s)
+ numargs = len(decl['arguments'])
+ vararg_pos = int(is_tensor)
+ have_vararg_version = (numargs > vararg_pos and
+ decl['arguments'][vararg_pos]['dynamic_type'] in {'IntList', 'TensorList'} and
+ (numargs == vararg_pos + 1 or python_args[vararg_pos + 1] == '*') and
+ (not is_tensor or decl['arguments'][0]['name'] == 'self'))
+
+ type_hints.append(type_hint)
+
+ if have_vararg_version:
+ # Two things come into play here: PyTorch has the "magic" that if the first and only positional argument
+ # is an IntList or TensorList, it will be used as a vararg variant.
+ # The following outputs the vararg variant, the "pass a list variant" is output above.
+ # The other thing is that in Python, the varargs are annotated with the element type, not the list type.
+ typelist = decl['arguments'][vararg_pos]['dynamic_type']
+ if typelist == 'IntList':
+ vararg_type = '_int'
+ else:
+ vararg_type = 'Tensor'
+ # replace first argument and eliminate '*' if present
+ python_args = ((['self'] if is_tensor else []) + ['*' + decl['arguments'][vararg_pos]['name'] +
+ ': ' + vararg_type] + python_args[vararg_pos + 2:])
+ python_args_s = ', '.join(python_args)
+ type_hint = "def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s)
+ type_hints.append(type_hint)
+
+ return type_hints
+
+
+def gen_pyi(declarations_path, out):
+ """gen_pyi()
+
+ This function generates a pyi file for torch.
+ """
+
+ # Some of this logic overlaps with generate_python_signature in
+ # tools/autograd/gen_python_functions.py; however, this
+ # function is all about generating mypy type signatures, whereas
+ # the other function generates are custom format for argument
+ # checking. If you are update this, consider if your change
+ # also needs to update the other file.
+
+ # Load information from YAML
+ declarations = load_aten_declarations(declarations_path)
+
+ # Generate type signatures for top-level functions
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ unsorted_function_hints = collections.defaultdict(list)
+ unsorted_function_hints.update({
+ 'set_flush_denormal': ['def set_flush_denormal(mode: bool) -> bool: ...'],
+ 'get_default_dtype': ['def get_default_dtype() -> _dtype: ...'],
+ 'from_numpy': ['def from_numpy(ndarray) -> Tensor: ...'],
+ 'clamp': ["def clamp(self, min: _float=-inf, max: _float=inf,"
+ " *, out: Optional[Tensor]=None) -> Tensor: ..."],
+ 'as_tensor': ["def as_tensor(data: Any, dtype: _dtype=None, device: Optional[_device]=None) -> Tensor: ..."],
+ 'get_num_threads': ['def get_num_threads() -> _int: ...'],
+ 'set_num_threads': ['def set_num_threads(num: _int) -> None: ...'],
+ # These functions are explicitly disabled by
+ # SKIP_PYTHON_BINDINGS because they are hand bound.
+ # Correspondingly, we must hand-write their signatures.
+ 'tensor': ["def tensor(data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
+ 'sparse_coo_tensor': ['def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List],'
+ ' size: Optional[_size]=None, *, dtype: Optional[_dtype]=None,'
+ ' device: Union[_device, str, None]=None, requires_grad:bool=False) -> Tensor: ...'],
+ 'range': ['def range(start: Number, end: Number,'
+ ' step: Number=1, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
+ .format(FACTORY_PARAMS)],
+ 'arange': ['def arange(start: Number, end: Number, step: Number, *,'
+ ' out: Optional[Tensor]=None, {}) -> Tensor: ...'
+ .format(FACTORY_PARAMS),
+ 'def arange(start: Number, end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
+ .format(FACTORY_PARAMS),
+ 'def arange(end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
+ .format(FACTORY_PARAMS)],
+ 'randint': ['def randint(low: _int, high: _int, size: _size, *, {}) -> Tensor: ...'
+ .format(FACTORY_PARAMS),
+ 'def randint(high: _int, size: _size, *, {}) -> Tensor: ...'
+ .format(FACTORY_PARAMS)],
+ })
+ for binop in ['add', 'sub', 'mul', 'div']:
+ unsorted_function_hints[binop].append(
+ 'def {}(input: Union[Tensor, Number],'
+ ' other: Union[Tensor, Number],'
+ ' *, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop))
+ unsorted_function_hints[binop].append(
+ 'def {}(input: Union[Tensor, Number],'
+ ' value: Number,'
+ ' other: Union[Tensor, Number],'
+ ' *, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop))
+
+ function_declarations = get_py_torch_functions(declarations)
+ for name in sorted(function_declarations.keys()):
+ unsorted_function_hints[name] += generate_type_hints(name, function_declarations[name])
+
+ # Generate type signatures for deprecated functions
+
+ # TODO: Maybe we shouldn't generate type hints for deprecated
+ # functions :) However, examples like those addcdiv rely on these.
+ with open('tools/autograd/deprecated.yaml', 'r') as f:
+ deprecated = yaml.load(f, Loader=YamlLoader)
+ for d in deprecated:
+ name, sig = re.match(r"^([^\(]+)\(([^\)]*)", d['name']).groups()
+ sig = ['*' if p.strip() == '*' else p.split() for p in sig.split(',')]
+ sig = ['*' if p == '*' else (p[1] + ': ' + type_to_python(p[0])) for p in sig]
+ unsorted_function_hints[name].append("def {}({}) -> Tensor: ...".format(name, ', '.join(sig)))
+
+ function_hints = []
+ for name, hints in sorted(unsorted_function_hints.items()):
+ if len(hints) > 1:
+ hints = ['@overload\n' + h for h in hints]
+ function_hints += hints
+
+ # Generate type signatures for Tensor methods
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ unsorted_tensor_method_hints = collections.defaultdict(list)
+ unsorted_tensor_method_hints.update({
+ 'size': ['def size(self) -> Size: ...',
+ 'def size(self, _int) -> _int: ...'],
+ 'stride': ['def stride(self) -> Tuple[_int]: ...',
+ 'def stride(self, _int) -> _int: ...'],
+ 'new_empty': ['def new_empty(self, size: {}, {}) -> Tensor: ...'.
+ format(type_to_python('IntList'), FACTORY_PARAMS)],
+ 'new_ones': ['def new_ones(self, size: {}, {}) -> Tensor: ...'.
+ format(type_to_python('IntList'), FACTORY_PARAMS)],
+ 'new_zeros': ['def new_zeros(self, size: {}, {}) -> Tensor: ...'.
+ format(type_to_python('IntList'), FACTORY_PARAMS)],
+ 'new_full': ['def new_full(self, size: {}, value: {}, {}) -> Tensor: ...'.
+ format(type_to_python('IntList'), type_to_python('Scalar'), FACTORY_PARAMS)],
+ 'new_tensor': ["def new_tensor(self, data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
+ # clamp has no default values in the Declarations
+ 'clamp': ["def clamp(self, min: _float=-inf, max: _float=inf,"
+ " *, out: Optional[Tensor]=None) -> Tensor: ..."],
+ 'clamp_': ["def clamp_(self, min: _float=-inf, max: _float=inf) -> Tensor: ..."],
+ '__getitem__': ["def __getitem__(self, {}) -> Tensor: ...".format(INDICES)],
+ '__setitem__': ["def __setitem__(self, {}, val: Union[Tensor, Number])"
+ " -> None: ...".format(INDICES)],
+ 'tolist': ['def tolist(self) -> List: ...'],
+ 'requires_grad_': ['def requires_grad_(self, mode: bool=True) -> Tensor: ...'],
+ 'element_size': ['def element_size(self) -> _int: ...'],
+ 'dim': ['def dim(self) -> _int: ...'],
+ 'ndimension': ['def ndimension(self) -> _int: ...'],
+ 'nelement': ['def nelement(self) -> _int: ...'],
+ 'cuda': ['def cuda(self, device: Optional[_device]=None, non_blocking: bool=False) -> Tensor: ...'],
+ 'numpy': ['def numpy(self) -> Any: ...'],
+ 'apply_': ['def apply_(self, callable: Callable) -> Tensor: ...'],
+ 'map_': ['def map_(tensor: Tensor, callable: Callable) -> Tensor: ...'],
+ 'copy_': ['def copy_(self, src: Tensor, non_blocking: bool=False) -> Tensor: ...'],
+ 'storage': ['def storage(self) -> Storage: ...'],
+ 'type': ['def type(self, dtype: Union[None, str, _dtype]=None, non_blocking: bool=False)'
+ ' -> Union[str, Tensor]: ...'],
+ 'get_device': ['def get_device(self) -> _int: ...'],
+ 'is_contiguous': ['def is_contiguous(self) -> bool: ...'],
+ 'is_cuda': ['def is_cuda(self) -> bool: ...'],
+ 'is_leaf': ['def is_leaf(self) -> bool: ...'],
+ 'storage_offset': ['def storage_offset(self) -> _int: ...'],
+ 'to': ['def to(self, dtype: _dtype, non_blocking: bool=False, copy: bool=False) -> Tensor: ...',
+ 'def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, '
+ 'non_blocking: bool=False, copy: bool=False) -> Tensor: ...',
+ 'def to(self, other: Tensor, non_blocking: bool=False, copy: bool=False) -> Tensor: ...',
+ ],
+ 'item': ["def item(self) -> Number: ..."],
+ })
+ for binop in ['add', 'sub', 'mul', 'div']:
+ for inplace in [True, False]:
+ out_suffix = ', *, out: Optional[Tensor]=None'
+ if inplace:
+ name += '_'
+ out_suffix = ''
+ unsorted_tensor_method_hints[name].append(
+ 'def {}(self, other: Union[Tensor, Number]{})'
+ ' -> Tensor: ...'.format(name, out_suffix))
+ unsorted_tensor_method_hints[name].append(
+ 'def {}(self, value: Number,'
+ ' other: Union[Tensor, Number]{})'
+ ' -> Tensor: ...'.format(name, out_suffix))
+ simple_conversions = ['byte', 'char', 'cpu', 'double', 'float', 'half', 'int', 'long', 'short']
+ for name in simple_conversions:
+ unsorted_tensor_method_hints[name].append('def {}(self) -> Tensor: ...'.format(name))
+
+ tensor_method_declarations = get_py_variable_methods(declarations)
+ for name in sorted(tensor_method_declarations.keys()):
+ unsorted_tensor_method_hints[name] += \
+ generate_type_hints(name, tensor_method_declarations[name], is_tensor=True)
+
+ for op in all_ops:
+ name = '__{}__'.format(op)
+ unsorted_tensor_method_hints[name] += sig_for_ops(name)
+
+ tensor_method_hints = []
+ for name, hints in sorted(unsorted_tensor_method_hints.items()):
+ if len(hints) > 1:
+ hints = ['@overload\n' + h for h in hints]
+ tensor_method_hints += hints
+
+ # TODO: Missing type hints for nn
+
+ # Generate type signatures for legacy classes
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ # TODO: These are deprecated, maybe we shouldn't type hint them
+ legacy_class_hints = []
+ for c in ('DoubleStorage', 'FloatStorage', 'LongStorage', 'IntStorage',
+ 'ShortStorage', 'CharStorage', 'ByteStorage'):
+ legacy_class_hints.append('class {}(Storage): ...'.format(c))
+
+ for c in ('DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
+ 'ShortTensor', 'CharTensor', 'ByteTensor'):
+ legacy_class_hints.append('class {}(Tensor): ...'.format(c))
+
+ # Generate type signatures for dtype classes
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ # TODO: don't explicitly list dtypes here; get it from canonical
+ # source
+ dtype_class_hints = ['{}: dtype = ...'.format(n)
+ for n in
+ ['float32', 'float', 'float64', 'double', 'float16', 'half',
+ 'uint8', 'int8', 'int16', 'short', 'int32', 'int', 'int64', 'long',
+ 'complex32', 'complex64', 'complex128']]
+
+ # Write out the stub
+ # ~~~~~~~~~~~~~~~~~~
+
+ env = {
+ 'function_hints': function_hints,
+ 'tensor_method_hints': tensor_method_hints,
+ 'legacy_class_hints': legacy_class_hints,
+ 'dtype_class_hints': dtype_class_hints,
+ }
+ TORCH_TYPE_STUBS = CodeTemplate.from_file(os.path.join('torch', '__init__.pyi.in'))
+
+ write(out, 'torch/__init__.pyi', TORCH_TYPE_STUBS, env)
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description='Generate type stubs for PyTorch')
+ parser.add_argument('--declarations-path', metavar='DECL',
+ default='torch/share/ATen/Declarations.yaml',
+ help='path to Declarations.yaml')
+ parser.add_argument('--out', metavar='OUT',
+ default='.',
+ help='path to output directory')
+ args = parser.parse_args()
+ gen_pyi(args.declarations_path, args.out)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt
index dc9d9fd515..5e8ba7185d 100644
--- a/torch/CMakeLists.txt
+++ b/torch/CMakeLists.txt
@@ -713,8 +713,26 @@ if (BUILD_PYTHON)
endif()
endif()
+ add_custom_target(torch_python_stubs DEPENDS "${TORCH_SRC_DIR}/__init__.pyi")
+ # For Declarations.yaml dependency
+ add_dependencies(torch_python_stubs ATEN_CPU_FILES_GEN_TARGET)
+ add_custom_command(
+ OUTPUT
+ "${TORCH_SRC_DIR}/__init__.pyi"
+ COMMAND
+ ${PYCMD} -mtools.pyi.gen_pyi
+ --declarations-path "${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml"
+ DEPENDS
+ "${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml"
+ "${TORCH_SRC_DIR}/__init__.pyi.in"
+ WORKING_DIRECTORY
+ "${TORCH_ROOT}"
+ )
+
add_library(torch_python SHARED ${TORCH_PYTHON_SRCS})
+ add_dependencies(torch_python torch_python_stubs)
+
target_link_libraries(torch_python ${TORCH_PYTHON_LINK_LIBRARIES})
target_compile_definitions(torch_python PRIVATE ${TORCH_PYTHON_COMPILE_DEFINITIONS})
diff --git a/torch/__init__.py b/torch/__init__.py
index bb985deb6f..a40ed1afe0 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -179,6 +179,7 @@ def set_default_dtype(d):
"""
_C._set_default_dtype(d)
+# If you edit these imports, please update torch/__init__.py.in as well
from .random import set_rng_state, get_rng_state, manual_seed, initial_seed
from .serialization import save, load
from ._tensor_str import set_printoptions
diff --git a/torch/__init__.pyi.in b/torch/__init__.pyi.in
new file mode 100644
index 0000000000..348e3bad42
--- /dev/null
+++ b/torch/__init__.pyi.in
@@ -0,0 +1,106 @@
+# ${generated_comment}
+
+from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload
+from torch._six import inf
+
+import builtins
+
+# These identifiers are reexported from other modules. These modules
+# are not mypy-clean yet, so in order to use this stub file usefully
+# from mypy you will need to specify --follow-imports=silent.
+# Not all is lost: these imports still enable IDEs like PyCharm to offer
+# autocomplete.
+#
+# Note: Why does the syntax here look so strange? Import visibility
+# rules in stubs are different from normal Python files! You must use
+# 'from ... import ... as ...' syntax to cause an identifier to be
+# exposed (or use a wildcard); regular syntax is not exposed.
+from .random import set_rng_state as set_rng_state, get_rng_state as get_rng_state, \
+ manual_seed as manual_seed, initial_seed as initial_seed
+from ._tensor_str import set_printoptions as set_printoptions
+from .functional import *
+from .serialization import save as save, load as load
+from .autograd import no_grad as no_grad, enable_grad as enable_grad, \
+ set_grad_enabled as set_grad_enabled
+
+class dtype: ...
+
+class layout: ...
+
+strided : layout = ...
+
+# See https://github.com/python/mypy/issues/4146 for why these workarounds
+# is necessary
+_int = builtins.int
+_float = builtins.float
+
+class device:
+ def __init__(self, device: Union[_int, str, None]=None) -> None: ...
+
+class Generator: ...
+
+class Size(tuple): ...
+
+class Storage: ...
+
+# See https://github.com/python/mypy/issues/4146 for why these workarounds
+# is necessary
+_dtype = dtype
+_device = device
+_size = Union[Size, List[_int], Tuple[_int, ...]]
+
+# Meta-type for "numeric" things; matches our docs
+Number = Union[builtins.int, builtins.float]
+
+# TODO: One downside of doing it this way, is direct use of
+# torch.tensor.Tensor doesn't get type annotations. Nobody
+# should really do that, so maybe this is not so bad.
+class Tensor:
+ dtype: _dtype = ...
+ shape: Size = ...
+ device: _device = ...
+ requires_grad: bool = ...
+ grad: Optional[Tensor] = ...
+
+ ${tensor_method_hints}
+
+ # Manually defined methods from torch/tensor.py
+ def backward(self, gradient: Optional[Tensor]=None, retain_graph: Optional[bool]=None, create_graph: bool=False) -> None: ...
+ def register_hook(self, hook: Callable) -> Any: ...
+ def retain_grad(self) -> None: ...
+ def is_pinned(self) -> bool: ...
+ def is_shared(self) -> bool: ...
+ def share_memory_(self) -> None: ...
+ # TODO: fill in the types for these, or otherwise figure out some
+ # way to not have to write these out again...
+ def argmax(self, dim=None, keepdim=False): ...
+ def argmin(self, dim=None, keepdim=False): ...
+ def argsort(self, dim=None, descending=False): ...
+ def norm(self, p="fro", dim=None, keepdim=False): ...
+ def stft(self, n_fft, hop_length=None, win_length=None, window=None,
+ center=True, pad_mode='reflect', normalized=False, onesided=True): ...
+ def split(self, split_size, dim=0): ...
+ def index_add(self, dim, index, tensor): ...
+ def index_copy(self, dim, index, tensor): ...
+ def index_fill(self, dim, index, value): ...
+ def scatter(self, dim, index, source): ...
+ def scatter_add(self, dim, index, source): ...
+ def masked_scatter(self, mask, tensor): ...
+ def masked_fill(self, mask, value): ...
+ def unique(self, sorted=True, return_inverse=False, dim=None): ...
+
+${function_hints}
+
+${legacy_class_hints}
+
+${dtype_class_hints}
+
+# Pure Python functions defined in torch/__init__.py
+
+def typename(obj) -> str: ...
+def is_tensor(obj) -> bool: ...
+def is_storage(obj) -> bool: ...
+def set_default_tensor_type(type) -> None: ... # ick, what a bad legacy API
+def set_default_dtype(d : _dtype) -> None: ...
+def manager_path() -> str: ...
+def compiled_with_cxx11_abi() -> bool: ...
diff --git a/torch/_six.py b/torch/_six.py
index f6bdd39c30..f23fe61733 100644
--- a/torch/_six.py
+++ b/torch/_six.py
@@ -20,6 +20,7 @@
import itertools
import sys
+import builtins
PY2 = sys.version_info[0] == 2
@@ -48,7 +49,7 @@ else:
if PY2:
FileNotFoundError = IOError
else:
- FileNotFoundError = FileNotFoundError
+ FileNotFoundError = builtins.FileNotFoundError
if PY2:
@@ -71,11 +72,10 @@ def with_metaclass(meta, *bases):
# A portable way of referring to the generator version of map
# in both Python 2 and Python 3.
-# TODO: Move this into an appropriate utility library.
if hasattr(itertools, 'imap'):
- imap = itertools.imap
+ imap = itertools.imap # type: ignore
else:
- imap = map
+ imap = map # type: ignore
if PY3:
diff --git a/torch/_utils.py b/torch/_utils.py
index c593b7aa48..10b9dca83a 100644
--- a/torch/_utils.py
+++ b/torch/_utils.py
@@ -309,3 +309,13 @@ def _take_tensors(tensors, size_limit):
for buf, _ in buf_dict.values():
if len(buf) > 0:
yield buf
+
+
+# annotation decorator to get annotations in a way that is compatible
+# with both Python 2 and 3
+def annotate(ret, **kwargs):
+ def dec(fun):
+ fun.__annotations__ = dict(kwargs)
+ fun.__annotations__['return'] = ret
+ return fun
+ return dec
diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp
index 6c3de1ee43..8ab98804a2 100644
--- a/torch/csrc/utils/python_arg_parser.cpp
+++ b/torch/csrc/utils/python_arg_parser.cpp
@@ -36,6 +36,10 @@ static std::unordered_map<std::string, ParameterType> type_map = {
// numbers to bind to Tensors. Some binary ops have separate Tensor and Scalar
// overloads and binding to the Tensor overload with a number of a different
// type will trigger a type error.
+//
+// If you modify this, you will need to adjust the blacklist in
+// tools/pyi/gen_pyi.py (and add hardcoded signatures for these
+// functions.)
static bool should_allow_numbers_as_tensors(const std::string& name) {
static std::unordered_set<std::string> allowed = {
"add", "add_", "add_out",
diff --git a/torch/functional.py b/torch/functional.py
index 656b2e2ca9..c47a30c1dc 100644
--- a/torch/functional.py
+++ b/torch/functional.py
@@ -4,8 +4,11 @@ from torch._six import inf
from torch._C import _add_docstr
from operator import mul
from functools import reduce
+from collections import Iterable
+from torch._utils import annotate
from itertools import product
import math
+from typing import Optional, Tuple, List, Union
import warnings
__all__ = [
diff --git a/torch/serialization.py b/torch/serialization.py
index 9a6ce83a9e..ef7d415569 100644
--- a/torch/serialization.py
+++ b/torch/serialization.py
@@ -366,7 +366,7 @@ def load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
# Map tensors from GPU 1 to GPU 0
>>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
# Load tensor from io.BytesIO object
- >>> with open('tensor.pt') as f:
+ >>> with open('tensor.pt', 'rb') as f:
buffer = io.BytesIO(f.read())
>>> torch.load(buffer)
"""
diff --git a/torch/tensor.py b/torch/tensor.py
index be936bf38e..ec12e08664 100644
--- a/torch/tensor.py
+++ b/torch/tensor.py
@@ -12,6 +12,10 @@ from torch._C import _add_docstr
# NB: If you subclass Tensor, and want to share the subclassed class
# across processes, you must also update torch/multiprocessing/reductions.py
# to define a ForkingPickler serialization mode for the class.
+#
+# NB: If you add a new method to Tensor, you must update
+# torch/__init__.py.in to add a type annotation for your method;
+# otherwise, it will not show up in autocomplete.
class Tensor(torch._C._TensorBase):
def __deepcopy__(self, memo):
if not self.is_leaf: