summaryrefslogtreecommitdiff
path: root/torch/optim
diff options
context:
space:
mode:
authorMarcin Elantkowski <marcin.elantkowski@gmail.com>2018-02-20 06:59:14 +0100
committerSoumith Chintala <soumith@gmail.com>2018-02-20 00:59:14 -0500
commitd2ff733cb109d32049381884302650463b71d77e (patch)
treeecb7a5746f4175dc9b1c65ef6ae6c6a9506b8739 /torch/optim
parent596470011b745328c491ba685541fcb02057bf52 (diff)
downloadpytorch-d2ff733cb109d32049381884302650463b71d77e.tar.gz
pytorch-d2ff733cb109d32049381884302650463b71d77e.tar.bz2
pytorch-d2ff733cb109d32049381884302650463b71d77e.zip
Make ReduceLROnPlateau serializable. (#5300)
* replace lambdas with partial * flake8
Diffstat (limited to 'torch/optim')
-rw-r--r--torch/optim/lr_scheduler.py36
1 files changed, 23 insertions, 13 deletions
diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py
index b9999d2ffe..9ce2988b20 100644
--- a/torch/optim/lr_scheduler.py
+++ b/torch/optim/lr_scheduler.py
@@ -1,5 +1,7 @@
import math
from bisect import bisect_right
+from functools import partial
+
from .optimizer import Optimizer
@@ -322,22 +324,30 @@ class ReduceLROnPlateau(object):
def in_cooldown(self):
return self.cooldown_counter > 0
- def _init_is_better(self, mode, threshold, threshold_mode):
- if mode not in {'min', 'max'}:
- raise ValueError('mode ' + mode + ' is unknown!')
- if threshold_mode not in {'rel', 'abs'}:
- raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
+ def _cmp(self, mode, threshold_mode, threshold, a, best):
if mode == 'min' and threshold_mode == 'rel':
rel_epsilon = 1. - threshold
- self.is_better = lambda a, best: a < best * rel_epsilon
- self.mode_worse = float('Inf')
+ return a < best * rel_epsilon
+
elif mode == 'min' and threshold_mode == 'abs':
- self.is_better = lambda a, best: a < best - threshold
- self.mode_worse = float('Inf')
+ return a < best - threshold
+
elif mode == 'max' and threshold_mode == 'rel':
rel_epsilon = threshold + 1.
- self.is_better = lambda a, best: a > best * rel_epsilon
- self.mode_worse = -float('Inf')
+ return a > best * rel_epsilon
+
else: # mode == 'max' and epsilon_mode == 'abs':
- self.is_better = lambda a, best: a > best + threshold
- self.mode_worse = -float('Inf')
+ return a > best + threshold
+
+ def _init_is_better(self, mode, threshold, threshold_mode):
+ if mode not in {'min', 'max'}:
+ raise ValueError('mode ' + mode + ' is unknown!')
+ if threshold_mode not in {'rel', 'abs'}:
+ raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
+
+ if mode == 'min':
+ self.mode_worse = float('inf')
+ else: # mode == 'max':
+ self.mode_worse = (-float('inf'))
+
+ self.is_better = partial(self._cmp, mode, threshold_mode, threshold)