summaryrefslogtreecommitdiff
path: root/torch/optim
diff options
context:
space:
mode:
author0phoff <0phoff@users.noreply.github.com>2018-07-31 19:30:20 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-07-31 19:39:06 -0700
commit294c06538416f5882e9f329695dcae7ff77c218b (patch)
treefcb3e83a2673d736a14475ebdd32ac4fbfa953ed /torch/optim
parentaae37324cc66b80644631bcc8cadcfe741cab625 (diff)
downloadpytorch-294c06538416f5882e9f329695dcae7ff77c218b.tar.gz
pytorch-294c06538416f5882e9f329695dcae7ff77c218b.tar.bz2
pytorch-294c06538416f5882e9f329695dcae7ff77c218b.zip
Changed serialization mechanism of LambdaLR scheduler (#9927)
Summary: I opened an issue explaining some of my frustrations with the current state of schedulers. While most points that I raised in [that issue](https://github.com/pytorch/pytorch/issues/8741#issuecomment-404449697) need to be discussed more thoroughly before being implemented, there are some that are not so difficult to fix. This PR changes the way the LambdaLR scheduler gets serialized: > The lr_lambda functions are only saved if the are callable objects (which can be stateful). > There is no point in saving functions/lambdas as you need their definition before unpickling and they are stateless. This has the big advantage that the scheduler is serializable, even if you use lambda functions or locally defined functions (aka a function in a function). Does this functionality need any unit tests? Pull Request resolved: https://github.com/pytorch/pytorch/pull/9927 Differential Revision: D9055505 Pulled By: soumith fbshipit-source-id: 6c1cec588beedd098ec7d2bce6a9add27f29e48f
Diffstat (limited to 'torch/optim')
-rw-r--r--torch/optim/lr_scheduler.py32
1 files changed, 32 insertions, 0 deletions
diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py
index ad7f780719..96cfaff868 100644
--- a/torch/optim/lr_scheduler.py
+++ b/torch/optim/lr_scheduler.py
@@ -1,3 +1,4 @@
+import types
import math
import torch
from torch._six import inf
@@ -86,6 +87,37 @@ class LambdaLR(_LRScheduler):
self.last_epoch = last_epoch
super(LambdaLR, self).__init__(optimizer, last_epoch)
+ def state_dict(self):
+ """Returns the state of the scheduler as a :class:`dict`.
+
+ It contains an entry for every variable in self.__dict__ which
+ is not the optimizer.
+ The learning rate lambda functions will only be saved if they are callable objects
+ and not if they are functions or lambdas.
+ """
+ state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')}
+ state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)
+
+ for idx, fn in enumerate(self.lr_lambdas):
+ if not isinstance(fn, types.FunctionType):
+ state_dict['lr_lambdas'][idx] = fn.__dict__.copy()
+
+ return state_dict
+
+ def load_state_dict(self, state_dict):
+ """Loads the schedulers state.
+
+ Arguments:
+ state_dict (dict): scheduler state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ lr_lambdas = state_dict.pop('lr_lambdas')
+ self.__dict__.update(state_dict)
+
+ for idx, fn in enumerate(lr_lambdas):
+ if fn is not None:
+ self.lr_lambdas[idx].__dict__.update(fn)
+
def get_lr(self):
return [base_lr * lmbda(self.last_epoch)
for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]