diff options
-rw-r--r-- | .flake8 | 2 | ||||
-rw-r--r-- | .travis.aten.yml | 2 | ||||
-rw-r--r-- | .travis.yml | 2 | ||||
-rw-r--r-- | test/test_jit.py | 87 | ||||
-rw-r--r-- | torch/_jit_internal.py | 18 | ||||
-rw-r--r-- | torch/jit/quantized.py | 4 |
6 files changed, 60 insertions, 55 deletions
@@ -1,4 +1,4 @@ [flake8] max-line-length = 120 ignore = E203,E305,E402,E721,E741,F401,F403,F405,F821,F841,F999,W503,W504 -exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,tools/amd_build/pyHIPIFY,torch/lib/include,torch/lib/tmp_install +exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,tools/amd_build/pyHIPIFY,torch/lib/include,torch/lib/tmp_install,build,torch/include diff --git a/.travis.aten.yml b/.travis.aten.yml index 0e9d8022aa..2425845496 100644 --- a/.travis.aten.yml +++ b/.travis.aten.yml @@ -27,5 +27,5 @@ matrix: include: env: LINT_CHECK python: "2.7" - install: pip install flake8 + install: pip install flake8-mypy script: flake8 diff --git a/.travis.yml b/.travis.yml index 8d1da417ce..6aa70770fc 100644 --- a/.travis.yml +++ b/.travis.yml @@ -29,7 +29,7 @@ matrix: python: "3.7" dist: xenial # required for Python 3.7 (travis-ci/travis-ci#9069) sudo: required # required for Python 3.7 (travis-ci/travis-ci#9069) - install: pip install flake8 + install: pip install flake8-mypy script: flake8 - name: "MyPy typecheck" python: "3.6" diff --git a/test/test_jit.py b/test/test_jit.py index 7c7824ecdd..9879eebcd7 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -5,11 +5,13 @@ import torch.nn as nn import torch.nn.functional as F import torch.nn.parallel as dp import torch.optim as optim +import torch.cuda import torch.jit.quantized from contextlib import contextmanager from itertools import product, chain import torch.jit.frontend from torch.autograd import Variable, Function +from torch.nn import Module from torch.autograd.function import traceable from torch.testing import assert_allclose from torch.onnx import OperatorExportTypes @@ -44,9 +46,11 @@ from torch._C import TensorType, TupleType, FloatType, IntType, \ ListType, StringType, DictType from copy import deepcopy import random -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Tuple from torch.jit.frontend import NotSupportedError from torch.jit import BatchTensor +from torch import Tensor +from torch.jit.annotations import BroadcastingList2, BroadcastingList3 # For testing truediv in python 2 from test_module.future_div import div_int_future, div_float_future @@ -96,7 +100,7 @@ if WINDOWS: finally: os.unlink(f.name) else: - @contextmanager + @contextmanager # noqa: T484 def TemporaryFileName(): with tempfile.NamedTemporaryFile() as f: yield f.name @@ -2262,7 +2266,7 @@ class TestJit(JitTestCase): with self.assertRaisesRegex(RuntimeError, "Expected a default value"): @torch.jit.script - def hints_bad_types(x, a=10, b=0.5): + def hints_bad_types(x, a=10, b=0.5): # noqa: T484 # type: (Tensor, float, int) -> Tensor return x + a + b @@ -3113,7 +3117,7 @@ class TestScript(JitTestCase): def sum_list(a): # type: (int) -> int sum = 0 - for i in a: + for i in a: # noqa: T484 sum += i return sum @@ -4727,23 +4731,23 @@ a") x = 1 else: x = torch.jit._unwrap_optional(x) - return x + return x # noqa: T484 with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"): @torch.jit.script def or_error(x, y): - # type: (Optional[int], Optional[int]) -> int + # type: (Optional[int], Optional[int]) -> None if x is None or y is None: - print(x + y) + print(x + y) # noqa: T484 with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"): @torch.jit.script def and_error(x, y): - # type: (Optional[int], Optional[int]) -> int + # type: (Optional[int], Optional[int]) -> None if x is None and y is None: pass else: - print(x + y) + print(x + y) # noqa: T484 with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"): @torch.jit.script @@ -4751,7 +4755,7 @@ a") # type: (Optional[int]) -> None x_none = x is not None if x_none: - print(x + 1) + print(x + 1) # noqa: T484 with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"): @torch.jit.script @@ -4759,7 +4763,7 @@ a") # type: (Optional[int], Optional[int]) -> None x_none = x is not None if y is not None and x_none: - print(x + y) + print(x + y) # noqa: T484 def test_while_write_outer_then_read(self): def func(a, b): @@ -5057,10 +5061,11 @@ a") self.checkScript(multiple_returns, [a], optimize=True) with self.assertRaisesRegex(RuntimeError, "but is actually of type None"): - @torch.jit.script + torch.jit.CompilationUnit(''' def no_return_bad_annotation(a): # type: (Tensor) -> Tensor a + 1 + ''') def test_error(self): @torch.jit.script @@ -5654,8 +5659,6 @@ a") hiddens = hx if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell): - from typing import Tuple - class ScriptWrapper(torch.jit.ScriptModule): def __init__(self, cell): super(ScriptWrapper, self).__init__() @@ -6650,7 +6653,7 @@ a") with self.assertRaisesRegex(RuntimeError, r"but instead got value of type tuple"): def foo(): # type: () -> Tensor - return ((3, 4),) + return ((3, 4),) # noqa: T484 @torch.jit.script def bar(): @@ -6769,7 +6772,7 @@ a") if x: y = [1] else: - y = [None] + y = [None] # noqa: T484 return y[0] @torch.jit.script @@ -6815,18 +6818,18 @@ a") print(int_fn((1, 1, 1))) with self.assertRaisesRegex(RuntimeError, "must be a positive integer:"): - @torch.jit.script + @torch.jit.script # noqa: T484 def fn(x): - # type: (BroadcastingListx[int]) -> List[int] + # type: (BroadcastingListx[int]) -> List[int] # noqa: T484 return x - # TODO: the type comment in this seems to trip up flake8 for some reason - # even though we have a noqa comment. Figure out why + # using CU so that flake8 error on int[2] is not raised (noqa not working) with self.assertRaisesRegex(RuntimeError, "Unknown type constructor"): - @torch.jit.script - def nested(x, y): - # type: (int, Tuple[int, int[2]]) -> List[int] # noqa: T484 - return x + cu = torch.jit.CompilationUnit(''' + def nested(x, y): + # type: (int, Tuple[int, int[2]]) -> List[int] + return x # noqa: T484 + ''') def test_ntuple_builtins(self): from torch.nn.modules.utils import _single, _pair, _triple, _quadruple @@ -8349,7 +8352,7 @@ a") with self.assertRaisesRegex(RuntimeError, 'but instead got value of type tuple'): def somefunc(): # type: () -> Tuple[Tuple[Tensor, Tensor]] - return torch.zeros(3, 4), torch.zeros(4, 5) + return torch.zeros(3, 4), torch.zeros(4, 5) # noqa: T484 @torch.jit.script def wrong_return_type(): @@ -9029,7 +9032,7 @@ a") def test(x): # type: (Optional[int]) -> int x = torch.jit._unwrap_optional(x) - x = x + x + x = x + x # noqa: T484 return x self.checkScript(test, (3,)) @@ -9082,14 +9085,14 @@ a") @torch.jit.script def return_tup(x): # type: (Tensor) -> Tuple[Tuple[Tensor, Tensor], Tensor] - return x, x + return x, x # noqa: T484 def test_annotated_script_fn_arg_mismatch(self): with self.assertRaisesRegex(RuntimeError, r"arguments for call are not valid"): @torch.jit.script def tuple_arg(x): # type: (Tuple[Tensor, Tensor]) -> Tensor - return x + 1 + return x + 1 # noqa: T484 def test_script_non_tensor_args_outputs(self): @torch.jit.script @@ -13122,11 +13125,11 @@ class TestAsync(JitTestCase): self.assertEqual(y, y_hat) def test_async_script_capture(self): - class Module(torch.jit.ScriptModule): + class Mod(torch.jit.ScriptModule): __constants__ = ['const'] def __init__(self): - super(Module, self).__init__(False) + super(Mod, self).__init__(False) self.const = 42 self.param = nn.Parameter(torch.randn(2, 2)) @@ -13144,7 +13147,7 @@ class TestAsync(JitTestCase): x1 = torch.rand(3, 4) x2 = torch.rand(5, 6) - m = Module() + m = Mod() y, y_hat = m.wait_script(x1, x2) self.assertEqual(y, y_hat) @@ -13244,9 +13247,9 @@ class TestAsync(JitTestCase): def forward(self, x): return (torch.neg(x), x) - class Module(torch.jit.ScriptModule): + class Mod(torch.jit.ScriptModule): def __init__(self): - super(Module, self).__init__(False) + super(Mod, self).__init__(False) x = torch.rand(3, 3) self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True) @@ -13266,10 +13269,10 @@ class TestAsync(JitTestCase): # return a nested structure of tensors return (tensor_list, tensor_tuple, tensor_tuple[1]) - class Tuple(nn.Module): + class TupleCl(nn.Module): def __init__(self): - super(Tuple, self).__init__() - self.module = Module() + super(TupleCl, self).__init__() + self.module = Mod() def forward(self, x): z = torch.neg(x) @@ -13278,7 +13281,7 @@ class TestAsync(JitTestCase): return tuple(list) x = torch.rand(3, 3) - module = torch.jit.trace(Tuple(), (x), _force_outplace=True) + module = torch.jit.trace(TupleCl(), (x), _force_outplace=True) # Make sure we have forks self.assertGraphContainsExactly(module.graph, kind='prim::fork', num_kind_nodes=2) @@ -13632,16 +13635,16 @@ class TestClassType(JitTestCase): @torch.jit.script class FooTest: def __init__(self, x): - # type: (int) + # type: (int) -> None self.foo = x def incFooTest(self, y): - # type: (int) + # type: (int) -> None self.foo = self.foo + y @torch.jit.script def fn(x): - # type: (int) + # type: (int) -> int foo = FooTest(x) foo.incFooTest(2) return foo.foo @@ -13689,7 +13692,7 @@ class TestClassType(JitTestCase): @torch.jit.script class FooTest: def __init__(self, x): - # type: (bool) + # type: (bool) -> None self.foo = x @torch.jit.script @@ -13718,7 +13721,7 @@ class TestClassType(JitTestCase): @torch.jit.script def fn(foo): - # type: (FooTest) + # type: (FooTest) -> Tensor return foo.attr @torch.jit.script diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 28053c5b37..3667cfe89a 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -9,24 +9,24 @@ import inspect from torch._six import builtins # Tracks standalone weak script functions -compiled_weak_fns = weakref.WeakKeyDictionary() +compiled_weak_fns = weakref.WeakKeyDictionary() # noqa: T484 # Tracks which methods should be converted to strong methods -weak_script_methods = weakref.WeakKeyDictionary() +weak_script_methods = weakref.WeakKeyDictionary() # noqa: T484 # Converted modules and their corresponding WeakScriptModuleProxy objects -weak_modules = weakref.WeakKeyDictionary() +weak_modules = weakref.WeakKeyDictionary() # noqa: T484 # Types that have been declared as weak modules -weak_types = weakref.WeakKeyDictionary() +weak_types = weakref.WeakKeyDictionary() # noqa: T484 # Wrapper functions that can call either of 2 functions depending on a boolean # argument -boolean_dispatched = weakref.WeakKeyDictionary() +boolean_dispatched = weakref.WeakKeyDictionary() # noqa: T484 # Python Op functions that should be ignored by the compiler. These will be replaced # with an operator that always throws an error -ignored_fns = weakref.WeakSet() +ignored_fns = weakref.WeakSet() # noqa: T484 COMPILATION_PENDING = object() COMPILED = object() @@ -223,9 +223,9 @@ except ImportError: def __getitem__(self, types): return DictInstance(types) - Tuple = TupleCls() - List = ListCls() - Dict = DictCls() + Tuple = TupleCls() # noqa: T484 + List = ListCls() # noqa: T484 + Dict = DictCls() # noqa: T484 def is_tuple(ann): return isinstance(ann, TupleInstance) diff --git a/torch/jit/quantized.py b/torch/jit/quantized.py index ec4e89a8b5..b5775b96b6 100644 --- a/torch/jit/quantized.py +++ b/torch/jit/quantized.py @@ -1,7 +1,9 @@ import torch import copy import numbers -from typing import Tuple +from typing import Tuple, Optional +from torch import Tensor +from torch.jit import ScriptModule from torch.nn.utils.rnn import PackedSequence from torch.nn import _VF |