import torch import unittest import io import tempfile import os import sys import zipfile import warnings import gzip import copy import pickle import shutil import pathlib from torch._utils_internal import get_file_path_2 from torch._utils import _rebuild_tensor from torch.serialization import check_module_version_greater_or_equal from torch.testing._internal.common_utils import TestCase, IS_WINDOWS, \ TEST_DILL, run_tests, download_file, BytesIOContext, TemporaryFileName from torch.testing._internal.common_device_type import instantiate_device_type_tests # These tests were all copied from `test/test_torch.py` at some point, so see # the actual blame, see this revision # https://github.com/pytorch/pytorch/blame/9a2691f2fc948b9792686085b493c61793c2de30/test/test_torch.py if TEST_DILL: import dill HAS_DILL_AT_LEAST_0_3_1 = check_module_version_greater_or_equal(dill, (0, 3, 1)) else: HAS_DILL_AT_LEAST_0_3_1 = False can_retrieve_source = True with warnings.catch_warnings(record=True) as warns: with tempfile.NamedTemporaryFile() as checkpoint: x = torch.save(torch.nn.Module(), checkpoint) for warn in warns: if "Couldn't retrieve source code" in warn.message.args[0]: can_retrieve_source = False break class FilelikeMock(object): def __init__(self, data, has_fileno=True, has_readinto=False): if has_readinto: self.readinto = self.readinto_opt if has_fileno: # Python 2's StringIO.StringIO has no fileno attribute. # This is used to test that. self.fileno = self.fileno_opt self.calls = set() self.bytesio = io.BytesIO(data) def trace(fn, name): def result(*args, **kwargs): self.calls.add(name) return fn(*args, **kwargs) return result for attr in ['read', 'readline', 'seek', 'tell', 'write', 'flush']: traced_fn = trace(getattr(self.bytesio, attr), attr) setattr(self, attr, traced_fn) def fileno_opt(self): raise io.UnsupportedOperation('Not a real file') def readinto_opt(self, view): self.calls.add('readinto') return self.bytesio.readinto(view) def was_called(self, name): return name in self.calls class SerializationMixin(object): def _test_serialization_data(self): a = [torch.randn(5, 5).float() for i in range(2)] b = [a[i % 2] for i in range(4)] # 0-3 b += [a[0].storage()] # 4 b += [a[0].reshape(-1)[1:4].storage()] # 5 b += [torch.arange(1, 11).int()] # 6 t1 = torch.FloatTensor().set_(a[0].reshape(-1)[1:4].clone().storage(), 0, (3,), (1,)) t2 = torch.FloatTensor().set_(a[0].reshape(-1)[1:4].clone().storage(), 0, (3,), (1,)) b += [(t1.storage(), t1.storage(), t2.storage())] # 7 b += [a[0].reshape(-1)[0:2].storage()] # 8 return b def _test_serialization_assert(self, b, c): self.assertEqual(b, c, atol=0, rtol=0) self.assertTrue(isinstance(c[0], torch.FloatTensor)) self.assertTrue(isinstance(c[1], torch.FloatTensor)) self.assertTrue(isinstance(c[2], torch.FloatTensor)) self.assertTrue(isinstance(c[3], torch.FloatTensor)) self.assertTrue(isinstance(c[4], torch.FloatStorage)) c[0].fill_(10) self.assertEqual(c[0], c[2], atol=0, rtol=0) self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), atol=0, rtol=0) c[1].fill_(20) self.assertEqual(c[1], c[3], atol=0, rtol=0) # I have to do it in this roundabout fashion, because there's no # way to slice storages for i in range(4): self.assertEqual(c[4][i + 1], c[5][i]) # check that serializing the same storage view object unpickles # it as one object not two (and vice versa) views = c[7] self.assertEqual(views[0]._cdata, views[1]._cdata) self.assertEqual(views[0], views[2]) self.assertNotEqual(views[0]._cdata, views[2]._cdata) rootview = c[8] self.assertEqual(rootview.data_ptr(), c[0].data_ptr()) def test_serialization_zipfile_utils(self): data = { 'a': b'12039810948234589', 'b': b'1239081209484958', 'c/d': b'94589480984058' } def test(name_or_buffer): with torch.serialization._open_zipfile_writer(name_or_buffer) as zip_file: for key in data: zip_file.write_record(key, data[key], len(data[key])) if hasattr(name_or_buffer, 'seek'): name_or_buffer.seek(0) with torch.serialization._open_zipfile_reader(name_or_buffer) as zip_file: for key in data: actual = zip_file.get_record(key) expected = data[key] self.assertEqual(expected, actual) with tempfile.NamedTemporaryFile() as f: test(f) with TemporaryFileName() as fname: test(fname) test(io.BytesIO()) def test_serialization(self): # Test serialization with a real file b = self._test_serialization_data() with tempfile.NamedTemporaryFile() as f: torch.save(b, f) f.seek(0) c = torch.load(f) self._test_serialization_assert(b, c) with TemporaryFileName() as fname: torch.save(b, fname) c = torch.load(fname) self._test_serialization_assert(b, c) # test non-ascii encoding of bytes arrays/strings # The following bytes are produced by serializing # [b'\xc5\xbc\xc4\x85\xc4\x85\xc3\xb3\xc5\xbc\xc4\x85\xc5\xbc', torch.zeros(1, dtype=torch.float), 2] # in Python 2.7.12 and PyTorch 0.4.1, where the first element contains # bytes of some utf-8 characters (i.e., `utf8_str.encode('utf-8')`). serialized = ( b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03.' b'\x80\x02}q\x01(U\x10protocol_versionq\x02M\xe9\x03U\n' b'type_sizesq\x03}q\x04(U\x03intq\x05K\x04U\x05shortq\x06K\x02U' b'\x04longq\x07K\x04uU\rlittle_endianq\x08\x88u.\x80\x02]q' b'\x01(U\x0e\xc5\xbc\xc4\x85\xc4\x85\xc3\xb3\xc5\xbc\xc4\x85' b'\xc5\xbcq\x02ctorch._utils\n_rebuild_tensor_v2\nq\x03((U' b'\x07storageq\x04ctorch\nFloatStorage\nq\x05U\x0845640624q' b'\x06U\x03cpuq\x07\x8a\x01\x01NtQK\x00K\x01\x85K\x01\x85' b'\x89NtRq\x08K\x02e.\x80\x02]q\x01U\x0845640624q\x02a.\x01\x00' b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' ) buf = io.BytesIO(serialized) utf8_bytes = b'\xc5\xbc\xc4\x85\xc4\x85\xc3\xb3\xc5\xbc\xc4\x85\xc5\xbc' utf8_str = utf8_bytes.decode('utf-8') loaded_utf8 = torch.load(buf, encoding='utf-8') self.assertEqual(loaded_utf8, [utf8_str, torch.zeros(1, dtype=torch.float), 2]) buf.seek(0) loaded_bytes = torch.load(buf, encoding='bytes') self.assertEqual(loaded_bytes, [utf8_bytes, torch.zeros(1, dtype=torch.float), 2]) def test_serialization_filelike(self): # Test serialization (load and save) with a filelike object b = self._test_serialization_data() with BytesIOContext() as f: torch.save(b, f) f.seek(0) c = torch.load(f) self._test_serialization_assert(b, c) def test_serialization_fake_zip(self): data = [ ord('P'), ord('K'), 5, 6 ] for i in range(0, 100): data.append(0) t = torch.tensor(data, dtype=torch.uint8) with tempfile.NamedTemporaryFile() as f: torch.save(t, f) # If this check is False for all Python versions (i.e. the fix # has been backported), this test and torch.serialization._is_zipfile # can be deleted self.assertTrue(zipfile.is_zipfile(f)) self.assertFalse(torch.serialization._is_zipfile(f)) f.seek(0) self.assertEqual(torch.load(f), t) def test_serialization_gzip(self): # Test serialization with gzip file b = self._test_serialization_data() f1 = tempfile.NamedTemporaryFile(delete=False) f2 = tempfile.NamedTemporaryFile(delete=False) torch.save(b, f1) with open(f1.name, 'rb') as f_in, gzip.open(f2.name, 'wb') as f_out: shutil.copyfileobj(f_in, f_out) with gzip.open(f2.name, 'rb') as f: c = torch.load(f) self._test_serialization_assert(b, c) @unittest.skipIf( not TEST_DILL or HAS_DILL_AT_LEAST_0_3_1, '"dill" not found or is correct version' ) def test_serialization_dill_version_not_supported(self): x = torch.randn(5, 5) with tempfile.NamedTemporaryFile() as f: with self.assertRaisesRegex(ValueError, 'supports dill >='): torch.save(x, f, pickle_module=dill) f.seek(0) with self.assertRaisesRegex(ValueError, 'supports dill >='): x2 = torch.load(f, pickle_module=dill, encoding='utf-8') @unittest.skipIf( not TEST_DILL or not HAS_DILL_AT_LEAST_0_3_1, '"dill" not found or not correct version' ) def test_serialization_dill(self): x = torch.randn(5, 5) with tempfile.NamedTemporaryFile() as f: torch.save(x, f, pickle_module=dill) f.seek(0) x2 = torch.load(f, pickle_module=dill, encoding='utf-8') self.assertIsInstance(x2, type(x)) self.assertEqual(x, x2) f.seek(0) x3 = torch.load(f, pickle_module=dill) self.assertIsInstance(x3, type(x)) self.assertEqual(x, x3) def test_serialization_offset_gzip(self): a = torch.randn(5, 5) i = 41 f1 = tempfile.NamedTemporaryFile(delete=False) f2 = tempfile.NamedTemporaryFile(delete=False) with open(f1.name, 'wb') as f: pickle.dump(i, f) torch.save(a, f) with open(f1.name, 'rb') as f_in, gzip.open(f2.name, 'wb') as f_out: shutil.copyfileobj(f_in, f_out) with gzip.open(f2.name, 'rb') as f: j = pickle.load(f) b = torch.load(f) self.assertTrue(torch.equal(a, b)) self.assertEqual(i, j) def test_serialization_sparse(self): x = torch.zeros(3, 3) x[1][1] = 1 x = x.to_sparse() with tempfile.NamedTemporaryFile() as f: torch.save({"tensor": x}, f) f.seek(0) y = torch.load(f) self.assertEqual(x, y["tensor"]) def test_serialization_sparse_invalid(self): x = torch.zeros(3, 3) x[1][1] = 1 x = x.to_sparse() class TensorSerializationSpoofer(object): def __init__(self, tensor): self.tensor = tensor def __reduce_ex__(self, proto): invalid_indices = self.tensor._indices().clone() invalid_indices[0][0] = 3 return ( torch._utils._rebuild_sparse_tensor, ( self.tensor.layout, ( invalid_indices, self.tensor._values(), self.tensor.size()))) with tempfile.NamedTemporaryFile() as f: torch.save({"spoofed": TensorSerializationSpoofer(x)}, f) f.seek(0) with self.assertRaisesRegex( RuntimeError, "size is inconsistent with indices"): y = torch.load(f) def test_serialize_device(self): device_str = ['cpu', 'cpu:0', 'cuda', 'cuda:0'] device_obj = [torch.device(d) for d in device_str] for device in device_obj: device_copied = copy.deepcopy(device) self.assertEqual(device, device_copied) def test_serialization_backwards_compat(self): a = [torch.arange(1 + i, 26 + i).view(5, 5).float() for i in range(2)] b = [a[i % 2] for i in range(4)] b += [a[0].storage()] b += [a[0].reshape(-1)[1:4].clone().storage()] path = download_file('https://download.pytorch.org/test_data/legacy_serialized.pt') c = torch.load(path) self.assertEqual(b, c, atol=0, rtol=0) self.assertTrue(isinstance(c[0], torch.FloatTensor)) self.assertTrue(isinstance(c[1], torch.FloatTensor)) self.assertTrue(isinstance(c[2], torch.FloatTensor)) self.assertTrue(isinstance(c[3], torch.FloatTensor)) self.assertTrue(isinstance(c[4], torch.FloatStorage)) c[0].fill_(10) self.assertEqual(c[0], c[2], atol=0, rtol=0) self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), atol=0, rtol=0) c[1].fill_(20) self.assertEqual(c[1], c[3], atol=0, rtol=0) # test some old tensor serialization mechanism class OldTensorBase(object): def __init__(self, new_tensor): self.new_tensor = new_tensor def __getstate__(self): return (self.new_tensor.storage(), self.new_tensor.storage_offset(), tuple(self.new_tensor.size()), self.new_tensor.stride()) class OldTensorV1(OldTensorBase): def __reduce__(self): return (torch.Tensor, (), self.__getstate__()) class OldTensorV2(OldTensorBase): def __reduce__(self): return (_rebuild_tensor, self.__getstate__()) x = torch.randn(30).as_strided([2, 3], [9, 3], 2) for old_cls in [OldTensorV1, OldTensorV2]: with tempfile.NamedTemporaryFile() as f: old_x = old_cls(x) torch.save(old_x, f) f.seek(0) load_x = torch.load(f) self.assertEqual(x.storage(), load_x.storage()) self.assertEqual(x.storage_offset(), load_x.storage_offset()) self.assertEqual(x.size(), load_x.size()) self.assertEqual(x.stride(), load_x.stride()) def test_serialization_save_warnings(self): with warnings.catch_warnings(record=True) as warns: with tempfile.NamedTemporaryFile() as checkpoint: x = torch.save(torch.nn.Linear(2, 3), checkpoint) self.assertEquals(len(warns), 0) def test_serialization_map_location(self): test_file_path = download_file('https://download.pytorch.org/test_data/gpu_tensors.pt') def map_location(storage, loc): return storage def load_bytes(): with open(test_file_path, 'rb') as f: return io.BytesIO(f.read()) fileobject_lambdas = [lambda: test_file_path, load_bytes] cpu_map_locations = [ map_location, {'cuda:0': 'cpu'}, 'cpu', torch.device('cpu'), ] gpu_0_map_locations = [ {'cuda:0': 'cuda:0'}, 'cuda', 'cuda:0', torch.device('cuda'), torch.device('cuda', 0) ] gpu_last_map_locations = [ 'cuda:{}'.format(torch.cuda.device_count() - 1), ] def check_map_locations(map_locations, tensor_class, intended_device): for fileobject_lambda in fileobject_lambdas: for map_location in map_locations: tensor = torch.load(fileobject_lambda(), map_location=map_location) self.assertEqual(tensor.device, intended_device) self.assertIsInstance(tensor, tensor_class) self.assertEqual(tensor, tensor_class([[1.0, 2.0], [3.0, 4.0]])) check_map_locations(cpu_map_locations, torch.FloatTensor, torch.device('cpu')) if torch.cuda.is_available(): check_map_locations(gpu_0_map_locations, torch.cuda.FloatTensor, torch.device('cuda', 0)) check_map_locations( gpu_last_map_locations, torch.cuda.FloatTensor, torch.device('cuda', torch.cuda.device_count() - 1) ) @unittest.skipIf(torch.cuda.is_available(), "Testing torch.load on CPU-only machine") def test_load_nonexistent_device(self): # Setup: create a serialized file object with a 'cuda:0' restore location # The following was generated by saving a torch.randn(2, device='cuda') tensor. serialized = (b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9' b'\x03.\x80\x02}q\x00(X\x10\x00\x00\x00protocol_versionq' b'\x01M\xe9\x03X\r\x00\x00\x00little_endianq\x02\x88X\n' b'\x00\x00\x00type_sizesq\x03}q\x04(X\x05\x00\x00\x00shortq' b'\x05K\x02X\x03\x00\x00\x00intq\x06K\x04X\x04\x00\x00\x00' b'longq\x07K\x04uu.\x80\x02ctorch._utils\n_rebuild_tensor_v2' b'\nq\x00((X\x07\x00\x00\x00storageq\x01ctorch\nFloatStorage' b'\nq\x02X\x0e\x00\x00\x0094919395964320q\x03X\x06\x00\x00' b'\x00cuda:0q\x04K\x02Ntq\x05QK\x00K\x02\x85q\x06K\x01\x85q' b'\x07\x89Ntq\x08Rq\t.\x80\x02]q\x00X\x0e\x00\x00\x00' b'94919395964320q\x01a.\x02\x00\x00\x00\x00\x00\x00\x00\xbb' b'\x1f\x82\xbe\xea\x81\xd1>') buf = io.BytesIO(serialized) error_msg = r'Attempting to deserialize object on a CUDA device' with self.assertRaisesRegex(RuntimeError, error_msg): _ = torch.load(buf) @unittest.skipIf((3, 8, 0) <= sys.version_info < (3, 8, 2), "See https://bugs.python.org/issue39681") def test_serialization_filelike_api_requirements(self): filemock = FilelikeMock(b'', has_readinto=False) tensor = torch.randn(3, 5) torch.save(tensor, filemock) expected_superset = {'write', 'flush'} self.assertTrue(expected_superset.issuperset(filemock.calls)) # Reset between save and load filemock.seek(0) filemock.calls.clear() _ = torch.load(filemock) expected_superset = {'read', 'readline', 'seek', 'tell'} self.assertTrue(expected_superset.issuperset(filemock.calls)) def _test_serialization_filelike(self, tensor, mock, desc): f = mock(b'') torch.save(tensor, f) f.seek(0) data = mock(f.read()) msg = 'filelike serialization with {}' b = torch.load(data) self.assertTrue(torch.equal(tensor, b), msg.format(desc)) @unittest.skipIf((3, 8, 0) <= sys.version_info < (3, 8, 2), "See https://bugs.python.org/issue39681") def test_serialization_filelike_missing_attrs(self): # Test edge cases where filelike objects are missing attributes. # The Python io docs suggests that these attributes should really exist # and throw io.UnsupportedOperation, but that isn't always the case. mocks = [ ('no readinto', lambda x: FilelikeMock(x)), ('has readinto', lambda x: FilelikeMock(x, has_readinto=True)), ('no fileno', lambda x: FilelikeMock(x, has_fileno=False)), ] to_serialize = torch.randn(3, 10) for desc, mock in mocks: self._test_serialization_filelike(to_serialize, mock, desc) @unittest.skipIf((3, 8, 0) <= sys.version_info < (3, 8, 2), "See https://bugs.python.org/issue39681") def test_serialization_filelike_stress(self): a = torch.randn(11 * (2 ** 9) + 1, 5 * (2 ** 9)) # This one should call python read multiple times self._test_serialization_filelike(a, lambda x: FilelikeMock(x, has_readinto=False), 'read() stress test') self._test_serialization_filelike(a, lambda x: FilelikeMock(x, has_readinto=True), 'readinto() stress test') def test_serialization_filelike_uses_readinto(self): # For maximum effiency, when reading a file-like object, # ensure the C API calls readinto instead of read. a = torch.randn(5, 4) f = io.BytesIO() torch.save(a, f) f.seek(0) data = FilelikeMock(f.read(), has_readinto=True) b = torch.load(data) self.assertTrue(data.was_called('readinto')) def test_serialization_storage_slice(self): # Generated using: # # t = torch.zeros(2); # s1 = t.storage()[:1] # s2 = t.storage()[1:] # torch.save((s1, s2), 'foo.ser') # # with PyTorch 0.3.1 serialized = (b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03' b'.\x80\x02}q\x00(X\n\x00\x00\x00type_sizesq\x01}q\x02(X\x03' b'\x00\x00\x00intq\x03K\x04X\x05\x00\x00\x00shortq\x04K\x02X' b'\x04\x00\x00\x00longq\x05K\x04uX\x10\x00\x00\x00protocol_versionq' b'\x06M\xe9\x03X\r\x00\x00\x00little_endianq\x07\x88u.\x80\x02' b'(X\x07\x00\x00\x00storageq\x00ctorch\nFloatStorage\nq\x01X\x0e' b'\x00\x00\x0094279043900432q\x02X\x03\x00\x00\x00cpuq\x03K\x02' b'X\x0e\x00\x00\x0094279029750368q\x04K\x00K\x01\x87q\x05tq\x06' b'Q(h\x00h\x01X\x0e\x00\x00\x0094279043900432q\x07h\x03K\x02X' b'\x0e\x00\x00\x0094279029750432q\x08K\x01K\x01\x87q\ttq\nQ' b'\x86q\x0b.\x80\x02]q\x00X\x0e\x00\x00\x0094279043900432q' b'\x01a.\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' b'\x00\x00\x00\x00') buf = io.BytesIO(serialized) (s1, s2) = torch.load(buf) self.assertEqual(s1[0], 0) self.assertEqual(s2[0], 0) self.assertEqual(s1.data_ptr() + 4, s2.data_ptr()) def test_load_unicode_error_msg(self): # This Pickle contains a Python 2 module with Unicode data and the # loading should fail if the user explicitly specifies ascii encoding! path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt') self.assertRaises(UnicodeDecodeError, lambda: torch.load(path, encoding='ascii')) def test_load_python2_unicode_module(self): # This Pickle contains some Unicode data! path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt') with warnings.catch_warnings(record=True) as w: self.assertIsNotNone(torch.load(path)) def test_load_error_msg(self): expected_err_msg = (".*You can only torch.load from a file that is seekable. " + "Please pre-load the data into a buffer like io.BytesIO and " + "try to load from it instead.") resource = FilelikeMock(data=b"data") delattr(resource, "tell") delattr(resource, "seek") with self.assertRaisesRegex(AttributeError, expected_err_msg): torch.load(resource) class serialization_method(object): def __init__(self, use_zip): self.use_zip = use_zip self.torch_save = torch.save def __enter__(self, *args, **kwargs): def wrapper(*args, **kwargs): if '_use_new_zipfile_serialization' in kwargs: raise RuntimeError("Cannot set method manually") kwargs['_use_new_zipfile_serialization'] = self.use_zip return self.torch_save(*args, **kwargs) torch.save = wrapper def __exit__(self, *args, **kwargs): torch.save = self.torch_save class TestBothSerialization(TestCase): @unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows") def test_serialization_new_format_old_format_compat(self, device): x = [torch.ones(200, 200, device=device) for i in range(30)] def test(f_new, f_old): torch.save(x, f_new, _use_new_zipfile_serialization=True) f_new.seek(0) x_new_load = torch.load(f_new) self.assertEqual(x, x_new_load) torch.save(x, f_old, _use_new_zipfile_serialization=False) f_old.seek(0) x_old_load = torch.load(f_old) self.assertEqual(x_old_load, x_new_load) with tempfile.NamedTemporaryFile() as f_new, tempfile.NamedTemporaryFile() as f_old: test(f_new, f_old) class TestOldSerialization(TestCase, SerializationMixin): # unique_key is necessary because on Python 2.7, if a warning passed to # the warning module is the same, it is not raised again. def _test_serialization_container(self, unique_key, filecontext_lambda): tmpmodule_name = 'tmpmodule{}'.format(unique_key) def import_module(name, filename): import importlib.util spec = importlib.util.spec_from_file_location(name, filename) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) sys.modules[module.__name__] = module return module with filecontext_lambda() as checkpoint: fname = get_file_path_2(os.path.dirname(os.path.dirname(torch.__file__)), 'torch', 'testing', '_internal', 'data', 'network1.py') module = import_module(tmpmodule_name, fname) torch.save(module.Net(), checkpoint) # First check that the checkpoint can be loaded without warnings checkpoint.seek(0) with warnings.catch_warnings(record=True) as w: loaded = torch.load(checkpoint) self.assertTrue(isinstance(loaded, module.Net)) if can_retrieve_source: self.assertEquals(len(w), 0) # Replace the module with different source fname = get_file_path_2(os.path.dirname(os.path.dirname(torch.__file__)), 'torch', 'testing', '_internal', 'data', 'network2.py') module = import_module(tmpmodule_name, fname) checkpoint.seek(0) with warnings.catch_warnings(record=True) as w: loaded = torch.load(checkpoint) self.assertTrue(isinstance(loaded, module.Net)) if can_retrieve_source: self.assertEquals(len(w), 1) self.assertTrue(w[0].category, 'SourceChangeWarning') def test_serialization_container(self): self._test_serialization_container('file', tempfile.NamedTemporaryFile) def test_serialization_container_filelike(self): self._test_serialization_container('filelike', BytesIOContext) def test_serialization_offset(self): a = torch.randn(5, 5) b = torch.randn(1024, 1024, 512, dtype=torch.float32) m = torch.nn.Conv2d(1, 1, (1, 3)) i, j = 41, 43 with tempfile.NamedTemporaryFile() as f: pickle.dump(i, f) torch.save(a, f) pickle.dump(j, f) torch.save(b, f) torch.save(m, f) self.assertTrue(f.tell() > 2 * 1024 * 1024 * 1024) f.seek(0) i_loaded = pickle.load(f) a_loaded = torch.load(f) j_loaded = pickle.load(f) b_loaded = torch.load(f) m_loaded = torch.load(f) self.assertTrue(torch.equal(a, a_loaded)) self.assertTrue(torch.equal(b, b_loaded)) self.assertTrue(m.kernel_size == m_loaded.kernel_size) self.assertEqual(i, i_loaded) self.assertEqual(j, j_loaded) def test_serialization_offset_filelike(self): a = torch.randn(5, 5) b = torch.randn(1024, 1024, 512, dtype=torch.float32) i, j = 41, 43 with BytesIOContext() as f: pickle.dump(i, f) torch.save(a, f) pickle.dump(j, f) torch.save(b, f) self.assertTrue(f.tell() > 2 * 1024 * 1024 * 1024) f.seek(0) i_loaded = pickle.load(f) a_loaded = torch.load(f) j_loaded = pickle.load(f) b_loaded = torch.load(f) self.assertTrue(torch.equal(a, a_loaded)) self.assertTrue(torch.equal(b, b_loaded)) self.assertEqual(i, i_loaded) self.assertEqual(j, j_loaded) def run(self, *args, **kwargs): with serialization_method(use_zip=False): return super(TestOldSerialization, self).run(*args, **kwargs) class TestWrapperSubclass(torch.Tensor): elem: torch.Tensor __slots__ = ['elem', 'other'] @staticmethod def __new__(cls, elem, *args, **kwargs): # The wrapping tensor (TestSubclass) is just a meta tensor, so it # doesn't hold any memory (meta tensor is generally the preferred type # of tensor you want to make a subclass from)... r = torch.Tensor._make_subclass(cls, elem.to('meta'), elem.requires_grad) # ...the real tensor is held as an element on the tensor. r.elem = elem return r class TestGetStateSubclass(torch.Tensor): elem: torch.Tensor __slots__ = ['elem'] @staticmethod def __new__(cls, elem, *args, **kwargs): # The wrapping tensor (TestSubclass) is just a meta tensor, so it # doesn't hold any memory (meta tensor is generally the preferred type # of tensor you want to make a subclass from)... r = torch.Tensor._make_subclass(cls, elem.to('meta'), elem.requires_grad) # ...the real tensor is held as an element on the tensor. r.elem = elem return r def __getstate__(self): return ("foo", getattr(self, "elem", None), self.__dict__) def __setstate__(self, state): marker, self.elem, self.__dict__ = state if not marker == "foo": raise RuntimeError("Invalid state for TestGetStateSubclass") self.reloaded = True class TestSerialization(TestCase, SerializationMixin): def test_serialization_zipfile(self): data = self._test_serialization_data() def test(name_or_buffer): torch.save(data, name_or_buffer) if hasattr(name_or_buffer, 'seek'): name_or_buffer.seek(0) result = torch.load(name_or_buffer) self.assertEqual(result, data) with tempfile.NamedTemporaryFile() as f: test(f) with TemporaryFileName() as fname: test(fname) test(io.BytesIO()) def test_serialization_zipfile_actually_jit(self): with tempfile.NamedTemporaryFile() as f: torch.jit.save(torch.jit.script(torch.nn.Linear(3, 4)), f) f.seek(0) torch.load(f) # Ensure large zip64 serialization works properly def test_serialization_2gb_file(self): big_model = torch.nn.Conv2d(20000, 3200, kernel_size=3) with BytesIOContext() as f: torch.save(big_model, f) f.seek(0) state = torch.load(f) def test_pathlike_serialization(self): model = torch.nn.Conv2d(20, 3200, kernel_size=3) with TemporaryFileName() as fname: path = pathlib.Path(fname) torch.save(model, path) torch.load(path) def test_meta_serialization(self): big_model = torch.nn.Conv2d(20000, 320000, kernel_size=3, device='meta') with BytesIOContext() as f: torch.save(big_model, f) f.seek(0) state = torch.load(f) self.assertEqual(state.weight.size(), big_model.weight.size()) def test_tensor_subclass_wrapper_serialization(self): wrapped_tensor = torch.rand(2) my_tensor = TestWrapperSubclass(wrapped_tensor) foo_val = "bar" my_tensor.foo = foo_val self.assertEqual(my_tensor.foo, foo_val) with BytesIOContext() as f: torch.save(my_tensor, f) f.seek(0) new_tensor = torch.load(f) self.assertIsInstance(new_tensor, TestWrapperSubclass) self.assertEqual(new_tensor.elem, my_tensor.elem) self.assertEqual(new_tensor.foo, foo_val) def test_tensor_subclass_getstate_overwrite(self): wrapped_tensor = torch.rand(2) my_tensor = TestGetStateSubclass(wrapped_tensor) foo_val = "bar" my_tensor.foo = foo_val self.assertEqual(my_tensor.foo, foo_val) with BytesIOContext() as f: torch.save(my_tensor, f) f.seek(0) new_tensor = torch.load(f) self.assertIsInstance(new_tensor, TestGetStateSubclass) self.assertEqual(new_tensor.elem, my_tensor.elem) self.assertEqual(new_tensor.foo, foo_val) self.assertTrue(new_tensor.reloaded) def run(self, *args, **kwargs): with serialization_method(use_zip=True): return super(TestSerialization, self).run(*args, **kwargs) instantiate_device_type_tests(TestBothSerialization, globals()) if __name__ == '__main__': run_tests()