diff options
author | Yan Wang <me@yanwang.me> | 2017-06-03 15:00:11 +0800 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-06-05 23:47:56 -0400 |
commit | a76098ac1532d5e9ee24b4776258ae731627f8e3 (patch) | |
tree | fa0f995b03f3589aa795215276adebc617cbe41c /torch/optim | |
parent | 2ce5875a4d0d1c7d0deea99a28b6acfcc86106d2 (diff) | |
download | pytorch-a76098ac1532d5e9ee24b4776258ae731627f8e3.tar.gz pytorch-a76098ac1532d5e9ee24b4776258ae731627f8e3.tar.bz2 pytorch-a76098ac1532d5e9ee24b4776258ae731627f8e3.zip |
fix optimizer when given single parameters (instead of an iterable)
When I use the named_parametes to modify the lr and weight decay, I will face a bug. Because the value of the named_parameters return is torch.nn.paramter.Parameter, not a generator of the Parameter.
Diffstat (limited to 'torch/optim')
-rw-r--r-- | torch/optim/optimizer.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index ebeddd1129..12285e913b 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -33,7 +33,10 @@ class Optimizer(object): param_set = set() for group in self.param_groups: - group['params'] = list(group['params']) + if isinstance(group['params'], torch.autograd.Variable): + group['params'] = [group['params']] + else: + group['params'] = list(group['params']) group_set = set(group['params']) if not param_set.isdisjoint(group_set): raise ValueError("some parameters appear in more than one " |