diff options
author | 0phoff <0phoff@users.noreply.github.com> | 2018-07-31 19:30:20 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-07-31 19:39:06 -0700 |
commit | 294c06538416f5882e9f329695dcae7ff77c218b (patch) | |
tree | fcb3e83a2673d736a14475ebdd32ac4fbfa953ed /torch/optim | |
parent | aae37324cc66b80644631bcc8cadcfe741cab625 (diff) | |
download | pytorch-294c06538416f5882e9f329695dcae7ff77c218b.tar.gz pytorch-294c06538416f5882e9f329695dcae7ff77c218b.tar.bz2 pytorch-294c06538416f5882e9f329695dcae7ff77c218b.zip |
Changed serialization mechanism of LambdaLR scheduler (#9927)
Summary:
I opened an issue explaining some of my frustrations with the current state of schedulers.
While most points that I raised in [that issue](https://github.com/pytorch/pytorch/issues/8741#issuecomment-404449697) need to be discussed more thoroughly before being implemented, there are some that are not so difficult to fix.
This PR changes the way the LambdaLR scheduler gets serialized:
> The lr_lambda functions are only saved if the are callable objects (which can be stateful).
> There is no point in saving functions/lambdas as you need their definition before unpickling and they are stateless.
This has the big advantage that the scheduler is serializable, even if you use lambda functions or locally defined functions (aka a function in a function).
Does this functionality need any unit tests?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9927
Differential Revision: D9055505
Pulled By: soumith
fbshipit-source-id: 6c1cec588beedd098ec7d2bce6a9add27f29e48f
Diffstat (limited to 'torch/optim')
-rw-r--r-- | torch/optim/lr_scheduler.py | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index ad7f780719..96cfaff868 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -1,3 +1,4 @@ +import types import math import torch from torch._six import inf @@ -86,6 +87,37 @@ class LambdaLR(_LRScheduler): self.last_epoch = last_epoch super(LambdaLR, self).__init__(optimizer, last_epoch) + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The learning rate lambda functions will only be saved if they are callable objects + and not if they are functions or lambdas. + """ + state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')} + state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas) + + for idx, fn in enumerate(self.lr_lambdas): + if not isinstance(fn, types.FunctionType): + state_dict['lr_lambdas'][idx] = fn.__dict__.copy() + + return state_dict + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Arguments: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + lr_lambdas = state_dict.pop('lr_lambdas') + self.__dict__.update(state_dict) + + for idx, fn in enumerate(lr_lambdas): + if fn is not None: + self.lr_lambdas[idx].__dict__.update(fn) + def get_lr(self): return [base_lr * lmbda(self.last_epoch) for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)] |