diff options
author | David Riazati <davidriazati@fb.com> | 2019-03-29 19:06:06 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-03-29 19:10:12 -0700 |
commit | 24db1667daeb84a9eb351ec6b6264de22213a910 (patch) | |
tree | 59af453861a2d2ec37eeb20224facc075a0c1341 /test | |
parent | e13101e0691b0eabc1900f482a615ea7f14e7a72 (diff) | |
download | pytorch-24db1667daeb84a9eb351ec6b6264de22213a910.tar.gz pytorch-24db1667daeb84a9eb351ec6b6264de22213a910.tar.bz2 pytorch-24db1667daeb84a9eb351ec6b6264de22213a910.zip |
Attribute serialization improvements (#18188)
Summary:
* adds attributes to `ScriptModule.__getattr__` so they can be accessed in Python after re-importing
* full support for all the possible values for an `int64_t`
* this necessitated a bunch more `pushWhatever` functions, so re-introduced a templated version to cut down on duplicate code
* tests to validate references / value sharing works
* adds `torch.jit.Unpickler` which people can use to de-serialize the pickle files into Python / have a quick reference on how to do this without PyTorch
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18188
Differential Revision: D14527490
Pulled By: driazati
fbshipit-source-id: efd15579cc04aa2e28c4b2c9490d82d849dee559
Diffstat (limited to 'test')
-rw-r--r-- | test/test_jit.py | 70 |
1 files changed, 67 insertions, 3 deletions
diff --git a/test/test_jit.py b/test/test_jit.py index 4599c96d1d..3f33683ee2 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -16,7 +16,7 @@ from torch.nn import Module from torch.autograd.function import traceable from torch.testing import assert_allclose from torch.onnx import OperatorExportTypes -from torch._six import inf, PY2, builtins +from torch._six import inf, PY2, builtins, StringIO from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \ skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \ freeze_rng_state, set_rng_seed, slowTest @@ -37,7 +37,10 @@ import warnings import math import types import pickle +import pickletools import copy +import zipfile + from common_methods_invocations import method_tests as autograd_method_tests from common_methods_invocations import create_input, unpack_variables, \ @@ -10488,8 +10491,6 @@ a") @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle") def test_attribute_unpickling(self): - import zipfile - class M(torch.jit.ScriptModule): def __init__(self): super(M, self).__init__() @@ -10557,6 +10558,69 @@ a") imported_m = self.getExportImportCopy(m) self.assertEqual(m(), imported_m()) + def test_serialization_big_ints(self): + class M(torch.jit.ScriptModule): + def __init__(self): + super(M, self).__init__() + self.int32_max = torch.jit.Attribute(2**31 - 1, int) + self.int32_min = torch.jit.Attribute(-2**31, int) + self.uint32_max = torch.jit.Attribute(2**32, int) + + self.int64_max = torch.jit.Attribute(2**63 - 1, int) + self.int64_min = torch.jit.Attribute(-2**63, int) + + self.tensor = torch.nn.Parameter(torch.ones(2, 2)) + + @torch.jit.script_method + def forward(self, x): + # type: (int) -> (int) + return x + (self.int32_max + self.int32_min) + (self.int64_max + self.int64_min) + + m = M() + imported = self.getExportImportCopy(m) + self.assertEqual(m(10), imported(10)) + + self.assertEqual(m.int32_max, imported.int32_max) + self.assertEqual(m.int32_min, imported.int32_min) + self.assertEqual(m.uint32_max, imported.uint32_max) + self.assertEqual(m.int64_max, imported.int64_max) + self.assertEqual(m.int64_min, imported.int64_min) + + def test_serialization_sharing(self): + class M(torch.jit.ScriptModule): + def __init__(self): + super(M, self).__init__() + self.list = torch.jit.Attribute([], List[str]) + + @torch.jit.script_method + def forward(self, key): + # type: (str) -> List[str] + self.list.append(key) + self.list.append(key) + self.list.append(key) + return self.list + + # the text of the string should only appear once in the pickling + m = M() + s1 = "a long string" + s2 = "a different, even longer string" + self.assertEqual(m(s1), [s1] * 3) + self.assertEqual(m(s2), [s1] * 3 + [s2] * 3) + with TemporaryFileName() as fname: + m.save(fname) + archive_name = os.path.basename(os.path.normpath(fname)) + archive = zipfile.ZipFile(fname, 'r') + pickled_data = archive.read(os.path.join(archive_name, 'attributes.pkl')) + + out = StringIO() + pickletools.dis(pickled_data, out=out) + disassembled = out.getvalue() + + FileCheck().check_count(s1, 1, exactly=True) \ + .check_count("BINGET", 2, exactly=True) \ + .check_count(s2, 1, exactly=True) \ + .check_count("BINGET", 2, exactly=True).run(out.getvalue()) + def test_optional_tuple(self): def fn(x=None): # type: (Optional[Tuple[int, int]]) -> Tuple[int, int] |