diff options
Diffstat (limited to 'torch/optim/optimizer.py')
-rw-r--r-- | torch/optim/optimizer.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index ebeddd1129..12285e913b 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -33,7 +33,10 @@ class Optimizer(object): param_set = set() for group in self.param_groups: - group['params'] = list(group['params']) + if isinstance(group['params'], torch.autograd.Variable): + group['params'] = [group['params']] + else: + group['params'] = list(group['params']) group_set = set(group['params']) if not param_set.isdisjoint(group_set): raise ValueError("some parameters appear in more than one " |