summaryrefslogtreecommitdiff
path: root/torch/__init__.pyi.in
diff options
context:
space:
mode:
authorThomas Viehmann <tv@beamnet.de>2019-01-29 11:19:51 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-01-29 12:14:17 -0800
commit6a6983ed7f14f8335a5b5614928713cb79658281 (patch)
treedd24a3d02312c12a62db4dd14dcb59efc077d244 /torch/__init__.pyi.in
parent3b337e7892a7b54758f273477aa092053f02b410 (diff)
downloadpytorch-6a6983ed7f14f8335a5b5614928713cb79658281.tar.gz
pytorch-6a6983ed7f14f8335a5b5614928713cb79658281.tar.bz2
pytorch-6a6983ed7f14f8335a5b5614928713cb79658281.zip
create type hint stub files for module torch (#12500)
Summary: We have: - This is an initial stab at creating a type stub `torch/__init__.pyi` . - This is only tested on Python 3, since that's the only Python version mypy works on. - So far, we only aim at doing this for torch functions and torch.Tensor. - Quite a few methods and functions have to be typed manually. These are done in `torch/__init__.pyi.in` For me, PyCharm (the non-paid one) didn't seem to indicate errors in the .pyi when opening and seemed to be able to get the type hint for the few functions I tried, but I don't use PyCharm for my usual PyTorch activities, so I didn't extensively try this out. An example of a generated PYI is at [this gist](https://gist.github.com/ezyang/bf9b6a5fa8827c52152858169bcb61b1). Pull Request resolved: https://github.com/pytorch/pytorch/pull/12500 Differential Revision: D13695553 Pulled By: ezyang fbshipit-source-id: 4566c71913ede4e4c23ebc4a72c17151f94e8e21
Diffstat (limited to 'torch/__init__.pyi.in')
-rw-r--r--torch/__init__.pyi.in106
1 files changed, 106 insertions, 0 deletions
diff --git a/torch/__init__.pyi.in b/torch/__init__.pyi.in
new file mode 100644
index 0000000000..348e3bad42
--- /dev/null
+++ b/torch/__init__.pyi.in
@@ -0,0 +1,106 @@
+# ${generated_comment}
+
+from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload
+from torch._six import inf
+
+import builtins
+
+# These identifiers are reexported from other modules. These modules
+# are not mypy-clean yet, so in order to use this stub file usefully
+# from mypy you will need to specify --follow-imports=silent.
+# Not all is lost: these imports still enable IDEs like PyCharm to offer
+# autocomplete.
+#
+# Note: Why does the syntax here look so strange? Import visibility
+# rules in stubs are different from normal Python files! You must use
+# 'from ... import ... as ...' syntax to cause an identifier to be
+# exposed (or use a wildcard); regular syntax is not exposed.
+from .random import set_rng_state as set_rng_state, get_rng_state as get_rng_state, \
+ manual_seed as manual_seed, initial_seed as initial_seed
+from ._tensor_str import set_printoptions as set_printoptions
+from .functional import *
+from .serialization import save as save, load as load
+from .autograd import no_grad as no_grad, enable_grad as enable_grad, \
+ set_grad_enabled as set_grad_enabled
+
+class dtype: ...
+
+class layout: ...
+
+strided : layout = ...
+
+# See https://github.com/python/mypy/issues/4146 for why these workarounds
+# is necessary
+_int = builtins.int
+_float = builtins.float
+
+class device:
+ def __init__(self, device: Union[_int, str, None]=None) -> None: ...
+
+class Generator: ...
+
+class Size(tuple): ...
+
+class Storage: ...
+
+# See https://github.com/python/mypy/issues/4146 for why these workarounds
+# is necessary
+_dtype = dtype
+_device = device
+_size = Union[Size, List[_int], Tuple[_int, ...]]
+
+# Meta-type for "numeric" things; matches our docs
+Number = Union[builtins.int, builtins.float]
+
+# TODO: One downside of doing it this way, is direct use of
+# torch.tensor.Tensor doesn't get type annotations. Nobody
+# should really do that, so maybe this is not so bad.
+class Tensor:
+ dtype: _dtype = ...
+ shape: Size = ...
+ device: _device = ...
+ requires_grad: bool = ...
+ grad: Optional[Tensor] = ...
+
+ ${tensor_method_hints}
+
+ # Manually defined methods from torch/tensor.py
+ def backward(self, gradient: Optional[Tensor]=None, retain_graph: Optional[bool]=None, create_graph: bool=False) -> None: ...
+ def register_hook(self, hook: Callable) -> Any: ...
+ def retain_grad(self) -> None: ...
+ def is_pinned(self) -> bool: ...
+ def is_shared(self) -> bool: ...
+ def share_memory_(self) -> None: ...
+ # TODO: fill in the types for these, or otherwise figure out some
+ # way to not have to write these out again...
+ def argmax(self, dim=None, keepdim=False): ...
+ def argmin(self, dim=None, keepdim=False): ...
+ def argsort(self, dim=None, descending=False): ...
+ def norm(self, p="fro", dim=None, keepdim=False): ...
+ def stft(self, n_fft, hop_length=None, win_length=None, window=None,
+ center=True, pad_mode='reflect', normalized=False, onesided=True): ...
+ def split(self, split_size, dim=0): ...
+ def index_add(self, dim, index, tensor): ...
+ def index_copy(self, dim, index, tensor): ...
+ def index_fill(self, dim, index, value): ...
+ def scatter(self, dim, index, source): ...
+ def scatter_add(self, dim, index, source): ...
+ def masked_scatter(self, mask, tensor): ...
+ def masked_fill(self, mask, value): ...
+ def unique(self, sorted=True, return_inverse=False, dim=None): ...
+
+${function_hints}
+
+${legacy_class_hints}
+
+${dtype_class_hints}
+
+# Pure Python functions defined in torch/__init__.py
+
+def typename(obj) -> str: ...
+def is_tensor(obj) -> bool: ...
+def is_storage(obj) -> bool: ...
+def set_default_tensor_type(type) -> None: ... # ick, what a bad legacy API
+def set_default_dtype(d : _dtype) -> None: ...
+def manager_path() -> str: ...
+def compiled_with_cxx11_abi() -> bool: ...