summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--test/test_nn.py24
-rw-r--r--torch/nn/modules/module.py14
2 files changed, 34 insertions, 4 deletions
diff --git a/test/test_nn.py b/test/test_nn.py
index bd4b18b78a..90f296c52f 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -3989,7 +3989,9 @@ class TestNN(NNTestCase):
'block.conv1.bias': torch.arange(1, 4),
'bn.running_mean': torch.randn(2),
})
- net.load_state_dict(state_dict)
+ incompatible_keys = net.load_state_dict(state_dict)
+ self.assertEqual(len(incompatible_keys.missing_keys), 0)
+ self.assertEqual(len(incompatible_keys.unexpected_keys), 0)
self.assertEqual(net.linear1.weight.data, state_dict['linear1.weight'])
self.assertEqual(net.block.conv1.bias.data, state_dict['block.conv1.bias'])
self.assertEqual(net.bn.running_mean, state_dict['bn.running_mean'])
@@ -3997,18 +3999,38 @@ class TestNN(NNTestCase):
state_dict = net.state_dict()
state_dict.update({'extra': torch.ones(5)})
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
+ incompatible_keys = net.load_state_dict(state_dict, strict=False)
+ self.assertEqual(len(incompatible_keys.missing_keys), 0)
+ self.assertEqual(len(incompatible_keys.unexpected_keys), 1)
+ self.assertIn('extra', incompatible_keys.unexpected_keys)
state_dict = net.state_dict()
state_dict.update({'extra.param': torch.ones(5)})
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
+ incompatible_keys = net.load_state_dict(state_dict, strict=False)
+ self.assertEqual(len(incompatible_keys.missing_keys), 0)
+ self.assertEqual(len(incompatible_keys.unexpected_keys), 1)
+ self.assertIn('extra.param', incompatible_keys.unexpected_keys)
state_dict = net.state_dict()
del state_dict['linear1.weight']
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
+ incompatible_keys = net.load_state_dict(state_dict, strict=False)
+ self.assertEqual(len(incompatible_keys.missing_keys), 1)
+ self.assertEqual(len(incompatible_keys.unexpected_keys), 0)
+ self.assertIn('linear1.weight', incompatible_keys.missing_keys)
+ state_dict.update({'extra.param': torch.ones(5)})
+ self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
+ incompatible_keys = net.load_state_dict(state_dict, strict=False)
+ self.assertEqual(len(incompatible_keys.missing_keys), 1)
+ self.assertEqual(len(incompatible_keys.unexpected_keys), 1)
+ self.assertIn('linear1.weight', incompatible_keys.missing_keys)
+ self.assertIn('extra.param', incompatible_keys.unexpected_keys)
state_dict = net.state_dict()
state_dict.update({'bn.running_mean': torch.rand(14, 4)}) # wrong size
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
+ self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict, strict=False))
state_dict = net.state_dict()
old_state_dict = deepcopy(state_dict)
diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py
index 079ec55244..11c6a081d7 100644
--- a/torch/nn/modules/module.py
+++ b/torch/nn/modules/module.py
@@ -1,4 +1,4 @@
-from collections import OrderedDict
+from collections import OrderedDict, namedtuple
import functools
import itertools
@@ -8,6 +8,9 @@ from ..parameter import Parameter
import torch.utils.hooks as hooks
+_IncompatibleKeys = namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])
+
+
def _addindent(s_, numSpaces):
s = s_.split('\n')
# don't do anything for single-line stuff
@@ -734,6 +737,11 @@ class Module(object):
strict (bool, optional): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
+
+ Returns:
+ ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
+ * **missing_keys** is a list of str containing the missing keys
+ * **unexpected_keys** is a list of str containing the unexpected keys
"""
missing_keys = []
unexpected_keys = []
@@ -748,7 +756,7 @@ class Module(object):
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
- state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
@@ -756,7 +764,6 @@ class Module(object):
load(self)
if strict:
- error_msg = ''
if len(unexpected_keys) > 0:
error_msgs.insert(
0, 'Unexpected key(s) in state_dict: {}. '.format(
@@ -769,6 +776,7 @@ class Module(object):
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))
+ return _IncompatibleKeys(missing_keys, unexpected_keys)
def _named_members(self, get_members_fn, prefix='', recurse=True):
r"""Helper method for yielding various names + members of modules."""