summaryrefslogtreecommitdiff
path: root/torch/optim
diff options
context:
space:
mode:
authorTongzhou Wang <SsnL@users.noreply.github.com>2018-04-03 21:29:19 -0400
committerSoumith Chintala <soumith@gmail.com>2018-04-03 21:29:18 -0400
commita2880531ea4c3ba0130738e5fa66ae31f293cbfb (patch)
treef977c5cc30527063cb25d1514068d25c0597b6a2 /torch/optim
parent06a697785cd1f6cf1ae0bb29a20a76ab98a61733 (diff)
downloadpytorch-a2880531ea4c3ba0130738e5fa66ae31f293cbfb.tar.gz
pytorch-a2880531ea4c3ba0130738e5fa66ae31f293cbfb.tar.bz2
pytorch-a2880531ea4c3ba0130738e5fa66ae31f293cbfb.zip
fix SGD lr check (#6244)
Diffstat (limited to 'torch/optim')
-rw-r--r--torch/optim/sgd.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py
index f6d5a484ef..78cb3e189a 100644
--- a/torch/optim/sgd.py
+++ b/torch/optim/sgd.py
@@ -50,11 +50,11 @@ class SGD(Optimizer):
def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False):
- if not 0.0 <= lr:
+ if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
- if not 0.0 <= momentum:
+ if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
- if not 0.0 <= weight_decay:
+ if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,