diff options
author | Marcin Elantkowski <marcin.elantkowski@gmail.com> | 2018-02-20 06:59:14 +0100 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2018-02-20 00:59:14 -0500 |
commit | d2ff733cb109d32049381884302650463b71d77e (patch) | |
tree | ecb7a5746f4175dc9b1c65ef6ae6c6a9506b8739 /torch/optim | |
parent | 596470011b745328c491ba685541fcb02057bf52 (diff) | |
download | pytorch-d2ff733cb109d32049381884302650463b71d77e.tar.gz pytorch-d2ff733cb109d32049381884302650463b71d77e.tar.bz2 pytorch-d2ff733cb109d32049381884302650463b71d77e.zip |
Make ReduceLROnPlateau serializable. (#5300)
* replace lambdas with partial
* flake8
Diffstat (limited to 'torch/optim')
-rw-r--r-- | torch/optim/lr_scheduler.py | 36 |
1 files changed, 23 insertions, 13 deletions
diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index b9999d2ffe..9ce2988b20 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -1,5 +1,7 @@ import math from bisect import bisect_right +from functools import partial + from .optimizer import Optimizer @@ -322,22 +324,30 @@ class ReduceLROnPlateau(object): def in_cooldown(self): return self.cooldown_counter > 0 - def _init_is_better(self, mode, threshold, threshold_mode): - if mode not in {'min', 'max'}: - raise ValueError('mode ' + mode + ' is unknown!') - if threshold_mode not in {'rel', 'abs'}: - raise ValueError('threshold mode ' + threshold_mode + ' is unknown!') + def _cmp(self, mode, threshold_mode, threshold, a, best): if mode == 'min' and threshold_mode == 'rel': rel_epsilon = 1. - threshold - self.is_better = lambda a, best: a < best * rel_epsilon - self.mode_worse = float('Inf') + return a < best * rel_epsilon + elif mode == 'min' and threshold_mode == 'abs': - self.is_better = lambda a, best: a < best - threshold - self.mode_worse = float('Inf') + return a < best - threshold + elif mode == 'max' and threshold_mode == 'rel': rel_epsilon = threshold + 1. - self.is_better = lambda a, best: a > best * rel_epsilon - self.mode_worse = -float('Inf') + return a > best * rel_epsilon + else: # mode == 'max' and epsilon_mode == 'abs': - self.is_better = lambda a, best: a > best + threshold - self.mode_worse = -float('Inf') + return a > best + threshold + + def _init_is_better(self, mode, threshold, threshold_mode): + if mode not in {'min', 'max'}: + raise ValueError('mode ' + mode + ' is unknown!') + if threshold_mode not in {'rel', 'abs'}: + raise ValueError('threshold mode ' + threshold_mode + ' is unknown!') + + if mode == 'min': + self.mode_worse = float('inf') + else: # mode == 'max': + self.mode_worse = (-float('inf')) + + self.is_better = partial(self._cmp, mode, threshold_mode, threshold) |