summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.flake82
-rw-r--r--.travis.aten.yml2
-rw-r--r--.travis.yml2
-rw-r--r--test/test_jit.py87
-rw-r--r--torch/_jit_internal.py18
-rw-r--r--torch/jit/quantized.py4
6 files changed, 60 insertions, 55 deletions
diff --git a/.flake8 b/.flake8
index 8510253cce..9048c94e6a 100644
--- a/.flake8
+++ b/.flake8
@@ -1,4 +1,4 @@
[flake8]
max-line-length = 120
ignore = E203,E305,E402,E721,E741,F401,F403,F405,F821,F841,F999,W503,W504
-exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,tools/amd_build/pyHIPIFY,torch/lib/include,torch/lib/tmp_install
+exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,tools/amd_build/pyHIPIFY,torch/lib/include,torch/lib/tmp_install,build,torch/include
diff --git a/.travis.aten.yml b/.travis.aten.yml
index 0e9d8022aa..2425845496 100644
--- a/.travis.aten.yml
+++ b/.travis.aten.yml
@@ -27,5 +27,5 @@ matrix:
include:
env: LINT_CHECK
python: "2.7"
- install: pip install flake8
+ install: pip install flake8-mypy
script: flake8
diff --git a/.travis.yml b/.travis.yml
index 8d1da417ce..6aa70770fc 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -29,7 +29,7 @@ matrix:
python: "3.7"
dist: xenial # required for Python 3.7 (travis-ci/travis-ci#9069)
sudo: required # required for Python 3.7 (travis-ci/travis-ci#9069)
- install: pip install flake8
+ install: pip install flake8-mypy
script: flake8
- name: "MyPy typecheck"
python: "3.6"
diff --git a/test/test_jit.py b/test/test_jit.py
index 7c7824ecdd..9879eebcd7 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -5,11 +5,13 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel as dp
import torch.optim as optim
+import torch.cuda
import torch.jit.quantized
from contextlib import contextmanager
from itertools import product, chain
import torch.jit.frontend
from torch.autograd import Variable, Function
+from torch.nn import Module
from torch.autograd.function import traceable
from torch.testing import assert_allclose
from torch.onnx import OperatorExportTypes
@@ -44,9 +46,11 @@ from torch._C import TensorType, TupleType, FloatType, IntType, \
ListType, StringType, DictType
from copy import deepcopy
import random
-from typing import List, Dict, Optional
+from typing import List, Dict, Optional, Tuple
from torch.jit.frontend import NotSupportedError
from torch.jit import BatchTensor
+from torch import Tensor
+from torch.jit.annotations import BroadcastingList2, BroadcastingList3
# For testing truediv in python 2
from test_module.future_div import div_int_future, div_float_future
@@ -96,7 +100,7 @@ if WINDOWS:
finally:
os.unlink(f.name)
else:
- @contextmanager
+ @contextmanager # noqa: T484
def TemporaryFileName():
with tempfile.NamedTemporaryFile() as f:
yield f.name
@@ -2262,7 +2266,7 @@ class TestJit(JitTestCase):
with self.assertRaisesRegex(RuntimeError, "Expected a default value"):
@torch.jit.script
- def hints_bad_types(x, a=10, b=0.5):
+ def hints_bad_types(x, a=10, b=0.5): # noqa: T484
# type: (Tensor, float, int) -> Tensor
return x + a + b
@@ -3113,7 +3117,7 @@ class TestScript(JitTestCase):
def sum_list(a):
# type: (int) -> int
sum = 0
- for i in a:
+ for i in a: # noqa: T484
sum += i
return sum
@@ -4727,23 +4731,23 @@ a")
x = 1
else:
x = torch.jit._unwrap_optional(x)
- return x
+ return x # noqa: T484
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
@torch.jit.script
def or_error(x, y):
- # type: (Optional[int], Optional[int]) -> int
+ # type: (Optional[int], Optional[int]) -> None
if x is None or y is None:
- print(x + y)
+ print(x + y) # noqa: T484
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
@torch.jit.script
def and_error(x, y):
- # type: (Optional[int], Optional[int]) -> int
+ # type: (Optional[int], Optional[int]) -> None
if x is None and y is None:
pass
else:
- print(x + y)
+ print(x + y) # noqa: T484
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
@torch.jit.script
@@ -4751,7 +4755,7 @@ a")
# type: (Optional[int]) -> None
x_none = x is not None
if x_none:
- print(x + 1)
+ print(x + 1) # noqa: T484
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
@torch.jit.script
@@ -4759,7 +4763,7 @@ a")
# type: (Optional[int], Optional[int]) -> None
x_none = x is not None
if y is not None and x_none:
- print(x + y)
+ print(x + y) # noqa: T484
def test_while_write_outer_then_read(self):
def func(a, b):
@@ -5057,10 +5061,11 @@ a")
self.checkScript(multiple_returns, [a], optimize=True)
with self.assertRaisesRegex(RuntimeError, "but is actually of type None"):
- @torch.jit.script
+ torch.jit.CompilationUnit('''
def no_return_bad_annotation(a):
# type: (Tensor) -> Tensor
a + 1
+ ''')
def test_error(self):
@torch.jit.script
@@ -5654,8 +5659,6 @@ a")
hiddens = hx
if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
- from typing import Tuple
-
class ScriptWrapper(torch.jit.ScriptModule):
def __init__(self, cell):
super(ScriptWrapper, self).__init__()
@@ -6650,7 +6653,7 @@ a")
with self.assertRaisesRegex(RuntimeError, r"but instead got value of type tuple"):
def foo():
# type: () -> Tensor
- return ((3, 4),)
+ return ((3, 4),) # noqa: T484
@torch.jit.script
def bar():
@@ -6769,7 +6772,7 @@ a")
if x:
y = [1]
else:
- y = [None]
+ y = [None] # noqa: T484
return y[0]
@torch.jit.script
@@ -6815,18 +6818,18 @@ a")
print(int_fn((1, 1, 1)))
with self.assertRaisesRegex(RuntimeError, "must be a positive integer:"):
- @torch.jit.script
+ @torch.jit.script # noqa: T484
def fn(x):
- # type: (BroadcastingListx[int]) -> List[int]
+ # type: (BroadcastingListx[int]) -> List[int] # noqa: T484
return x
- # TODO: the type comment in this seems to trip up flake8 for some reason
- # even though we have a noqa comment. Figure out why
+ # using CU so that flake8 error on int[2] is not raised (noqa not working)
with self.assertRaisesRegex(RuntimeError, "Unknown type constructor"):
- @torch.jit.script
- def nested(x, y):
- # type: (int, Tuple[int, int[2]]) -> List[int] # noqa: T484
- return x
+ cu = torch.jit.CompilationUnit('''
+ def nested(x, y):
+ # type: (int, Tuple[int, int[2]]) -> List[int]
+ return x # noqa: T484
+ ''')
def test_ntuple_builtins(self):
from torch.nn.modules.utils import _single, _pair, _triple, _quadruple
@@ -8349,7 +8352,7 @@ a")
with self.assertRaisesRegex(RuntimeError, 'but instead got value of type tuple'):
def somefunc():
# type: () -> Tuple[Tuple[Tensor, Tensor]]
- return torch.zeros(3, 4), torch.zeros(4, 5)
+ return torch.zeros(3, 4), torch.zeros(4, 5) # noqa: T484
@torch.jit.script
def wrong_return_type():
@@ -9029,7 +9032,7 @@ a")
def test(x):
# type: (Optional[int]) -> int
x = torch.jit._unwrap_optional(x)
- x = x + x
+ x = x + x # noqa: T484
return x
self.checkScript(test, (3,))
@@ -9082,14 +9085,14 @@ a")
@torch.jit.script
def return_tup(x):
# type: (Tensor) -> Tuple[Tuple[Tensor, Tensor], Tensor]
- return x, x
+ return x, x # noqa: T484
def test_annotated_script_fn_arg_mismatch(self):
with self.assertRaisesRegex(RuntimeError, r"arguments for call are not valid"):
@torch.jit.script
def tuple_arg(x):
# type: (Tuple[Tensor, Tensor]) -> Tensor
- return x + 1
+ return x + 1 # noqa: T484
def test_script_non_tensor_args_outputs(self):
@torch.jit.script
@@ -13122,11 +13125,11 @@ class TestAsync(JitTestCase):
self.assertEqual(y, y_hat)
def test_async_script_capture(self):
- class Module(torch.jit.ScriptModule):
+ class Mod(torch.jit.ScriptModule):
__constants__ = ['const']
def __init__(self):
- super(Module, self).__init__(False)
+ super(Mod, self).__init__(False)
self.const = 42
self.param = nn.Parameter(torch.randn(2, 2))
@@ -13144,7 +13147,7 @@ class TestAsync(JitTestCase):
x1 = torch.rand(3, 4)
x2 = torch.rand(5, 6)
- m = Module()
+ m = Mod()
y, y_hat = m.wait_script(x1, x2)
self.assertEqual(y, y_hat)
@@ -13244,9 +13247,9 @@ class TestAsync(JitTestCase):
def forward(self, x):
return (torch.neg(x), x)
- class Module(torch.jit.ScriptModule):
+ class Mod(torch.jit.ScriptModule):
def __init__(self):
- super(Module, self).__init__(False)
+ super(Mod, self).__init__(False)
x = torch.rand(3, 3)
self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True)
@@ -13266,10 +13269,10 @@ class TestAsync(JitTestCase):
# return a nested structure of tensors
return (tensor_list, tensor_tuple, tensor_tuple[1])
- class Tuple(nn.Module):
+ class TupleCl(nn.Module):
def __init__(self):
- super(Tuple, self).__init__()
- self.module = Module()
+ super(TupleCl, self).__init__()
+ self.module = Mod()
def forward(self, x):
z = torch.neg(x)
@@ -13278,7 +13281,7 @@ class TestAsync(JitTestCase):
return tuple(list)
x = torch.rand(3, 3)
- module = torch.jit.trace(Tuple(), (x), _force_outplace=True)
+ module = torch.jit.trace(TupleCl(), (x), _force_outplace=True)
# Make sure we have forks
self.assertGraphContainsExactly(module.graph, kind='prim::fork', num_kind_nodes=2)
@@ -13632,16 +13635,16 @@ class TestClassType(JitTestCase):
@torch.jit.script
class FooTest:
def __init__(self, x):
- # type: (int)
+ # type: (int) -> None
self.foo = x
def incFooTest(self, y):
- # type: (int)
+ # type: (int) -> None
self.foo = self.foo + y
@torch.jit.script
def fn(x):
- # type: (int)
+ # type: (int) -> int
foo = FooTest(x)
foo.incFooTest(2)
return foo.foo
@@ -13689,7 +13692,7 @@ class TestClassType(JitTestCase):
@torch.jit.script
class FooTest:
def __init__(self, x):
- # type: (bool)
+ # type: (bool) -> None
self.foo = x
@torch.jit.script
@@ -13718,7 +13721,7 @@ class TestClassType(JitTestCase):
@torch.jit.script
def fn(foo):
- # type: (FooTest)
+ # type: (FooTest) -> Tensor
return foo.attr
@torch.jit.script
diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py
index 28053c5b37..3667cfe89a 100644
--- a/torch/_jit_internal.py
+++ b/torch/_jit_internal.py
@@ -9,24 +9,24 @@ import inspect
from torch._six import builtins
# Tracks standalone weak script functions
-compiled_weak_fns = weakref.WeakKeyDictionary()
+compiled_weak_fns = weakref.WeakKeyDictionary() # noqa: T484
# Tracks which methods should be converted to strong methods
-weak_script_methods = weakref.WeakKeyDictionary()
+weak_script_methods = weakref.WeakKeyDictionary() # noqa: T484
# Converted modules and their corresponding WeakScriptModuleProxy objects
-weak_modules = weakref.WeakKeyDictionary()
+weak_modules = weakref.WeakKeyDictionary() # noqa: T484
# Types that have been declared as weak modules
-weak_types = weakref.WeakKeyDictionary()
+weak_types = weakref.WeakKeyDictionary() # noqa: T484
# Wrapper functions that can call either of 2 functions depending on a boolean
# argument
-boolean_dispatched = weakref.WeakKeyDictionary()
+boolean_dispatched = weakref.WeakKeyDictionary() # noqa: T484
# Python Op functions that should be ignored by the compiler. These will be replaced
# with an operator that always throws an error
-ignored_fns = weakref.WeakSet()
+ignored_fns = weakref.WeakSet() # noqa: T484
COMPILATION_PENDING = object()
COMPILED = object()
@@ -223,9 +223,9 @@ except ImportError:
def __getitem__(self, types):
return DictInstance(types)
- Tuple = TupleCls()
- List = ListCls()
- Dict = DictCls()
+ Tuple = TupleCls() # noqa: T484
+ List = ListCls() # noqa: T484
+ Dict = DictCls() # noqa: T484
def is_tuple(ann):
return isinstance(ann, TupleInstance)
diff --git a/torch/jit/quantized.py b/torch/jit/quantized.py
index ec4e89a8b5..b5775b96b6 100644
--- a/torch/jit/quantized.py
+++ b/torch/jit/quantized.py
@@ -1,7 +1,9 @@
import torch
import copy
import numbers
-from typing import Tuple
+from typing import Tuple, Optional
+from torch import Tensor
+from torch.jit import ScriptModule
from torch.nn.utils.rnn import PackedSequence
from torch.nn import _VF