diff options
author | Tongzhou Wang <SsnL@users.noreply.github.com> | 2018-04-03 21:29:19 -0400 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2018-04-03 21:29:18 -0400 |
commit | a2880531ea4c3ba0130738e5fa66ae31f293cbfb (patch) | |
tree | f977c5cc30527063cb25d1514068d25c0597b6a2 | |
parent | 06a697785cd1f6cf1ae0bb29a20a76ab98a61733 (diff) | |
download | pytorch-a2880531ea4c3ba0130738e5fa66ae31f293cbfb.tar.gz pytorch-a2880531ea4c3ba0130738e5fa66ae31f293cbfb.tar.bz2 pytorch-a2880531ea4c3ba0130738e5fa66ae31f293cbfb.zip |
fix SGD lr check (#6244)
-rw-r--r-- | test/test_optim.py | 4 | ||||
-rw-r--r-- | torch/optim/sgd.py | 6 |
2 files changed, 7 insertions, 3 deletions
diff --git a/test/test_optim.py b/test/test_optim.py index 386438960a..736d1b51b4 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -260,6 +260,10 @@ class TestOptim(TestCase): self._build_params_dict_single(weight, bias, lr=1e-2), lr=1e-3) ) + self._test_basic_cases( + lambda weight, bias: optim.SGD( + self._build_params_dict_single(weight, bias, lr=1e-2)) + ) with self.assertRaisesRegex(ValueError, "Invalid momentum value: -0.5"): optim.SGD(None, lr=1e-2, momentum=-0.5) diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py index f6d5a484ef..78cb3e189a 100644 --- a/torch/optim/sgd.py +++ b/torch/optim/sgd.py @@ -50,11 +50,11 @@ class SGD(Optimizer): def __init__(self, params, lr=required, momentum=0, dampening=0, weight_decay=0, nesterov=False): - if not 0.0 <= lr: + if lr is not required and lr < 0.0: raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= momentum: + if momentum < 0.0: raise ValueError("Invalid momentum value: {}".format(momentum)) - if not 0.0 <= weight_decay: + if weight_decay < 0.0: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) defaults = dict(lr=lr, momentum=momentum, dampening=dampening, |