summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/source/optim.rst4
-rw-r--r--test/test_optim.py45
-rw-r--r--torch/optim/lr_scheduler.py38
3 files changed, 69 insertions, 18 deletions
diff --git a/docs/source/optim.rst b/docs/source/optim.rst
index 2125d043d1..f44f51a8b8 100644
--- a/docs/source/optim.rst
+++ b/docs/source/optim.rst
@@ -130,7 +130,7 @@ How to adjust Learning Rate
---------------------------
:mod:`torch.optim.lr_scheduler` provides several methods to adjust the learning
-rate based on the number of epoches. :class:`torch.optim.lr_scheduler.ReduceLROnPlateau`
+rate based on the number of epochs. :class:`torch.optim.lr_scheduler.ReduceLROnPlateau`
allows dynamic learning rate reducing based on some validation measurements.
.. autoclass:: torch.optim.lr_scheduler.LambdaLR
@@ -141,5 +141,7 @@ allows dynamic learning rate reducing based on some validation measurements.
:members:
.. autoclass:: torch.optim.lr_scheduler.ExponentialLR
:members:
+.. autoclass:: torch.optim.lr_scheduler.CosineAnnealingLR
+ :members:
.. autoclass:: torch.optim.lr_scheduler.ReduceLROnPlateau
:members:
diff --git a/test/test_optim.py b/test/test_optim.py
index 65e295ae9d..4afa3ad3e2 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -1,3 +1,4 @@
+import math
import unittest
import functools
from copy import deepcopy
@@ -8,7 +9,7 @@ import torch.nn.functional as F
from torch.optim import SGD
from torch.autograd import Variable
from torch import sparse
-from torch.optim.lr_scheduler import LambdaLR, StepLR, MultiStepLR, ExponentialLR, ReduceLROnPlateau
+from torch.optim.lr_scheduler import LambdaLR, StepLR, MultiStepLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau
from common import TestCase, run_tests
@@ -460,10 +461,10 @@ class TestLRScheduler(TestCase):
# lr = 0.05 if epoch < 3
# lr = 0.005 if 30 <= epoch < 6
# lr = 0.0005 if epoch >= 9
+ epochs = 10
single_targets = [0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3
- targets = [single_targets, list(map(lambda x: x * 10, single_targets))]
+ targets = [single_targets, list(map(lambda x: x * epochs, single_targets))]
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
- epochs = 10
self._test(scheduler, targets, epochs)
def test_multi_step_lr(self):
@@ -471,106 +472,116 @@ class TestLRScheduler(TestCase):
# lr = 0.005 if 2 <= epoch < 5
# lr = 0.0005 if epoch < 9
# lr = 0.00005 if epoch >= 9
+ epochs = 10
single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3
- targets = [single_targets, list(map(lambda x: x * 10, single_targets))]
+ targets = [single_targets, list(map(lambda x: x * epochs, single_targets))]
scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
- epochs = 10
self._test(scheduler, targets, epochs)
def test_exp_lr(self):
- single_targets = [0.05 * (0.9 ** x) for x in range(10)]
- targets = [single_targets, list(map(lambda x: x * 10, single_targets))]
+ epochs = 10
+ single_targets = [0.05 * (0.9 ** x) for x in range(epochs)]
+ targets = [single_targets, list(map(lambda x: x * epochs, single_targets))]
scheduler = ExponentialLR(self.opt, gamma=0.9)
+ self._test(scheduler, targets, epochs)
+
+ def test_cos_anneal_lr(self):
epochs = 10
+ eta_min = 1e-10
+ single_targets = [eta_min + (0.05 - eta_min) *
+ (1 + math.cos(x / epochs * math.pi)) / 2
+ for x in range(epochs)]
+ targets = [single_targets, list(map(lambda x: x * epochs, single_targets))]
+ scheduler = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
self._test(scheduler, targets, epochs)
def test_reduce_lr_on_plateau1(self):
+ epochs = 10
for param_group in self.opt.param_groups:
param_group['lr'] = 0.5
targets = [[0.5] * 20]
metrics = [10 - i * 0.0167 for i in range(20)]
scheduler = ReduceLROnPlateau(self.opt, threshold_mode='abs', mode='min',
threshold=0.01, patience=5, cooldown=5)
- epochs = 10
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
def test_reduce_lr_on_plateau2(self):
+ epochs = 22
for param_group in self.opt.param_groups:
param_group['lr'] = 0.5
targets = [[0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2]
metrics = [10 - i * 0.0165 for i in range(22)]
scheduler = ReduceLROnPlateau(self.opt, patience=5, cooldown=0, threshold_mode='abs',
mode='min', threshold=0.1)
- epochs = 22
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
def test_reduce_lr_on_plateau3(self):
+ epochs = 22
for param_group in self.opt.param_groups:
param_group['lr'] = 0.5
targets = [[0.5] * (2 + 6) + [0.05] * (5 + 6) + [0.005] * 4]
metrics = [-0.8] * 2 + [-0.234] * 20
scheduler = ReduceLROnPlateau(self.opt, mode='max', patience=5, cooldown=5,
threshold_mode='abs')
- epochs = 22
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
def test_reduce_lr_on_plateau4(self):
+ epochs = 20
for param_group in self.opt.param_groups:
param_group['lr'] = 0.5
targets = [[0.5] * 20]
metrics = [1.5 * (1.025 ** i) for i in range(20)] # 1.025 > 1.1**0.25
scheduler = ReduceLROnPlateau(self.opt, mode='max', patience=3,
threshold_mode='rel', threshold=0.1)
- epochs = 20
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
def test_reduce_lr_on_plateau5(self):
+ epochs = 20
for param_group in self.opt.param_groups:
param_group['lr'] = 0.5
targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4]
metrics = [1.5 * (1.005 ** i) for i in range(20)]
scheduler = ReduceLROnPlateau(self.opt, mode='max', threshold_mode='rel',
threshold=0.1, patience=5, cooldown=5)
- epochs = 20
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
def test_reduce_lr_on_plateau6(self):
+ epochs = 20
for param_group in self.opt.param_groups:
param_group['lr'] = 0.5
targets = [[0.5] * 20]
metrics = [1.5 * (0.85 ** i) for i in range(20)]
scheduler = ReduceLROnPlateau(self.opt, mode='min', threshold_mode='rel',
threshold=0.1)
- epochs = 20
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
def test_reduce_lr_on_plateau7(self):
+ epochs = 20
for param_group in self.opt.param_groups:
param_group['lr'] = 0.5
targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4]
metrics = [1] * 7 + [0.6] + [0.5] * 12
scheduler = ReduceLROnPlateau(self.opt, mode='min', threshold_mode='rel',
threshold=0.1, patience=5, cooldown=5)
- epochs = 20
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
def test_reduce_lr_on_plateau8(self):
+ epochs = 20
for param_group in self.opt.param_groups:
param_group['lr'] = 0.5
targets = [[0.5] * 6 + [0.4] * 14, [0.5] * 6 + [0.3] * 14]
metrics = [1.5 * (1.005 ** i) for i in range(20)]
scheduler = ReduceLROnPlateau(self.opt, mode='max', threshold_mode='rel', min_lr=[0.4, 0.3],
threshold=0.1, patience=5, cooldown=5)
- epochs = 20
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
def test_lambda_lr(self):
+ epochs = 10
self.opt.param_groups[0]['lr'] = 0.05
self.opt.param_groups[1]['lr'] = 0.4
- targets = [[0.05 * (0.9 ** x) for x in range(10)], [0.4 * (0.8 ** x) for x in range(10)]]
+ targets = [[0.05 * (0.9 ** x) for x in range(epochs)], [0.4 * (0.8 ** x) for x in range(epochs)]]
scheduler = LambdaLR(self.opt,
lr_lambda=[lambda x1: 0.9 ** x1, lambda x2: 0.8 ** x2])
- epochs = 10
self._test(scheduler, targets, epochs)
def _test(self, scheduler, targets, epochs=10):
diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py
index daa0cb2d21..0df5b53283 100644
--- a/torch/optim/lr_scheduler.py
+++ b/torch/optim/lr_scheduler.py
@@ -1,3 +1,4 @@
+import math
from bisect import bisect_right
from .optimizer import Optimizer
@@ -160,6 +161,43 @@ class ExponentialLR(_LRScheduler):
for base_lr in self.base_lrs]
+class CosineAnnealingLR(_LRScheduler):
+ """Set the learning rate of each parameter group using a cosine annealing
+ schedule, where :math:`\eta_{max}` is set to the initial lr and
+ :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
+
+ .. math::
+
+ \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
+ \cos(\frac{T_{cur}}{T_{max}}\pi))
+
+ When last_epoch=-1, sets initial lr as lr.
+
+ It has been proposed in
+ `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
+ implements the cosine annealing part of SGDR, and not the restarts.
+
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ T_max (int): Maximum number of iterations.
+ eta_min (float): Minimum learning rate. Default: 0.
+ last_epoch (int): The index of last epoch. Default: -1.
+
+ .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
+ https://arxiv.org/abs/1608.03983
+ """
+
+ def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1):
+ self.T_max = T_max
+ self.eta_min = eta_min
+ super(CosineAnnealingLR, self).__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ return [self.eta_min + (base_lr - self.eta_min) *
+ (1 + math.cos(self.last_epoch / self.T_max * math.pi)) / 2
+ for base_lr in self.base_lrs]
+
+
class ReduceLROnPlateau(object):
"""Reduce learning rate when a metric has stopped improving.
Models often benefit from reducing the learning rate by a factor