diff options
author | Kai Arulkumaran <Kaixhin@users.noreply.github.com> | 2017-12-18 07:43:08 +0000 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-12-18 02:43:08 -0500 |
commit | e9ef20eab5e5cf361bdc7a425c7f8b873baad9d3 (patch) | |
tree | 8e4564b78f2e5880d40d21db453b59415808233f /torch/optim | |
parent | 847c56aeb5857fc4d3f5df88b9e8f937939bb8cc (diff) | |
download | pytorch-e9ef20eab5e5cf361bdc7a425c7f8b873baad9d3.tar.gz pytorch-e9ef20eab5e5cf361bdc7a425c7f8b873baad9d3.tar.bz2 pytorch-e9ef20eab5e5cf361bdc7a425c7f8b873baad9d3.zip |
Add Cosine Annealing LR Scheduler (#3311)
* Add Cosine Annealing LR Scheduler
* Update eta_min in tests to prevent numerical mistakes
* Use non-zero min_eta in test_cos_anneal_lr
Diffstat (limited to 'torch/optim')
-rw-r--r-- | torch/optim/lr_scheduler.py | 38 |
1 files changed, 38 insertions, 0 deletions
diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index daa0cb2d21..0df5b53283 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -1,3 +1,4 @@ +import math from bisect import bisect_right from .optimizer import Optimizer @@ -160,6 +161,43 @@ class ExponentialLR(_LRScheduler): for base_lr in self.base_lrs] +class CosineAnnealingLR(_LRScheduler): + """Set the learning rate of each parameter group using a cosine annealing + schedule, where :math:`\eta_{max}` is set to the initial lr and + :math:`T_{cur}` is the number of epochs since the last restart in SGDR: + + .. math:: + + \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 + + \cos(\frac{T_{cur}}{T_{max}}\pi)) + + When last_epoch=-1, sets initial lr as lr. + + It has been proposed in + `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only + implements the cosine annealing part of SGDR, and not the restarts. + + Args: + optimizer (Optimizer): Wrapped optimizer. + T_max (int): Maximum number of iterations. + eta_min (float): Minimum learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + + .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: + https://arxiv.org/abs/1608.03983 + """ + + def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1): + self.T_max = T_max + self.eta_min = eta_min + super(CosineAnnealingLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + return [self.eta_min + (base_lr - self.eta_min) * + (1 + math.cos(self.last_epoch / self.T_max * math.pi)) / 2 + for base_lr in self.base_lrs] + + class ReduceLROnPlateau(object): """Reduce learning rate when a metric has stopped improving. Models often benefit from reducing the learning rate by a factor |