summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorEdward Yang <ezyang@fb.com>2019-03-21 09:06:30 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-03-21 09:10:30 -0700
commitd1497debf2aeb27235afb3d73a970388ac8100ab (patch)
tree32bcd7887a3d3304f6b25a8a35edcb2b5fc0951b /test
parentba81074c4088f9b9445ca4d3f2a9463afab4845c (diff)
downloadpytorch-d1497debf2aeb27235afb3d73a970388ac8100ab.tar.gz
pytorch-d1497debf2aeb27235afb3d73a970388ac8100ab.tar.bz2
pytorch-d1497debf2aeb27235afb3d73a970388ac8100ab.zip
Fix B903 lint: save memory for data classes with slots/namedtuple (#18184)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18184 ghimport-source-id: 2ce860b07c58d06dc10cd7e5b97d4ef7c709a50d Stack from [ghstack](https://github.com/ezyang/ghstack): * **#18184 Fix B903 lint: save memory for data classes with slots/namedtuple** * #18181 Fix B902 lint error: invalid first argument. * #18178 Fix B006 lint errors: using mutable structure in default argument. * #18177 Fix lstrip bug revealed by B005 lint Signed-off-by: Edward Z. Yang <ezyang@fb.com> Differential Revision: D14530872 fbshipit-source-id: e26cecab3a8545e7638454c28e654e7b82a3c08a
Diffstat (limited to 'test')
-rw-r--r--test/common_methods_invocations.py5
-rw-r--r--test/test_jit.py8
2 files changed, 6 insertions, 7 deletions
diff --git a/test/common_methods_invocations.py b/test/common_methods_invocations.py
index 756bba9256..abb7856fde 100644
--- a/test/common_methods_invocations.py
+++ b/test/common_methods_invocations.py
@@ -2,6 +2,7 @@ import torch
from torch._six import inf, nan, istuple
from functools import reduce, wraps
from operator import mul, itemgetter
+import collections
from torch.autograd import Variable, Function, detect_anomaly
from torch.testing import make_non_contiguous
from common_utils import (skipIfNoLapack,
@@ -72,9 +73,7 @@ def prod_zeros(dim_size, dim_select):
return result
-class non_differentiable(object):
- def __init__(self, tensor):
- self.tensor = tensor
+non_differentiable = collections.namedtuple('non_differentiable', ['tensor'])
class dont_convert(tuple):
diff --git a/test/test_jit.py b/test/test_jit.py
index c4ab4e8202..0a194bc515 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -13942,7 +13942,7 @@ class TestClassType(JitTestCase):
self.assertEqual(fn(input), input)
def test_get_attr(self):
- @torch.jit.script
+ @torch.jit.script # noqa: B903
class FooTest:
def __init__(self, x):
self.foo = x
@@ -14005,7 +14005,7 @@ class TestClassType(JitTestCase):
def test_type_annotations(self):
with self.assertRaisesRegex(RuntimeError, "expected a value of type bool"):
- @torch.jit.script
+ @torch.jit.script # noqa: B903
class FooTest:
def __init__(self, x):
# type: (bool) -> None
@@ -14026,7 +14026,7 @@ class TestClassType(JitTestCase):
self.attr = x
def test_class_type_as_param(self):
- @torch.jit.script
+ @torch.jit.script # noqa: B903
class FooTest:
def __init__(self, x):
self.attr = x
@@ -14094,7 +14094,7 @@ class TestClassType(JitTestCase):
self.assertEqual(input, output)
def test_save_load_with_classes_nested(self):
- @torch.jit.script
+ @torch.jit.script # noqa: B903
class FooNestedTest:
def __init__(self, y):
self.y = y