diff options
author | Samuel <albanie@users.noreply.github.com> | 2018-04-27 22:47:37 +0100 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2018-04-27 17:47:37 -0400 |
commit | 0c737dff6333c859957c3b26f7a0114f73be12a6 (patch) | |
tree | c0f371a0b2326175d64af442c34324ea959a90d7 /torch/optim | |
parent | 6ce376fee374de52554987adf2ed701d509dea9b (diff) | |
download | pytorch-0c737dff6333c859957c3b26f7a0114f73be12a6.tar.gz pytorch-0c737dff6333c859957c3b26f7a0114f73be12a6.tar.bz2 pytorch-0c737dff6333c859957c3b26f7a0114f73be12a6.zip |
fix lbfgs variable names (#7037)
Switches the step/direction variable names (steps and directions are flipped
in the current implementation of the two loop-recursion). This change does
not change the numerical output of the program, but should make it easier
to follow.
Diffstat (limited to 'torch/optim')
-rw-r--r-- | torch/optim/lbfgs.py | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/torch/optim/lbfgs.py b/torch/optim/lbfgs.py index 1ccf1ef658..ac096262a8 100644 --- a/torch/optim/lbfgs.py +++ b/torch/optim/lbfgs.py @@ -148,8 +148,8 @@ class LBFGS(Optimizer): old_stps.pop(0) # store new direction/step - old_dirs.append(s) - old_stps.append(y) + old_dirs.append(y) + old_stps.append(s) # update scale of initial Hessian approximation H_diag = ys / y.dot(y) # (y*y) @@ -165,20 +165,20 @@ class LBFGS(Optimizer): al = state['al'] for i in range(num_old): - ro[i] = 1. / old_stps[i].dot(old_dirs[i]) + ro[i] = 1. / old_dirs[i].dot(old_stps[i]) # iteration in L-BFGS loop collapsed to use just one buffer q = flat_grad.neg() for i in range(num_old - 1, -1, -1): - al[i] = old_dirs[i].dot(q) * ro[i] - q.add_(-al[i], old_stps[i]) + al[i] = old_stps[i].dot(q) * ro[i] + q.add_(-al[i], old_dirs[i]) # multiply by initial Hessian # r/d is the final direction d = r = torch.mul(q, H_diag) for i in range(num_old): - be_i = old_stps[i].dot(r) * ro[i] - r.add_(al[i] - be_i, old_dirs[i]) + be_i = old_dirs[i].dot(r) * ro[i] + r.add_(al[i] - be_i, old_stps[i]) if prev_flat_grad is None: prev_flat_grad = flat_grad.clone() |