diff options
-rw-r--r-- | docs/source/optim.rst | 2 | ||||
-rw-r--r-- | test/test_optim.py | 178 | ||||
-rw-r--r-- | torch/optim/lr_scheduler.py | 214 |
3 files changed, 393 insertions, 1 deletions
diff --git a/docs/source/optim.rst b/docs/source/optim.rst index 9ba6c395b9..db8af0f384 100644 --- a/docs/source/optim.rst +++ b/docs/source/optim.rst @@ -145,3 +145,5 @@ allows dynamic learning rate reducing based on some validation measurements. :members: .. autoclass:: torch.optim.lr_scheduler.ReduceLROnPlateau :members: +.. autoclass:: torch.optim.lr_scheduler.CyclicLR + :members: diff --git a/test/test_optim.py b/test/test_optim.py index 82fee18619..c6c51e4b1e 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -11,7 +11,8 @@ from torch.optim import SGD from torch.autograd import Variable from torch import sparse from torch.optim.lr_scheduler import LambdaLR, StepLR, MultiStepLR, \ - ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, _LRScheduler + ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, _LRScheduler, \ + CyclicLR from common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests # load_tests from common_utils is used to automatically filter tests for @@ -790,6 +791,165 @@ class TestLRScheduler(TestCase): schedulers[1] = CosineAnnealingLR(self.opt, epochs, eta_min) self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) + def test_cycle_lr_invalid_mode(self): + with self.assertRaises(ValueError): + scheduler = CyclicLR(self.opt, base_lr=0, max_lr=0, mode="CATS") + + def test_cycle_lr_triangular_mode_one_lr(self): + lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3] + momentum_target = [5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3] + lr_targets = [lr_target, lr_target] + momentum_targets = [momentum_target, momentum_target] + scheduler = CyclicLR(self.opt, base_lr=1, max_lr=5, step_size_up=4, + cycle_momentum=True, base_momentum=1, max_momentum=5, + mode='triangular') + self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) + + def test_cycle_lr_triangular_mode_one_lr_no_momentum(self): + lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3] + lr_targets = [lr_target, lr_target] + momentum_target = [self.opt.defaults['momentum']] * len(lr_target) + momentum_targets = [momentum_target, momentum_target] + scheduler = CyclicLR(self.opt, base_lr=1, max_lr=5, step_size_up=4, + cycle_momentum=False, mode='triangular') + self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) + + def test_cycle_lr_triangular2_mode_one_lr(self): + lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 1.5, 2.0, 2.5, 3.0, 2.5, 2.0, 1.5, + 1, 1.25, 1.50, 1.75, 2.00, 1.75] + momentum_target = [5.0, 4.0, 3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 5.0, 4.5, 4.0, + 3.5, 3.0, 3.5, 4.0, 4.5, 5.0, 4.75, 4.5, 4.25, 4.0, 4.25] + lr_targets = [lr_target, lr_target] + momentum_targets = [momentum_target, momentum_target] + scheduler = CyclicLR(self.opt, base_lr=1, max_lr=5, step_size_up=4, + cycle_momentum=True, base_momentum=1, max_momentum=5, + mode='triangular2') + self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) + + def test_cycle_lr_exp_range_mode_one_lr(self): + base_lr, max_lr = 1, 5 + diff_lr = max_lr - base_lr + gamma = 0.9 + xs = [0, 0.25, 0.5, 0.75, 1, 0.75, 0.50, 0.25, 0, 0.25, 0.5, 0.75, 1] + lr_target = list(map(lambda x: base_lr + x[1] * diff_lr * gamma**x[0], enumerate(xs))) + momentum_target = list(map(lambda x: max_lr - x[1] * diff_lr * gamma**x[0], enumerate(xs))) + lr_targets = [lr_target, lr_target] + momentum_targets = [momentum_target, momentum_target] + scheduler = CyclicLR(self.opt, base_lr=base_lr, + max_lr=max_lr, step_size_up=4, + cycle_momentum=True, base_momentum=base_lr, max_momentum=max_lr, + mode='exp_range', gamma=gamma) + self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) + + def test_cycle_lr_triangular_mode(self): + lr_target_1 = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3] + lr_target_2 = list(map(lambda x: x + 1, lr_target_1)) + lr_targets = [lr_target_1, lr_target_2] + momentum_target_1 = [5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3] + momentum_target_2 = list(map(lambda x: x + 1, momentum_target_1)) + momentum_targets = [momentum_target_1, momentum_target_2] + scheduler = CyclicLR(self.opt, base_lr=[1, 2], max_lr=[5, 6], step_size_up=4, + cycle_momentum=True, base_momentum=[1, 2], max_momentum=[5, 6], + mode='triangular') + self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1)) + + def test_cycle_lr_triangular2_mode(self): + lr_target_1 = [1, 2, 3, 4, 5, 4, 3, 2, 1, 1.5, 2.0, 2.5, 3.0, 2.5, 2.0, 1.5, 1, + 1.25, 1.50, 1.75, 2.00, 1.75] + lr_target_2 = list(map(lambda x: x + 2, lr_target_1)) + lr_targets = [lr_target_1, lr_target_2] + momentum_target_1 = [5.0, 4.0, 3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 5.0, 4.5, 4.0, 3.5, + 3.0, 3.5, 4.0, 4.5, 5.0, 4.75, 4.5, 4.25, 4.0, 4.25] + momentum_target_2 = list(map(lambda x: x + 2, momentum_target_1)) + momentum_targets = [momentum_target_1, momentum_target_2] + scheduler = CyclicLR(self.opt, base_lr=[1, 3], max_lr=[5, 7], step_size_up=4, + cycle_momentum=True, base_momentum=[1, 3], max_momentum=[5, 7], + mode='triangular2') + self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1)) + + def test_cycle_lr_exp_range_mode(self): + base_lr_1, max_lr_1 = 1, 5 + base_lr_2, max_lr_2 = 5, 12 + + diff_lr_1 = max_lr_1 - base_lr_1 + diff_lr_2 = max_lr_2 - base_lr_2 + + gamma = 0.9 + xs = [0, 0.25, 0.5, 0.75, 1, 0.75, 0.50, 0.25, 0, 0.25, 0.5, 0.75, 1] + lr_target_1 = list(map(lambda x: base_lr_1 + x[1] * diff_lr_1 * gamma**x[0], enumerate(xs))) + lr_target_2 = list(map(lambda x: base_lr_2 + x[1] * diff_lr_2 * gamma**x[0], enumerate(xs))) + lr_targets = [lr_target_1, lr_target_2] + momentum_target_1 = list(map(lambda x: max_lr_1 - x[1] * diff_lr_1 * gamma**x[0], enumerate(xs))) + momentum_target_2 = list(map(lambda x: max_lr_2 - x[1] * diff_lr_2 * gamma**x[0], enumerate(xs))) + momentum_targets = [momentum_target_1, momentum_target_2] + scheduler = CyclicLR(self.opt, base_lr=[base_lr_1, base_lr_2], + max_lr=[max_lr_1, max_lr_2], step_size_up=4, + cycle_momentum=True, base_momentum=[base_lr_1, base_lr_2], + max_momentum=[max_lr_1, max_lr_2], + mode='exp_range', gamma=gamma) + self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1)) + + def test_cycle_lr_triangular_mode_step_size_up_down(self): + lr_target = [1.0, 2.0, 3.0, 4.0, 5.0, 13.0 / 3, 11.0 / 3, 9.0 / 3, 7.0 / 3, 5.0 / 3, 1.0] + lr_targets = [lr_target, lr_target] + momentum_target = [5.0, 4.0, 3.0, 2.0, 1.0, 5.0 / 3, 7.0 / 3, 3.0, 11.0 / 3, 13.0 / 3, 5.0] + momentum_targets = [momentum_target, momentum_target] + + scheduler = CyclicLR(self.opt, base_lr=1, max_lr=5, + step_size_up=4, + step_size_down=6, + cycle_momentum=True, + base_momentum=1, max_momentum=5, + mode='triangular') + self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) + + def test_cycle_lr_triangular2_mode_step_size_up_down(self): + lr_base_target = ([ + 1.0, 3.0, 5.0, 13.0 / 3, 11.0 / 3, 9.0 / 3, 7.0 / 3, 5.0 / 3, 1.0, 2.0, 3.0, 8.0 / 3, + 7.0 / 3, 6.0 / 3, 5.0 / 3, 4.0 / 3, 1.0, 3.0 / 2, 2.0, 11.0 / 6, 10.0 / 6, 9.0 / 6, + 8.0 / 6, 7.0 / 6 + ]) + momentum_base_target = ([ + 5.0, 3.0, 1.0, 5.0 / 3, 7.0 / 3, 3.0, 11.0 / 3, 13.0 / 3, 5.0, 4.0, 3.0, 10.0 / 3, + 11.0 / 3, 4.0, 13.0 / 3, 14.0 / 3, 5.0, 4.5, 4.0, 25.0 / 6, 13.0 / 3, 4.5, 14.0 / 3, + 29.0 / 6 + ]) + deltas = [2 * i for i in range(0, 2)] + base_lrs = [1 + delta for delta in deltas] + max_lrs = [5 + delta for delta in deltas] + lr_targets = [[x + delta for x in lr_base_target] for delta in deltas] + momentum_targets = [[x + delta for x in momentum_base_target] for delta in deltas] + scheduler = CyclicLR( + self.opt, + base_lr=base_lrs, + max_lr=max_lrs, + step_size_up=2, + step_size_down=6, + cycle_momentum=True, + base_momentum=base_lrs, + max_momentum=max_lrs, + mode='triangular2') + self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_base_target)) + + def test_cycle_lr_exp_range_mode_step_size_up_down(self): + base_lr, max_lr = 1, 5 + diff_lr = max_lr - base_lr + gamma = 0.9 + xs = ([ + 0.0, 0.5, 1.0, 5.0 / 6, 4.0 / 6, 3.0 / 6, 2.0 / 6, 1.0 / 6, 0.0, 0.5, 1.0, 5.0 / 6, + 4.0 / 6 + ]) + lr_target = [base_lr + x * diff_lr * gamma**i for i, x in enumerate(xs)] + lr_targets = [lr_target, lr_target] + momentum_target = [max_lr - x * diff_lr * gamma**i for i, x in enumerate(xs)] + momentum_targets = [momentum_target, momentum_target] + scheduler = CyclicLR(self.opt, base_lr=base_lr, max_lr=max_lr, + step_size_up=2, step_size_down=6, + cycle_momentum=True, base_momentum=base_lr, + max_momentum=max_lr, + mode='exp_range', gamma=gamma) + self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) + def test_lambda_lr(self): epochs = 10 self.opt.param_groups[0]['lr'] = 0.05 @@ -905,5 +1065,21 @@ class TestLRScheduler(TestCase): msg='LR is wrong in epoch {}: expected {}, got {}'.format( epoch, target[epoch], param_group['lr']), delta=1e-5) + def _test_cycle_lr(self, scheduler, lr_targets, momentum_targets, batch_iterations, verbose=False): + for batch_num in range(batch_iterations): + scheduler.step(batch_num) + if verbose: + print('batch{}:\tlr={},momentum={}'.format(batch_num, self.opt.param_groups[0]['lr'], + self.opt.param_groups[0]['momentum'])) + for param_group, lr_target, momentum_target in zip(self.opt.param_groups, lr_targets, momentum_targets): + self.assertAlmostEqual( + lr_target[batch_num], param_group['lr'], + msg='LR is wrong in batch_num {}: expected {}, got {}'.format( + batch_num, lr_target[batch_num], param_group['lr']), delta=1e-5) + self.assertAlmostEqual( + momentum_target[batch_num], param_group['momentum'], + msg='Momentum is wrong in batch_num {}: expected {}, got {}'.format( + batch_num, momentum_target[batch_num], param_group['momentum']), delta=1e-5) + if __name__ == '__main__': run_tests() diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 200e2c6ecf..36507940e9 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -4,6 +4,7 @@ import torch from torch._six import inf from collections import Counter from functools import partial + from .optimizer import Optimizer @@ -427,3 +428,216 @@ class ReduceLROnPlateau(object): def load_state_dict(self, state_dict): self.__dict__.update(state_dict) self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode) + + +class CyclicLR(_LRScheduler): + """Sets the learning rate of each parameter group according to + cyclical learning rate policy (CLR). The policy cycles the learning + rate between two boundaries with a constant frequency, as detailed in + the paper `Cyclical Learning Rates for Training Neural Networks`_. + The distance between the two boundaries can be scaled on a per-iteration + or per-cycle basis. + + Cyclical learning rate policy changes the learning rate after every batch. + `step` should be called after a batch has been used for training. + + This class has three built-in policies, as put forth in the paper: + "triangular": + A basic triangular cycle w/ no amplitude scaling. + "triangular2": + A basic triangular cycle that scales initial amplitude by half each cycle. + "exp_range": + A cycle that scales initial amplitude by gamma**(cycle iterations) at each + cycle iteration. + + This implementation was adapted from the github repo: `bckenstler/CLR`_ + + Args: + optimizer (Optimizer): Wrapped optimizer. + base_lr (float or list): Initial learning rate which is the + lower boundary in the cycle for each parameter group. + max_lr (float or list): Upper learning rate boundaries in the cycle + for each parameter group. Functionally, + it defines the cycle amplitude (max_lr - base_lr). + The lr at any cycle is the sum of base_lr + and some scaling of the amplitude; therefore + max_lr may not actually be reached depending on + scaling function. + step_size_up (int): Number of training iterations in the + increasing half of a cycle. Default: 2000 + step_size_down (int): Number of training iterations in the + decreasing half of a cycle. If step_size_down is None, + it is set to step_size_up. Default: None + mode (str): One of {triangular, triangular2, exp_range}. + Values correspond to policies detailed above. + If scale_fn is not None, this argument is ignored. + Default: 'triangular' + gamma (float): Constant in 'exp_range' scaling function: + gamma**(cycle iterations) + Default: 1.0 + scale_fn (function): Custom scaling policy defined by a single + argument lambda function, where + 0 <= scale_fn(x) <= 1 for all x >= 0. + If specified, then 'mode' is ignored. + Default: None + scale_mode (str): {'cycle', 'iterations'}. + Defines whether scale_fn is evaluated on + cycle number or cycle iterations (training + iterations since start of cycle). + Default: 'cycle' + cycle_momentum (bool): If ``True``, momentum is cycled inversely + to learning rate between 'base_momentum' and 'max_momentum'. + Default: True + base_momentum (float or list): Initial momentum which is the + lower boundary in the cycle for each parameter group. + Default: 0.8 + max_momentum (float or list): Upper momentum boundaries in the cycle + for each parameter group. Functionally, + it defines the cycle amplitude (max_momentum - base_momentum). + The momentum at any cycle is the difference of max_momentum + and some scaling of the amplitude; therefore + base_momentum may not actually be reached depending on + scaling function. Default: 0.9 + last_epoch (int): The index of the last batch. This parameter is used when + resuming a training job. Since `step()` should be invoked after each + batch instead of after each epoch, this number represents the total + number of *batches* computed, not the total number of epochs computed. + When last_epoch=-1, the schedule is started from the beginning. + Default: -1 + + Example: + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> scheduler = torch.optim.CyclicLR(optimizer) + >>> data_loader = torch.utils.data.DataLoader(...) + >>> for epoch in range(10): + >>> for batch in data_loader: + >>> train_batch(...) + >>> scheduler.step() + + + .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 + .. _bckenstler/CLR: https://github.com/bckenstler/CLR + """ + + def __init__(self, + optimizer, + base_lr, + max_lr, + step_size_up=2000, + step_size_down=None, + mode='triangular', + gamma=1., + scale_fn=None, + scale_mode='cycle', + cycle_momentum=True, + base_momentum=0.8, + max_momentum=0.9, + last_epoch=-1): + + if not isinstance(optimizer, Optimizer): + raise TypeError('{} is not an Optimizer'.format( + type(optimizer).__name__)) + self.optimizer = optimizer + + base_lrs = self._format_param('base_lr', optimizer, base_lr) + if last_epoch == -1: + for lr, group in zip(base_lrs, optimizer.param_groups): + group['lr'] = lr + + self.max_lrs = self._format_param('max_lr', optimizer, max_lr) + + step_size_up = float(step_size_up) + step_size_down = float(step_size_down) if step_size_down is not None else step_size_up + self.total_size = step_size_up + step_size_down + self.step_ratio = step_size_up / self.total_size + + if mode not in ['triangular', 'triangular2', 'exp_range'] \ + and scale_fn is None: + raise ValueError('mode is invalid and scale_fn is None') + + self.mode = mode + self.gamma = gamma + + if scale_fn is None: + if self.mode == 'triangular': + self.scale_fn = self._triangular_scale_fn + self.scale_mode = 'cycle' + elif self.mode == 'triangular2': + self.scale_fn = self._triangular2_scale_fn + self.scale_mode = 'cycle' + elif self.mode == 'exp_range': + self.scale_fn = self._exp_range_scale_fn + self.scale_mode = 'iterations' + else: + self.scale_fn = scale_fn + self.scale_mode = scale_mode + + self.cycle_momentum = cycle_momentum + if cycle_momentum: + if 'momentum' not in optimizer.defaults: + raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled') + + base_momentums = self._format_param('base_momentum', optimizer, base_momentum) + if last_epoch == -1: + for momentum, group in zip(base_momentums, optimizer.param_groups): + group['momentum'] = momentum + self.base_momentums = list(map(lambda group: group['momentum'], optimizer.param_groups)) + self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum) + + super(CyclicLR, self).__init__(optimizer, last_epoch) + + def _format_param(self, name, optimizer, param): + """Return correctly formatted lr/momentum for each param group.""" + if isinstance(param, (list, tuple)): + if len(param) != len(optimizer.param_groups): + raise ValueError("expected {} values for {}, got {}".format( + len(optimizer.param_groups), name, len(param))) + return param + else: + return [param] * len(optimizer.param_groups) + + def _triangular_scale_fn(self, x): + return 1. + + def _triangular2_scale_fn(self, x): + return 1 / (2. ** (x - 1)) + + def _exp_range_scale_fn(self, x): + return self.gamma**(x) + + def get_lr(self): + """Calculates the learning rate at batch index. This function treats + `self.last_epoch` as the last batch index. + + If `self.cycle_momentum` is ``True``, this function has a side effect of + updating the optimizer's momentum. + """ + cycle = math.floor(1 + self.last_epoch / self.total_size) + x = 1. + self.last_epoch / self.total_size - cycle + if x <= self.step_ratio: + scale_factor = x / self.step_ratio + else: + scale_factor = (x - 1) / (self.step_ratio - 1) + + lrs = [] + for base_lr, max_lr in zip(self.base_lrs, self.max_lrs): + base_height = (max_lr - base_lr) * scale_factor + if self.scale_mode == 'cycle': + lr = base_lr + base_height * self.scale_fn(cycle) + else: + lr = base_lr + base_height * self.scale_fn(self.last_epoch) + lrs.append(lr) + + if self.cycle_momentum: + momentums = [] + for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums): + base_height = (max_momentum - base_momentum) * scale_factor + if self.scale_mode == 'cycle': + momentum = max_momentum - base_height * self.scale_fn(cycle) + else: + momentum = max_momentum - base_height * self.scale_fn(self.last_epoch) + momentums.append(momentum) + for param_group, momentum in zip(self.optimizer.param_groups, momentums): + param_group['momentum'] = momentum + + return lrs |