diff options
-rw-r--r-- | test/test_nn.py | 24 | ||||
-rw-r--r-- | torch/nn/modules/module.py | 14 |
2 files changed, 34 insertions, 4 deletions
diff --git a/test/test_nn.py b/test/test_nn.py index bd4b18b78a..90f296c52f 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -3989,7 +3989,9 @@ class TestNN(NNTestCase): 'block.conv1.bias': torch.arange(1, 4), 'bn.running_mean': torch.randn(2), }) - net.load_state_dict(state_dict) + incompatible_keys = net.load_state_dict(state_dict) + self.assertEqual(len(incompatible_keys.missing_keys), 0) + self.assertEqual(len(incompatible_keys.unexpected_keys), 0) self.assertEqual(net.linear1.weight.data, state_dict['linear1.weight']) self.assertEqual(net.block.conv1.bias.data, state_dict['block.conv1.bias']) self.assertEqual(net.bn.running_mean, state_dict['bn.running_mean']) @@ -3997,18 +3999,38 @@ class TestNN(NNTestCase): state_dict = net.state_dict() state_dict.update({'extra': torch.ones(5)}) self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict)) + incompatible_keys = net.load_state_dict(state_dict, strict=False) + self.assertEqual(len(incompatible_keys.missing_keys), 0) + self.assertEqual(len(incompatible_keys.unexpected_keys), 1) + self.assertIn('extra', incompatible_keys.unexpected_keys) state_dict = net.state_dict() state_dict.update({'extra.param': torch.ones(5)}) self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict)) + incompatible_keys = net.load_state_dict(state_dict, strict=False) + self.assertEqual(len(incompatible_keys.missing_keys), 0) + self.assertEqual(len(incompatible_keys.unexpected_keys), 1) + self.assertIn('extra.param', incompatible_keys.unexpected_keys) state_dict = net.state_dict() del state_dict['linear1.weight'] self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict)) + incompatible_keys = net.load_state_dict(state_dict, strict=False) + self.assertEqual(len(incompatible_keys.missing_keys), 1) + self.assertEqual(len(incompatible_keys.unexpected_keys), 0) + self.assertIn('linear1.weight', incompatible_keys.missing_keys) + state_dict.update({'extra.param': torch.ones(5)}) + self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict)) + incompatible_keys = net.load_state_dict(state_dict, strict=False) + self.assertEqual(len(incompatible_keys.missing_keys), 1) + self.assertEqual(len(incompatible_keys.unexpected_keys), 1) + self.assertIn('linear1.weight', incompatible_keys.missing_keys) + self.assertIn('extra.param', incompatible_keys.unexpected_keys) state_dict = net.state_dict() state_dict.update({'bn.running_mean': torch.rand(14, 4)}) # wrong size self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict)) + self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict, strict=False)) state_dict = net.state_dict() old_state_dict = deepcopy(state_dict) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 079ec55244..11c6a081d7 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -1,4 +1,4 @@ -from collections import OrderedDict +from collections import OrderedDict, namedtuple import functools import itertools @@ -8,6 +8,9 @@ from ..parameter import Parameter import torch.utils.hooks as hooks +_IncompatibleKeys = namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys']) + + def _addindent(s_, numSpaces): s = s_.split('\n') # don't do anything for single-line stuff @@ -734,6 +737,11 @@ class Module(object): strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` + + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * **missing_keys** is a list of str containing the missing keys + * **unexpected_keys** is a list of str containing the unexpected keys """ missing_keys = [] unexpected_keys = [] @@ -748,7 +756,7 @@ class Module(object): def load(module, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) module._load_from_state_dict( - state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') @@ -756,7 +764,6 @@ class Module(object): load(self) if strict: - error_msg = '' if len(unexpected_keys) > 0: error_msgs.insert( 0, 'Unexpected key(s) in state_dict: {}. '.format( @@ -769,6 +776,7 @@ class Module(object): if len(error_msgs) > 0: raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( self.__class__.__name__, "\n\t".join(error_msgs))) + return _IncompatibleKeys(missing_keys, unexpected_keys) def _named_members(self, get_members_fn, prefix='', recurse=True): r"""Helper method for yielding various names + members of modules.""" |