summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorDavid Riazati <davidriazati@fb.com>2019-03-29 19:06:06 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-03-29 19:10:12 -0700
commit24db1667daeb84a9eb351ec6b6264de22213a910 (patch)
tree59af453861a2d2ec37eeb20224facc075a0c1341 /test
parente13101e0691b0eabc1900f482a615ea7f14e7a72 (diff)
downloadpytorch-24db1667daeb84a9eb351ec6b6264de22213a910.tar.gz
pytorch-24db1667daeb84a9eb351ec6b6264de22213a910.tar.bz2
pytorch-24db1667daeb84a9eb351ec6b6264de22213a910.zip
Attribute serialization improvements (#18188)
Summary: * adds attributes to `ScriptModule.__getattr__` so they can be accessed in Python after re-importing * full support for all the possible values for an `int64_t` * this necessitated a bunch more `pushWhatever` functions, so re-introduced a templated version to cut down on duplicate code * tests to validate references / value sharing works * adds `torch.jit.Unpickler` which people can use to de-serialize the pickle files into Python / have a quick reference on how to do this without PyTorch Pull Request resolved: https://github.com/pytorch/pytorch/pull/18188 Differential Revision: D14527490 Pulled By: driazati fbshipit-source-id: efd15579cc04aa2e28c4b2c9490d82d849dee559
Diffstat (limited to 'test')
-rw-r--r--test/test_jit.py70
1 files changed, 67 insertions, 3 deletions
diff --git a/test/test_jit.py b/test/test_jit.py
index 4599c96d1d..3f33683ee2 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -16,7 +16,7 @@ from torch.nn import Module
from torch.autograd.function import traceable
from torch.testing import assert_allclose
from torch.onnx import OperatorExportTypes
-from torch._six import inf, PY2, builtins
+from torch._six import inf, PY2, builtins, StringIO
from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \
freeze_rng_state, set_rng_seed, slowTest
@@ -37,7 +37,10 @@ import warnings
import math
import types
import pickle
+import pickletools
import copy
+import zipfile
+
from common_methods_invocations import method_tests as autograd_method_tests
from common_methods_invocations import create_input, unpack_variables, \
@@ -10488,8 +10491,6 @@ a")
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle")
def test_attribute_unpickling(self):
- import zipfile
-
class M(torch.jit.ScriptModule):
def __init__(self):
super(M, self).__init__()
@@ -10557,6 +10558,69 @@ a")
imported_m = self.getExportImportCopy(m)
self.assertEqual(m(), imported_m())
+ def test_serialization_big_ints(self):
+ class M(torch.jit.ScriptModule):
+ def __init__(self):
+ super(M, self).__init__()
+ self.int32_max = torch.jit.Attribute(2**31 - 1, int)
+ self.int32_min = torch.jit.Attribute(-2**31, int)
+ self.uint32_max = torch.jit.Attribute(2**32, int)
+
+ self.int64_max = torch.jit.Attribute(2**63 - 1, int)
+ self.int64_min = torch.jit.Attribute(-2**63, int)
+
+ self.tensor = torch.nn.Parameter(torch.ones(2, 2))
+
+ @torch.jit.script_method
+ def forward(self, x):
+ # type: (int) -> (int)
+ return x + (self.int32_max + self.int32_min) + (self.int64_max + self.int64_min)
+
+ m = M()
+ imported = self.getExportImportCopy(m)
+ self.assertEqual(m(10), imported(10))
+
+ self.assertEqual(m.int32_max, imported.int32_max)
+ self.assertEqual(m.int32_min, imported.int32_min)
+ self.assertEqual(m.uint32_max, imported.uint32_max)
+ self.assertEqual(m.int64_max, imported.int64_max)
+ self.assertEqual(m.int64_min, imported.int64_min)
+
+ def test_serialization_sharing(self):
+ class M(torch.jit.ScriptModule):
+ def __init__(self):
+ super(M, self).__init__()
+ self.list = torch.jit.Attribute([], List[str])
+
+ @torch.jit.script_method
+ def forward(self, key):
+ # type: (str) -> List[str]
+ self.list.append(key)
+ self.list.append(key)
+ self.list.append(key)
+ return self.list
+
+ # the text of the string should only appear once in the pickling
+ m = M()
+ s1 = "a long string"
+ s2 = "a different, even longer string"
+ self.assertEqual(m(s1), [s1] * 3)
+ self.assertEqual(m(s2), [s1] * 3 + [s2] * 3)
+ with TemporaryFileName() as fname:
+ m.save(fname)
+ archive_name = os.path.basename(os.path.normpath(fname))
+ archive = zipfile.ZipFile(fname, 'r')
+ pickled_data = archive.read(os.path.join(archive_name, 'attributes.pkl'))
+
+ out = StringIO()
+ pickletools.dis(pickled_data, out=out)
+ disassembled = out.getvalue()
+
+ FileCheck().check_count(s1, 1, exactly=True) \
+ .check_count("BINGET", 2, exactly=True) \
+ .check_count(s2, 1, exactly=True) \
+ .check_count("BINGET", 2, exactly=True).run(out.getvalue())
+
def test_optional_tuple(self):
def fn(x=None):
# type: (Optional[Tuple[int, int]]) -> Tuple[int, int]