summaryrefslogtreecommitdiff
path: root/torch/optim
diff options
context:
space:
mode:
authorSam Pepose <sampepose@fb.com>2019-03-27 19:47:43 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-03-27 19:56:04 -0700
commit8635078d9ed47266e6df89ad5ef887b598a76f25 (patch)
tree78704f54331627703849b5342182d22df217d77a /torch/optim
parent54abfda12434692ea5902ce5f94062fdac7fde61 (diff)
downloadpytorch-8635078d9ed47266e6df89ad5ef887b598a76f25.tar.gz
pytorch-8635078d9ed47266e6df89ad5ef887b598a76f25.tar.bz2
pytorch-8635078d9ed47266e6df89ad5ef887b598a76f25.zip
Adds Cyclical Learning Rate and Momentum (#18001)
Summary: This implements a cyclical learning rate (CLR) schedule with an optional inverse cyclical momentum. More info about CLR: https://github.com/bckenstler/CLR This is finishing what #2016 started. Resolves #1909. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18001 Differential Revision: D14451845 Pulled By: sampepose fbshipit-source-id: 8f682e0c3dee3a73bd2b14cc93fcf5f0e836b8c9
Diffstat (limited to 'torch/optim')
-rw-r--r--torch/optim/lr_scheduler.py214
1 files changed, 214 insertions, 0 deletions
diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py
index 200e2c6ecf..36507940e9 100644
--- a/torch/optim/lr_scheduler.py
+++ b/torch/optim/lr_scheduler.py
@@ -4,6 +4,7 @@ import torch
from torch._six import inf
from collections import Counter
from functools import partial
+
from .optimizer import Optimizer
@@ -427,3 +428,216 @@ class ReduceLROnPlateau(object):
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode)
+
+
+class CyclicLR(_LRScheduler):
+ """Sets the learning rate of each parameter group according to
+ cyclical learning rate policy (CLR). The policy cycles the learning
+ rate between two boundaries with a constant frequency, as detailed in
+ the paper `Cyclical Learning Rates for Training Neural Networks`_.
+ The distance between the two boundaries can be scaled on a per-iteration
+ or per-cycle basis.
+
+ Cyclical learning rate policy changes the learning rate after every batch.
+ `step` should be called after a batch has been used for training.
+
+ This class has three built-in policies, as put forth in the paper:
+ "triangular":
+ A basic triangular cycle w/ no amplitude scaling.
+ "triangular2":
+ A basic triangular cycle that scales initial amplitude by half each cycle.
+ "exp_range":
+ A cycle that scales initial amplitude by gamma**(cycle iterations) at each
+ cycle iteration.
+
+ This implementation was adapted from the github repo: `bckenstler/CLR`_
+
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ base_lr (float or list): Initial learning rate which is the
+ lower boundary in the cycle for each parameter group.
+ max_lr (float or list): Upper learning rate boundaries in the cycle
+ for each parameter group. Functionally,
+ it defines the cycle amplitude (max_lr - base_lr).
+ The lr at any cycle is the sum of base_lr
+ and some scaling of the amplitude; therefore
+ max_lr may not actually be reached depending on
+ scaling function.
+ step_size_up (int): Number of training iterations in the
+ increasing half of a cycle. Default: 2000
+ step_size_down (int): Number of training iterations in the
+ decreasing half of a cycle. If step_size_down is None,
+ it is set to step_size_up. Default: None
+ mode (str): One of {triangular, triangular2, exp_range}.
+ Values correspond to policies detailed above.
+ If scale_fn is not None, this argument is ignored.
+ Default: 'triangular'
+ gamma (float): Constant in 'exp_range' scaling function:
+ gamma**(cycle iterations)
+ Default: 1.0
+ scale_fn (function): Custom scaling policy defined by a single
+ argument lambda function, where
+ 0 <= scale_fn(x) <= 1 for all x >= 0.
+ If specified, then 'mode' is ignored.
+ Default: None
+ scale_mode (str): {'cycle', 'iterations'}.
+ Defines whether scale_fn is evaluated on
+ cycle number or cycle iterations (training
+ iterations since start of cycle).
+ Default: 'cycle'
+ cycle_momentum (bool): If ``True``, momentum is cycled inversely
+ to learning rate between 'base_momentum' and 'max_momentum'.
+ Default: True
+ base_momentum (float or list): Initial momentum which is the
+ lower boundary in the cycle for each parameter group.
+ Default: 0.8
+ max_momentum (float or list): Upper momentum boundaries in the cycle
+ for each parameter group. Functionally,
+ it defines the cycle amplitude (max_momentum - base_momentum).
+ The momentum at any cycle is the difference of max_momentum
+ and some scaling of the amplitude; therefore
+ base_momentum may not actually be reached depending on
+ scaling function. Default: 0.9
+ last_epoch (int): The index of the last batch. This parameter is used when
+ resuming a training job. Since `step()` should be invoked after each
+ batch instead of after each epoch, this number represents the total
+ number of *batches* computed, not the total number of epochs computed.
+ When last_epoch=-1, the schedule is started from the beginning.
+ Default: -1
+
+ Example:
+ >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
+ >>> scheduler = torch.optim.CyclicLR(optimizer)
+ >>> data_loader = torch.utils.data.DataLoader(...)
+ >>> for epoch in range(10):
+ >>> for batch in data_loader:
+ >>> train_batch(...)
+ >>> scheduler.step()
+
+
+ .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
+ .. _bckenstler/CLR: https://github.com/bckenstler/CLR
+ """
+
+ def __init__(self,
+ optimizer,
+ base_lr,
+ max_lr,
+ step_size_up=2000,
+ step_size_down=None,
+ mode='triangular',
+ gamma=1.,
+ scale_fn=None,
+ scale_mode='cycle',
+ cycle_momentum=True,
+ base_momentum=0.8,
+ max_momentum=0.9,
+ last_epoch=-1):
+
+ if not isinstance(optimizer, Optimizer):
+ raise TypeError('{} is not an Optimizer'.format(
+ type(optimizer).__name__))
+ self.optimizer = optimizer
+
+ base_lrs = self._format_param('base_lr', optimizer, base_lr)
+ if last_epoch == -1:
+ for lr, group in zip(base_lrs, optimizer.param_groups):
+ group['lr'] = lr
+
+ self.max_lrs = self._format_param('max_lr', optimizer, max_lr)
+
+ step_size_up = float(step_size_up)
+ step_size_down = float(step_size_down) if step_size_down is not None else step_size_up
+ self.total_size = step_size_up + step_size_down
+ self.step_ratio = step_size_up / self.total_size
+
+ if mode not in ['triangular', 'triangular2', 'exp_range'] \
+ and scale_fn is None:
+ raise ValueError('mode is invalid and scale_fn is None')
+
+ self.mode = mode
+ self.gamma = gamma
+
+ if scale_fn is None:
+ if self.mode == 'triangular':
+ self.scale_fn = self._triangular_scale_fn
+ self.scale_mode = 'cycle'
+ elif self.mode == 'triangular2':
+ self.scale_fn = self._triangular2_scale_fn
+ self.scale_mode = 'cycle'
+ elif self.mode == 'exp_range':
+ self.scale_fn = self._exp_range_scale_fn
+ self.scale_mode = 'iterations'
+ else:
+ self.scale_fn = scale_fn
+ self.scale_mode = scale_mode
+
+ self.cycle_momentum = cycle_momentum
+ if cycle_momentum:
+ if 'momentum' not in optimizer.defaults:
+ raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
+
+ base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
+ if last_epoch == -1:
+ for momentum, group in zip(base_momentums, optimizer.param_groups):
+ group['momentum'] = momentum
+ self.base_momentums = list(map(lambda group: group['momentum'], optimizer.param_groups))
+ self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
+
+ super(CyclicLR, self).__init__(optimizer, last_epoch)
+
+ def _format_param(self, name, optimizer, param):
+ """Return correctly formatted lr/momentum for each param group."""
+ if isinstance(param, (list, tuple)):
+ if len(param) != len(optimizer.param_groups):
+ raise ValueError("expected {} values for {}, got {}".format(
+ len(optimizer.param_groups), name, len(param)))
+ return param
+ else:
+ return [param] * len(optimizer.param_groups)
+
+ def _triangular_scale_fn(self, x):
+ return 1.
+
+ def _triangular2_scale_fn(self, x):
+ return 1 / (2. ** (x - 1))
+
+ def _exp_range_scale_fn(self, x):
+ return self.gamma**(x)
+
+ def get_lr(self):
+ """Calculates the learning rate at batch index. This function treats
+ `self.last_epoch` as the last batch index.
+
+ If `self.cycle_momentum` is ``True``, this function has a side effect of
+ updating the optimizer's momentum.
+ """
+ cycle = math.floor(1 + self.last_epoch / self.total_size)
+ x = 1. + self.last_epoch / self.total_size - cycle
+ if x <= self.step_ratio:
+ scale_factor = x / self.step_ratio
+ else:
+ scale_factor = (x - 1) / (self.step_ratio - 1)
+
+ lrs = []
+ for base_lr, max_lr in zip(self.base_lrs, self.max_lrs):
+ base_height = (max_lr - base_lr) * scale_factor
+ if self.scale_mode == 'cycle':
+ lr = base_lr + base_height * self.scale_fn(cycle)
+ else:
+ lr = base_lr + base_height * self.scale_fn(self.last_epoch)
+ lrs.append(lr)
+
+ if self.cycle_momentum:
+ momentums = []
+ for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums):
+ base_height = (max_momentum - base_momentum) * scale_factor
+ if self.scale_mode == 'cycle':
+ momentum = max_momentum - base_height * self.scale_fn(cycle)
+ else:
+ momentum = max_momentum - base_height * self.scale_fn(self.last_epoch)
+ momentums.append(momentum)
+ for param_group, momentum in zip(self.optimizer.param_groups, momentums):
+ param_group['momentum'] = momentum
+
+ return lrs