summaryrefslogtreecommitdiff
path: root/torch/optim
diff options
context:
space:
mode:
Diffstat (limited to 'torch/optim')
-rw-r--r--torch/optim/sgd.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py
index f6d5a484ef..78cb3e189a 100644
--- a/torch/optim/sgd.py
+++ b/torch/optim/sgd.py
@@ -50,11 +50,11 @@ class SGD(Optimizer):
def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False):
- if not 0.0 <= lr:
+ if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
- if not 0.0 <= momentum:
+ if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
- if not 0.0 <= weight_decay:
+ if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,