diff options
author | Thomas Viehmann <tv@beamnet.de> | 2019-01-29 11:19:51 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-01-29 12:14:17 -0800 |
commit | 6a6983ed7f14f8335a5b5614928713cb79658281 (patch) | |
tree | dd24a3d02312c12a62db4dd14dcb59efc077d244 /torch/__init__.pyi.in | |
parent | 3b337e7892a7b54758f273477aa092053f02b410 (diff) | |
download | pytorch-6a6983ed7f14f8335a5b5614928713cb79658281.tar.gz pytorch-6a6983ed7f14f8335a5b5614928713cb79658281.tar.bz2 pytorch-6a6983ed7f14f8335a5b5614928713cb79658281.zip |
create type hint stub files for module torch (#12500)
Summary:
We have:
- This is an initial stab at creating a type stub `torch/__init__.pyi` .
- This is only tested on Python 3, since that's the only Python version mypy
works on.
- So far, we only aim at doing this for torch functions and torch.Tensor.
- Quite a few methods and functions have to be typed manually. These are
done in `torch/__init__.pyi.in`
For me, PyCharm (the non-paid one) didn't seem to indicate errors in the .pyi when opening and seemed to be able to get the type hint for the few functions I tried, but I don't use PyCharm for my usual PyTorch activities, so I didn't extensively try this out.
An example of a generated PYI is at [this gist](https://gist.github.com/ezyang/bf9b6a5fa8827c52152858169bcb61b1).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12500
Differential Revision: D13695553
Pulled By: ezyang
fbshipit-source-id: 4566c71913ede4e4c23ebc4a72c17151f94e8e21
Diffstat (limited to 'torch/__init__.pyi.in')
-rw-r--r-- | torch/__init__.pyi.in | 106 |
1 files changed, 106 insertions, 0 deletions
diff --git a/torch/__init__.pyi.in b/torch/__init__.pyi.in new file mode 100644 index 0000000000..348e3bad42 --- /dev/null +++ b/torch/__init__.pyi.in @@ -0,0 +1,106 @@ +# ${generated_comment} + +from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload +from torch._six import inf + +import builtins + +# These identifiers are reexported from other modules. These modules +# are not mypy-clean yet, so in order to use this stub file usefully +# from mypy you will need to specify --follow-imports=silent. +# Not all is lost: these imports still enable IDEs like PyCharm to offer +# autocomplete. +# +# Note: Why does the syntax here look so strange? Import visibility +# rules in stubs are different from normal Python files! You must use +# 'from ... import ... as ...' syntax to cause an identifier to be +# exposed (or use a wildcard); regular syntax is not exposed. +from .random import set_rng_state as set_rng_state, get_rng_state as get_rng_state, \ + manual_seed as manual_seed, initial_seed as initial_seed +from ._tensor_str import set_printoptions as set_printoptions +from .functional import * +from .serialization import save as save, load as load +from .autograd import no_grad as no_grad, enable_grad as enable_grad, \ + set_grad_enabled as set_grad_enabled + +class dtype: ... + +class layout: ... + +strided : layout = ... + +# See https://github.com/python/mypy/issues/4146 for why these workarounds +# is necessary +_int = builtins.int +_float = builtins.float + +class device: + def __init__(self, device: Union[_int, str, None]=None) -> None: ... + +class Generator: ... + +class Size(tuple): ... + +class Storage: ... + +# See https://github.com/python/mypy/issues/4146 for why these workarounds +# is necessary +_dtype = dtype +_device = device +_size = Union[Size, List[_int], Tuple[_int, ...]] + +# Meta-type for "numeric" things; matches our docs +Number = Union[builtins.int, builtins.float] + +# TODO: One downside of doing it this way, is direct use of +# torch.tensor.Tensor doesn't get type annotations. Nobody +# should really do that, so maybe this is not so bad. +class Tensor: + dtype: _dtype = ... + shape: Size = ... + device: _device = ... + requires_grad: bool = ... + grad: Optional[Tensor] = ... + + ${tensor_method_hints} + + # Manually defined methods from torch/tensor.py + def backward(self, gradient: Optional[Tensor]=None, retain_graph: Optional[bool]=None, create_graph: bool=False) -> None: ... + def register_hook(self, hook: Callable) -> Any: ... + def retain_grad(self) -> None: ... + def is_pinned(self) -> bool: ... + def is_shared(self) -> bool: ... + def share_memory_(self) -> None: ... + # TODO: fill in the types for these, or otherwise figure out some + # way to not have to write these out again... + def argmax(self, dim=None, keepdim=False): ... + def argmin(self, dim=None, keepdim=False): ... + def argsort(self, dim=None, descending=False): ... + def norm(self, p="fro", dim=None, keepdim=False): ... + def stft(self, n_fft, hop_length=None, win_length=None, window=None, + center=True, pad_mode='reflect', normalized=False, onesided=True): ... + def split(self, split_size, dim=0): ... + def index_add(self, dim, index, tensor): ... + def index_copy(self, dim, index, tensor): ... + def index_fill(self, dim, index, value): ... + def scatter(self, dim, index, source): ... + def scatter_add(self, dim, index, source): ... + def masked_scatter(self, mask, tensor): ... + def masked_fill(self, mask, value): ... + def unique(self, sorted=True, return_inverse=False, dim=None): ... + +${function_hints} + +${legacy_class_hints} + +${dtype_class_hints} + +# Pure Python functions defined in torch/__init__.py + +def typename(obj) -> str: ... +def is_tensor(obj) -> bool: ... +def is_storage(obj) -> bool: ... +def set_default_tensor_type(type) -> None: ... # ick, what a bad legacy API +def set_default_dtype(d : _dtype) -> None: ... +def manager_path() -> str: ... +def compiled_with_cxx11_abi() -> bool: ... |