diff options
Diffstat (limited to 'torch/optim')
-rw-r--r-- | torch/optim/sgd.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py index f6d5a484ef..78cb3e189a 100644 --- a/torch/optim/sgd.py +++ b/torch/optim/sgd.py @@ -50,11 +50,11 @@ class SGD(Optimizer): def __init__(self, params, lr=required, momentum=0, dampening=0, weight_decay=0, nesterov=False): - if not 0.0 <= lr: + if lr is not required and lr < 0.0: raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= momentum: + if momentum < 0.0: raise ValueError("Invalid momentum value: {}".format(momentum)) - if not 0.0 <= weight_decay: + if weight_decay < 0.0: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) defaults = dict(lr=lr, momentum=momentum, dampening=dampening, |