summaryrefslogtreecommitdiff
path: root/torch/optim
diff options
context:
space:
mode:
authorYan Wang <me@yanwang.me>2017-06-03 15:00:11 +0800
committerSoumith Chintala <soumith@gmail.com>2017-06-05 23:47:56 -0400
commita76098ac1532d5e9ee24b4776258ae731627f8e3 (patch)
treefa0f995b03f3589aa795215276adebc617cbe41c /torch/optim
parent2ce5875a4d0d1c7d0deea99a28b6acfcc86106d2 (diff)
downloadpytorch-a76098ac1532d5e9ee24b4776258ae731627f8e3.tar.gz
pytorch-a76098ac1532d5e9ee24b4776258ae731627f8e3.tar.bz2
pytorch-a76098ac1532d5e9ee24b4776258ae731627f8e3.zip
fix optimizer when given single parameters (instead of an iterable)
When I use the named_parametes to modify the lr and weight decay, I will face a bug. Because the value of the named_parameters return is torch.nn.paramter.Parameter, not a generator of the Parameter.
Diffstat (limited to 'torch/optim')
-rw-r--r--torch/optim/optimizer.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py
index ebeddd1129..12285e913b 100644
--- a/torch/optim/optimizer.py
+++ b/torch/optim/optimizer.py
@@ -33,7 +33,10 @@ class Optimizer(object):
param_set = set()
for group in self.param_groups:
- group['params'] = list(group['params'])
+ if isinstance(group['params'], torch.autograd.Variable):
+ group['params'] = [group['params']]
+ else:
+ group['params'] = list(group['params'])
group_set = set(group['params'])
if not param_set.isdisjoint(group_set):
raise ValueError("some parameters appear in more than one "