summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/source/optim.rst2
-rw-r--r--test/test_optim.py178
-rw-r--r--torch/optim/lr_scheduler.py214
3 files changed, 393 insertions, 1 deletions
diff --git a/docs/source/optim.rst b/docs/source/optim.rst
index 9ba6c395b9..db8af0f384 100644
--- a/docs/source/optim.rst
+++ b/docs/source/optim.rst
@@ -145,3 +145,5 @@ allows dynamic learning rate reducing based on some validation measurements.
:members:
.. autoclass:: torch.optim.lr_scheduler.ReduceLROnPlateau
:members:
+.. autoclass:: torch.optim.lr_scheduler.CyclicLR
+ :members:
diff --git a/test/test_optim.py b/test/test_optim.py
index 82fee18619..c6c51e4b1e 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -11,7 +11,8 @@ from torch.optim import SGD
from torch.autograd import Variable
from torch import sparse
from torch.optim.lr_scheduler import LambdaLR, StepLR, MultiStepLR, \
- ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, _LRScheduler
+ ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, _LRScheduler, \
+ CyclicLR
from common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests
# load_tests from common_utils is used to automatically filter tests for
@@ -790,6 +791,165 @@ class TestLRScheduler(TestCase):
schedulers[1] = CosineAnnealingLR(self.opt, epochs, eta_min)
self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)
+ def test_cycle_lr_invalid_mode(self):
+ with self.assertRaises(ValueError):
+ scheduler = CyclicLR(self.opt, base_lr=0, max_lr=0, mode="CATS")
+
+ def test_cycle_lr_triangular_mode_one_lr(self):
+ lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]
+ momentum_target = [5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3]
+ lr_targets = [lr_target, lr_target]
+ momentum_targets = [momentum_target, momentum_target]
+ scheduler = CyclicLR(self.opt, base_lr=1, max_lr=5, step_size_up=4,
+ cycle_momentum=True, base_momentum=1, max_momentum=5,
+ mode='triangular')
+ self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
+
+ def test_cycle_lr_triangular_mode_one_lr_no_momentum(self):
+ lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]
+ lr_targets = [lr_target, lr_target]
+ momentum_target = [self.opt.defaults['momentum']] * len(lr_target)
+ momentum_targets = [momentum_target, momentum_target]
+ scheduler = CyclicLR(self.opt, base_lr=1, max_lr=5, step_size_up=4,
+ cycle_momentum=False, mode='triangular')
+ self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
+
+ def test_cycle_lr_triangular2_mode_one_lr(self):
+ lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 1.5, 2.0, 2.5, 3.0, 2.5, 2.0, 1.5,
+ 1, 1.25, 1.50, 1.75, 2.00, 1.75]
+ momentum_target = [5.0, 4.0, 3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 5.0, 4.5, 4.0,
+ 3.5, 3.0, 3.5, 4.0, 4.5, 5.0, 4.75, 4.5, 4.25, 4.0, 4.25]
+ lr_targets = [lr_target, lr_target]
+ momentum_targets = [momentum_target, momentum_target]
+ scheduler = CyclicLR(self.opt, base_lr=1, max_lr=5, step_size_up=4,
+ cycle_momentum=True, base_momentum=1, max_momentum=5,
+ mode='triangular2')
+ self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
+
+ def test_cycle_lr_exp_range_mode_one_lr(self):
+ base_lr, max_lr = 1, 5
+ diff_lr = max_lr - base_lr
+ gamma = 0.9
+ xs = [0, 0.25, 0.5, 0.75, 1, 0.75, 0.50, 0.25, 0, 0.25, 0.5, 0.75, 1]
+ lr_target = list(map(lambda x: base_lr + x[1] * diff_lr * gamma**x[0], enumerate(xs)))
+ momentum_target = list(map(lambda x: max_lr - x[1] * diff_lr * gamma**x[0], enumerate(xs)))
+ lr_targets = [lr_target, lr_target]
+ momentum_targets = [momentum_target, momentum_target]
+ scheduler = CyclicLR(self.opt, base_lr=base_lr,
+ max_lr=max_lr, step_size_up=4,
+ cycle_momentum=True, base_momentum=base_lr, max_momentum=max_lr,
+ mode='exp_range', gamma=gamma)
+ self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
+
+ def test_cycle_lr_triangular_mode(self):
+ lr_target_1 = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]
+ lr_target_2 = list(map(lambda x: x + 1, lr_target_1))
+ lr_targets = [lr_target_1, lr_target_2]
+ momentum_target_1 = [5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3]
+ momentum_target_2 = list(map(lambda x: x + 1, momentum_target_1))
+ momentum_targets = [momentum_target_1, momentum_target_2]
+ scheduler = CyclicLR(self.opt, base_lr=[1, 2], max_lr=[5, 6], step_size_up=4,
+ cycle_momentum=True, base_momentum=[1, 2], max_momentum=[5, 6],
+ mode='triangular')
+ self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1))
+
+ def test_cycle_lr_triangular2_mode(self):
+ lr_target_1 = [1, 2, 3, 4, 5, 4, 3, 2, 1, 1.5, 2.0, 2.5, 3.0, 2.5, 2.0, 1.5, 1,
+ 1.25, 1.50, 1.75, 2.00, 1.75]
+ lr_target_2 = list(map(lambda x: x + 2, lr_target_1))
+ lr_targets = [lr_target_1, lr_target_2]
+ momentum_target_1 = [5.0, 4.0, 3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 5.0, 4.5, 4.0, 3.5,
+ 3.0, 3.5, 4.0, 4.5, 5.0, 4.75, 4.5, 4.25, 4.0, 4.25]
+ momentum_target_2 = list(map(lambda x: x + 2, momentum_target_1))
+ momentum_targets = [momentum_target_1, momentum_target_2]
+ scheduler = CyclicLR(self.opt, base_lr=[1, 3], max_lr=[5, 7], step_size_up=4,
+ cycle_momentum=True, base_momentum=[1, 3], max_momentum=[5, 7],
+ mode='triangular2')
+ self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1))
+
+ def test_cycle_lr_exp_range_mode(self):
+ base_lr_1, max_lr_1 = 1, 5
+ base_lr_2, max_lr_2 = 5, 12
+
+ diff_lr_1 = max_lr_1 - base_lr_1
+ diff_lr_2 = max_lr_2 - base_lr_2
+
+ gamma = 0.9
+ xs = [0, 0.25, 0.5, 0.75, 1, 0.75, 0.50, 0.25, 0, 0.25, 0.5, 0.75, 1]
+ lr_target_1 = list(map(lambda x: base_lr_1 + x[1] * diff_lr_1 * gamma**x[0], enumerate(xs)))
+ lr_target_2 = list(map(lambda x: base_lr_2 + x[1] * diff_lr_2 * gamma**x[0], enumerate(xs)))
+ lr_targets = [lr_target_1, lr_target_2]
+ momentum_target_1 = list(map(lambda x: max_lr_1 - x[1] * diff_lr_1 * gamma**x[0], enumerate(xs)))
+ momentum_target_2 = list(map(lambda x: max_lr_2 - x[1] * diff_lr_2 * gamma**x[0], enumerate(xs)))
+ momentum_targets = [momentum_target_1, momentum_target_2]
+ scheduler = CyclicLR(self.opt, base_lr=[base_lr_1, base_lr_2],
+ max_lr=[max_lr_1, max_lr_2], step_size_up=4,
+ cycle_momentum=True, base_momentum=[base_lr_1, base_lr_2],
+ max_momentum=[max_lr_1, max_lr_2],
+ mode='exp_range', gamma=gamma)
+ self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1))
+
+ def test_cycle_lr_triangular_mode_step_size_up_down(self):
+ lr_target = [1.0, 2.0, 3.0, 4.0, 5.0, 13.0 / 3, 11.0 / 3, 9.0 / 3, 7.0 / 3, 5.0 / 3, 1.0]
+ lr_targets = [lr_target, lr_target]
+ momentum_target = [5.0, 4.0, 3.0, 2.0, 1.0, 5.0 / 3, 7.0 / 3, 3.0, 11.0 / 3, 13.0 / 3, 5.0]
+ momentum_targets = [momentum_target, momentum_target]
+
+ scheduler = CyclicLR(self.opt, base_lr=1, max_lr=5,
+ step_size_up=4,
+ step_size_down=6,
+ cycle_momentum=True,
+ base_momentum=1, max_momentum=5,
+ mode='triangular')
+ self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
+
+ def test_cycle_lr_triangular2_mode_step_size_up_down(self):
+ lr_base_target = ([
+ 1.0, 3.0, 5.0, 13.0 / 3, 11.0 / 3, 9.0 / 3, 7.0 / 3, 5.0 / 3, 1.0, 2.0, 3.0, 8.0 / 3,
+ 7.0 / 3, 6.0 / 3, 5.0 / 3, 4.0 / 3, 1.0, 3.0 / 2, 2.0, 11.0 / 6, 10.0 / 6, 9.0 / 6,
+ 8.0 / 6, 7.0 / 6
+ ])
+ momentum_base_target = ([
+ 5.0, 3.0, 1.0, 5.0 / 3, 7.0 / 3, 3.0, 11.0 / 3, 13.0 / 3, 5.0, 4.0, 3.0, 10.0 / 3,
+ 11.0 / 3, 4.0, 13.0 / 3, 14.0 / 3, 5.0, 4.5, 4.0, 25.0 / 6, 13.0 / 3, 4.5, 14.0 / 3,
+ 29.0 / 6
+ ])
+ deltas = [2 * i for i in range(0, 2)]
+ base_lrs = [1 + delta for delta in deltas]
+ max_lrs = [5 + delta for delta in deltas]
+ lr_targets = [[x + delta for x in lr_base_target] for delta in deltas]
+ momentum_targets = [[x + delta for x in momentum_base_target] for delta in deltas]
+ scheduler = CyclicLR(
+ self.opt,
+ base_lr=base_lrs,
+ max_lr=max_lrs,
+ step_size_up=2,
+ step_size_down=6,
+ cycle_momentum=True,
+ base_momentum=base_lrs,
+ max_momentum=max_lrs,
+ mode='triangular2')
+ self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_base_target))
+
+ def test_cycle_lr_exp_range_mode_step_size_up_down(self):
+ base_lr, max_lr = 1, 5
+ diff_lr = max_lr - base_lr
+ gamma = 0.9
+ xs = ([
+ 0.0, 0.5, 1.0, 5.0 / 6, 4.0 / 6, 3.0 / 6, 2.0 / 6, 1.0 / 6, 0.0, 0.5, 1.0, 5.0 / 6,
+ 4.0 / 6
+ ])
+ lr_target = [base_lr + x * diff_lr * gamma**i for i, x in enumerate(xs)]
+ lr_targets = [lr_target, lr_target]
+ momentum_target = [max_lr - x * diff_lr * gamma**i for i, x in enumerate(xs)]
+ momentum_targets = [momentum_target, momentum_target]
+ scheduler = CyclicLR(self.opt, base_lr=base_lr, max_lr=max_lr,
+ step_size_up=2, step_size_down=6,
+ cycle_momentum=True, base_momentum=base_lr,
+ max_momentum=max_lr,
+ mode='exp_range', gamma=gamma)
+ self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
+
def test_lambda_lr(self):
epochs = 10
self.opt.param_groups[0]['lr'] = 0.05
@@ -905,5 +1065,21 @@ class TestLRScheduler(TestCase):
msg='LR is wrong in epoch {}: expected {}, got {}'.format(
epoch, target[epoch], param_group['lr']), delta=1e-5)
+ def _test_cycle_lr(self, scheduler, lr_targets, momentum_targets, batch_iterations, verbose=False):
+ for batch_num in range(batch_iterations):
+ scheduler.step(batch_num)
+ if verbose:
+ print('batch{}:\tlr={},momentum={}'.format(batch_num, self.opt.param_groups[0]['lr'],
+ self.opt.param_groups[0]['momentum']))
+ for param_group, lr_target, momentum_target in zip(self.opt.param_groups, lr_targets, momentum_targets):
+ self.assertAlmostEqual(
+ lr_target[batch_num], param_group['lr'],
+ msg='LR is wrong in batch_num {}: expected {}, got {}'.format(
+ batch_num, lr_target[batch_num], param_group['lr']), delta=1e-5)
+ self.assertAlmostEqual(
+ momentum_target[batch_num], param_group['momentum'],
+ msg='Momentum is wrong in batch_num {}: expected {}, got {}'.format(
+ batch_num, momentum_target[batch_num], param_group['momentum']), delta=1e-5)
+
if __name__ == '__main__':
run_tests()
diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py
index 200e2c6ecf..36507940e9 100644
--- a/torch/optim/lr_scheduler.py
+++ b/torch/optim/lr_scheduler.py
@@ -4,6 +4,7 @@ import torch
from torch._six import inf
from collections import Counter
from functools import partial
+
from .optimizer import Optimizer
@@ -427,3 +428,216 @@ class ReduceLROnPlateau(object):
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode)
+
+
+class CyclicLR(_LRScheduler):
+ """Sets the learning rate of each parameter group according to
+ cyclical learning rate policy (CLR). The policy cycles the learning
+ rate between two boundaries with a constant frequency, as detailed in
+ the paper `Cyclical Learning Rates for Training Neural Networks`_.
+ The distance between the two boundaries can be scaled on a per-iteration
+ or per-cycle basis.
+
+ Cyclical learning rate policy changes the learning rate after every batch.
+ `step` should be called after a batch has been used for training.
+
+ This class has three built-in policies, as put forth in the paper:
+ "triangular":
+ A basic triangular cycle w/ no amplitude scaling.
+ "triangular2":
+ A basic triangular cycle that scales initial amplitude by half each cycle.
+ "exp_range":
+ A cycle that scales initial amplitude by gamma**(cycle iterations) at each
+ cycle iteration.
+
+ This implementation was adapted from the github repo: `bckenstler/CLR`_
+
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ base_lr (float or list): Initial learning rate which is the
+ lower boundary in the cycle for each parameter group.
+ max_lr (float or list): Upper learning rate boundaries in the cycle
+ for each parameter group. Functionally,
+ it defines the cycle amplitude (max_lr - base_lr).
+ The lr at any cycle is the sum of base_lr
+ and some scaling of the amplitude; therefore
+ max_lr may not actually be reached depending on
+ scaling function.
+ step_size_up (int): Number of training iterations in the
+ increasing half of a cycle. Default: 2000
+ step_size_down (int): Number of training iterations in the
+ decreasing half of a cycle. If step_size_down is None,
+ it is set to step_size_up. Default: None
+ mode (str): One of {triangular, triangular2, exp_range}.
+ Values correspond to policies detailed above.
+ If scale_fn is not None, this argument is ignored.
+ Default: 'triangular'
+ gamma (float): Constant in 'exp_range' scaling function:
+ gamma**(cycle iterations)
+ Default: 1.0
+ scale_fn (function): Custom scaling policy defined by a single
+ argument lambda function, where
+ 0 <= scale_fn(x) <= 1 for all x >= 0.
+ If specified, then 'mode' is ignored.
+ Default: None
+ scale_mode (str): {'cycle', 'iterations'}.
+ Defines whether scale_fn is evaluated on
+ cycle number or cycle iterations (training
+ iterations since start of cycle).
+ Default: 'cycle'
+ cycle_momentum (bool): If ``True``, momentum is cycled inversely
+ to learning rate between 'base_momentum' and 'max_momentum'.
+ Default: True
+ base_momentum (float or list): Initial momentum which is the
+ lower boundary in the cycle for each parameter group.
+ Default: 0.8
+ max_momentum (float or list): Upper momentum boundaries in the cycle
+ for each parameter group. Functionally,
+ it defines the cycle amplitude (max_momentum - base_momentum).
+ The momentum at any cycle is the difference of max_momentum
+ and some scaling of the amplitude; therefore
+ base_momentum may not actually be reached depending on
+ scaling function. Default: 0.9
+ last_epoch (int): The index of the last batch. This parameter is used when
+ resuming a training job. Since `step()` should be invoked after each
+ batch instead of after each epoch, this number represents the total
+ number of *batches* computed, not the total number of epochs computed.
+ When last_epoch=-1, the schedule is started from the beginning.
+ Default: -1
+
+ Example:
+ >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
+ >>> scheduler = torch.optim.CyclicLR(optimizer)
+ >>> data_loader = torch.utils.data.DataLoader(...)
+ >>> for epoch in range(10):
+ >>> for batch in data_loader:
+ >>> train_batch(...)
+ >>> scheduler.step()
+
+
+ .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
+ .. _bckenstler/CLR: https://github.com/bckenstler/CLR
+ """
+
+ def __init__(self,
+ optimizer,
+ base_lr,
+ max_lr,
+ step_size_up=2000,
+ step_size_down=None,
+ mode='triangular',
+ gamma=1.,
+ scale_fn=None,
+ scale_mode='cycle',
+ cycle_momentum=True,
+ base_momentum=0.8,
+ max_momentum=0.9,
+ last_epoch=-1):
+
+ if not isinstance(optimizer, Optimizer):
+ raise TypeError('{} is not an Optimizer'.format(
+ type(optimizer).__name__))
+ self.optimizer = optimizer
+
+ base_lrs = self._format_param('base_lr', optimizer, base_lr)
+ if last_epoch == -1:
+ for lr, group in zip(base_lrs, optimizer.param_groups):
+ group['lr'] = lr
+
+ self.max_lrs = self._format_param('max_lr', optimizer, max_lr)
+
+ step_size_up = float(step_size_up)
+ step_size_down = float(step_size_down) if step_size_down is not None else step_size_up
+ self.total_size = step_size_up + step_size_down
+ self.step_ratio = step_size_up / self.total_size
+
+ if mode not in ['triangular', 'triangular2', 'exp_range'] \
+ and scale_fn is None:
+ raise ValueError('mode is invalid and scale_fn is None')
+
+ self.mode = mode
+ self.gamma = gamma
+
+ if scale_fn is None:
+ if self.mode == 'triangular':
+ self.scale_fn = self._triangular_scale_fn
+ self.scale_mode = 'cycle'
+ elif self.mode == 'triangular2':
+ self.scale_fn = self._triangular2_scale_fn
+ self.scale_mode = 'cycle'
+ elif self.mode == 'exp_range':
+ self.scale_fn = self._exp_range_scale_fn
+ self.scale_mode = 'iterations'
+ else:
+ self.scale_fn = scale_fn
+ self.scale_mode = scale_mode
+
+ self.cycle_momentum = cycle_momentum
+ if cycle_momentum:
+ if 'momentum' not in optimizer.defaults:
+ raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
+
+ base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
+ if last_epoch == -1:
+ for momentum, group in zip(base_momentums, optimizer.param_groups):
+ group['momentum'] = momentum
+ self.base_momentums = list(map(lambda group: group['momentum'], optimizer.param_groups))
+ self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
+
+ super(CyclicLR, self).__init__(optimizer, last_epoch)
+
+ def _format_param(self, name, optimizer, param):
+ """Return correctly formatted lr/momentum for each param group."""
+ if isinstance(param, (list, tuple)):
+ if len(param) != len(optimizer.param_groups):
+ raise ValueError("expected {} values for {}, got {}".format(
+ len(optimizer.param_groups), name, len(param)))
+ return param
+ else:
+ return [param] * len(optimizer.param_groups)
+
+ def _triangular_scale_fn(self, x):
+ return 1.
+
+ def _triangular2_scale_fn(self, x):
+ return 1 / (2. ** (x - 1))
+
+ def _exp_range_scale_fn(self, x):
+ return self.gamma**(x)
+
+ def get_lr(self):
+ """Calculates the learning rate at batch index. This function treats
+ `self.last_epoch` as the last batch index.
+
+ If `self.cycle_momentum` is ``True``, this function has a side effect of
+ updating the optimizer's momentum.
+ """
+ cycle = math.floor(1 + self.last_epoch / self.total_size)
+ x = 1. + self.last_epoch / self.total_size - cycle
+ if x <= self.step_ratio:
+ scale_factor = x / self.step_ratio
+ else:
+ scale_factor = (x - 1) / (self.step_ratio - 1)
+
+ lrs = []
+ for base_lr, max_lr in zip(self.base_lrs, self.max_lrs):
+ base_height = (max_lr - base_lr) * scale_factor
+ if self.scale_mode == 'cycle':
+ lr = base_lr + base_height * self.scale_fn(cycle)
+ else:
+ lr = base_lr + base_height * self.scale_fn(self.last_epoch)
+ lrs.append(lr)
+
+ if self.cycle_momentum:
+ momentums = []
+ for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums):
+ base_height = (max_momentum - base_momentum) * scale_factor
+ if self.scale_mode == 'cycle':
+ momentum = max_momentum - base_height * self.scale_fn(cycle)
+ else:
+ momentum = max_momentum - base_height * self.scale_fn(self.last_epoch)
+ momentums.append(momentum)
+ for param_group, momentum in zip(self.optimizer.param_groups, momentums):
+ param_group['momentum'] = momentum
+
+ return lrs