diff options
-rw-r--r-- | docs/source/optim.rst | 4 | ||||
-rw-r--r-- | test/test_optim.py | 45 | ||||
-rw-r--r-- | torch/optim/lr_scheduler.py | 38 |
3 files changed, 69 insertions, 18 deletions
diff --git a/docs/source/optim.rst b/docs/source/optim.rst index 2125d043d1..f44f51a8b8 100644 --- a/docs/source/optim.rst +++ b/docs/source/optim.rst @@ -130,7 +130,7 @@ How to adjust Learning Rate --------------------------- :mod:`torch.optim.lr_scheduler` provides several methods to adjust the learning -rate based on the number of epoches. :class:`torch.optim.lr_scheduler.ReduceLROnPlateau` +rate based on the number of epochs. :class:`torch.optim.lr_scheduler.ReduceLROnPlateau` allows dynamic learning rate reducing based on some validation measurements. .. autoclass:: torch.optim.lr_scheduler.LambdaLR @@ -141,5 +141,7 @@ allows dynamic learning rate reducing based on some validation measurements. :members: .. autoclass:: torch.optim.lr_scheduler.ExponentialLR :members: +.. autoclass:: torch.optim.lr_scheduler.CosineAnnealingLR + :members: .. autoclass:: torch.optim.lr_scheduler.ReduceLROnPlateau :members: diff --git a/test/test_optim.py b/test/test_optim.py index 65e295ae9d..4afa3ad3e2 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -1,3 +1,4 @@ +import math import unittest import functools from copy import deepcopy @@ -8,7 +9,7 @@ import torch.nn.functional as F from torch.optim import SGD from torch.autograd import Variable from torch import sparse -from torch.optim.lr_scheduler import LambdaLR, StepLR, MultiStepLR, ExponentialLR, ReduceLROnPlateau +from torch.optim.lr_scheduler import LambdaLR, StepLR, MultiStepLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau from common import TestCase, run_tests @@ -460,10 +461,10 @@ class TestLRScheduler(TestCase): # lr = 0.05 if epoch < 3 # lr = 0.005 if 30 <= epoch < 6 # lr = 0.0005 if epoch >= 9 + epochs = 10 single_targets = [0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3 - targets = [single_targets, list(map(lambda x: x * 10, single_targets))] + targets = [single_targets, list(map(lambda x: x * epochs, single_targets))] scheduler = StepLR(self.opt, gamma=0.1, step_size=3) - epochs = 10 self._test(scheduler, targets, epochs) def test_multi_step_lr(self): @@ -471,106 +472,116 @@ class TestLRScheduler(TestCase): # lr = 0.005 if 2 <= epoch < 5 # lr = 0.0005 if epoch < 9 # lr = 0.00005 if epoch >= 9 + epochs = 10 single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3 - targets = [single_targets, list(map(lambda x: x * 10, single_targets))] + targets = [single_targets, list(map(lambda x: x * epochs, single_targets))] scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) - epochs = 10 self._test(scheduler, targets, epochs) def test_exp_lr(self): - single_targets = [0.05 * (0.9 ** x) for x in range(10)] - targets = [single_targets, list(map(lambda x: x * 10, single_targets))] + epochs = 10 + single_targets = [0.05 * (0.9 ** x) for x in range(epochs)] + targets = [single_targets, list(map(lambda x: x * epochs, single_targets))] scheduler = ExponentialLR(self.opt, gamma=0.9) + self._test(scheduler, targets, epochs) + + def test_cos_anneal_lr(self): epochs = 10 + eta_min = 1e-10 + single_targets = [eta_min + (0.05 - eta_min) * + (1 + math.cos(x / epochs * math.pi)) / 2 + for x in range(epochs)] + targets = [single_targets, list(map(lambda x: x * epochs, single_targets))] + scheduler = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) self._test(scheduler, targets, epochs) def test_reduce_lr_on_plateau1(self): + epochs = 10 for param_group in self.opt.param_groups: param_group['lr'] = 0.5 targets = [[0.5] * 20] metrics = [10 - i * 0.0167 for i in range(20)] scheduler = ReduceLROnPlateau(self.opt, threshold_mode='abs', mode='min', threshold=0.01, patience=5, cooldown=5) - epochs = 10 self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) def test_reduce_lr_on_plateau2(self): + epochs = 22 for param_group in self.opt.param_groups: param_group['lr'] = 0.5 targets = [[0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2] metrics = [10 - i * 0.0165 for i in range(22)] scheduler = ReduceLROnPlateau(self.opt, patience=5, cooldown=0, threshold_mode='abs', mode='min', threshold=0.1) - epochs = 22 self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) def test_reduce_lr_on_plateau3(self): + epochs = 22 for param_group in self.opt.param_groups: param_group['lr'] = 0.5 targets = [[0.5] * (2 + 6) + [0.05] * (5 + 6) + [0.005] * 4] metrics = [-0.8] * 2 + [-0.234] * 20 scheduler = ReduceLROnPlateau(self.opt, mode='max', patience=5, cooldown=5, threshold_mode='abs') - epochs = 22 self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) def test_reduce_lr_on_plateau4(self): + epochs = 20 for param_group in self.opt.param_groups: param_group['lr'] = 0.5 targets = [[0.5] * 20] metrics = [1.5 * (1.025 ** i) for i in range(20)] # 1.025 > 1.1**0.25 scheduler = ReduceLROnPlateau(self.opt, mode='max', patience=3, threshold_mode='rel', threshold=0.1) - epochs = 20 self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) def test_reduce_lr_on_plateau5(self): + epochs = 20 for param_group in self.opt.param_groups: param_group['lr'] = 0.5 targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4] metrics = [1.5 * (1.005 ** i) for i in range(20)] scheduler = ReduceLROnPlateau(self.opt, mode='max', threshold_mode='rel', threshold=0.1, patience=5, cooldown=5) - epochs = 20 self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) def test_reduce_lr_on_plateau6(self): + epochs = 20 for param_group in self.opt.param_groups: param_group['lr'] = 0.5 targets = [[0.5] * 20] metrics = [1.5 * (0.85 ** i) for i in range(20)] scheduler = ReduceLROnPlateau(self.opt, mode='min', threshold_mode='rel', threshold=0.1) - epochs = 20 self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) def test_reduce_lr_on_plateau7(self): + epochs = 20 for param_group in self.opt.param_groups: param_group['lr'] = 0.5 targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4] metrics = [1] * 7 + [0.6] + [0.5] * 12 scheduler = ReduceLROnPlateau(self.opt, mode='min', threshold_mode='rel', threshold=0.1, patience=5, cooldown=5) - epochs = 20 self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) def test_reduce_lr_on_plateau8(self): + epochs = 20 for param_group in self.opt.param_groups: param_group['lr'] = 0.5 targets = [[0.5] * 6 + [0.4] * 14, [0.5] * 6 + [0.3] * 14] metrics = [1.5 * (1.005 ** i) for i in range(20)] scheduler = ReduceLROnPlateau(self.opt, mode='max', threshold_mode='rel', min_lr=[0.4, 0.3], threshold=0.1, patience=5, cooldown=5) - epochs = 20 self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) def test_lambda_lr(self): + epochs = 10 self.opt.param_groups[0]['lr'] = 0.05 self.opt.param_groups[1]['lr'] = 0.4 - targets = [[0.05 * (0.9 ** x) for x in range(10)], [0.4 * (0.8 ** x) for x in range(10)]] + targets = [[0.05 * (0.9 ** x) for x in range(epochs)], [0.4 * (0.8 ** x) for x in range(epochs)]] scheduler = LambdaLR(self.opt, lr_lambda=[lambda x1: 0.9 ** x1, lambda x2: 0.8 ** x2]) - epochs = 10 self._test(scheduler, targets, epochs) def _test(self, scheduler, targets, epochs=10): diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index daa0cb2d21..0df5b53283 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -1,3 +1,4 @@ +import math from bisect import bisect_right from .optimizer import Optimizer @@ -160,6 +161,43 @@ class ExponentialLR(_LRScheduler): for base_lr in self.base_lrs] +class CosineAnnealingLR(_LRScheduler): + """Set the learning rate of each parameter group using a cosine annealing + schedule, where :math:`\eta_{max}` is set to the initial lr and + :math:`T_{cur}` is the number of epochs since the last restart in SGDR: + + .. math:: + + \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 + + \cos(\frac{T_{cur}}{T_{max}}\pi)) + + When last_epoch=-1, sets initial lr as lr. + + It has been proposed in + `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only + implements the cosine annealing part of SGDR, and not the restarts. + + Args: + optimizer (Optimizer): Wrapped optimizer. + T_max (int): Maximum number of iterations. + eta_min (float): Minimum learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + + .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: + https://arxiv.org/abs/1608.03983 + """ + + def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1): + self.T_max = T_max + self.eta_min = eta_min + super(CosineAnnealingLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + return [self.eta_min + (base_lr - self.eta_min) * + (1 + math.cos(self.last_epoch / self.T_max * math.pi)) / 2 + for base_lr in self.base_lrs] + + class ReduceLROnPlateau(object): """Reduce learning rate when a metric has stopped improving. Models often benefit from reducing the learning rate by a factor |