summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTongzhou Wang <SsnL@users.noreply.github.com>2018-04-03 21:29:19 -0400
committerSoumith Chintala <soumith@gmail.com>2018-04-03 21:29:18 -0400
commita2880531ea4c3ba0130738e5fa66ae31f293cbfb (patch)
treef977c5cc30527063cb25d1514068d25c0597b6a2
parent06a697785cd1f6cf1ae0bb29a20a76ab98a61733 (diff)
downloadpytorch-a2880531ea4c3ba0130738e5fa66ae31f293cbfb.tar.gz
pytorch-a2880531ea4c3ba0130738e5fa66ae31f293cbfb.tar.bz2
pytorch-a2880531ea4c3ba0130738e5fa66ae31f293cbfb.zip
fix SGD lr check (#6244)
-rw-r--r--test/test_optim.py4
-rw-r--r--torch/optim/sgd.py6
2 files changed, 7 insertions, 3 deletions
diff --git a/test/test_optim.py b/test/test_optim.py
index 386438960a..736d1b51b4 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -260,6 +260,10 @@ class TestOptim(TestCase):
self._build_params_dict_single(weight, bias, lr=1e-2),
lr=1e-3)
)
+ self._test_basic_cases(
+ lambda weight, bias: optim.SGD(
+ self._build_params_dict_single(weight, bias, lr=1e-2))
+ )
with self.assertRaisesRegex(ValueError, "Invalid momentum value: -0.5"):
optim.SGD(None, lr=1e-2, momentum=-0.5)
diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py
index f6d5a484ef..78cb3e189a 100644
--- a/torch/optim/sgd.py
+++ b/torch/optim/sgd.py
@@ -50,11 +50,11 @@ class SGD(Optimizer):
def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False):
- if not 0.0 <= lr:
+ if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
- if not 0.0 <= momentum:
+ if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
- if not 0.0 <= weight_decay:
+ if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,