1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
|
import torch
from functools import reduce
from .optimizer import Optimizer
class LBFGS(Optimizer):
"""Implements L-BFGS algorithm.
.. warning::
This optimizer doesn't support per-parameter options and parameter
groups (there can be only one).
.. warning::
Right now all parameters have to be on a single device. This will be
improved in the future.
.. note::
This is a very memory intensive optimizer (it requires additional
``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory
try reducing the history size, or use a different algorithm.
Arguments:
lr (float): learning rate (default: 1)
max_iter (int): maximal number of iterations per optimization step
(default: 20)
max_eval (int): maximal number of function evaluations per optimization
step (default: max_iter * 1.25).
tolerance_grad (float): termination tolerance on first order optimality
(default: 1e-5).
tolerance_change (float): termination tolerance on function value/parameter
changes (default: 1e-9).
history_size (int): update history size (default: 100).
"""
def __init__(self, params, lr=1, max_iter=20, max_eval=None,
tolerance_grad=1e-5, tolerance_change=1e-9, history_size=100,
line_search_fn=None):
if max_eval is None:
max_eval = max_iter * 5 // 4
defaults = dict(lr=lr, max_iter=max_iter, max_eval=max_eval,
tolerance_grad=tolerance_grad, tolerance_change=tolerance_change,
history_size=history_size, line_search_fn=line_search_fn)
super(LBFGS, self).__init__(params, defaults)
if len(self.param_groups) != 1:
raise ValueError("LBFGS doesn't support per-parameter options "
"(parameter groups)")
self._params = self.param_groups[0]['params']
self._numel_cache = None
def _numel(self):
if self._numel_cache is None:
self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0)
return self._numel_cache
def _gather_flat_grad(self):
return torch.cat(
tuple(param.grad.data.view(-1) for param in self._params), 0)
def _add_grad(self, step_size, update):
offset = 0
for p in self._params:
numel = p.numel()
p.data.add_(step_size, update[offset:offset+numel])
offset += numel
assert offset == self._numel()
def step(self, closure):
"""Performs a single optimization step.
Arguments:
closure (callable): A closure that reevaluates the model
and returns the loss.
"""
assert len(self.param_groups) == 1
group = self.param_groups[0]
lr = group['lr']
max_iter = group['max_iter']
max_eval = group['max_eval']
tolerance_grad = group['tolerance_grad']
tolerance_change = group['tolerance_change']
line_search_fn = group['line_search_fn']
history_size = group['history_size']
state = self.state['global_state']
state.setdefault('func_evals', 0)
state.setdefault('n_iter', 0)
# evaluate initial f(x) and df/dx
orig_loss = closure()
loss = orig_loss.data[0]
current_evals = 1
state['func_evals'] += 1
flat_grad = self._gather_flat_grad()
abs_grad_sum = flat_grad.abs().sum()
if abs_grad_sum <= tolerance_grad:
return loss
# variables cached in state (for tracing)
d = state.get('d')
t = state.get('t')
old_dirs = state.get('old_dirs')
old_stps = state.get('old_stps')
H_diag = state.get('H_diag')
prev_flat_grad = state.get('prev_flat_grad')
prev_loss = state.get('prev_loss')
n_iter = 0
# optimize for a max of max_iter iterations
while n_iter < max_iter:
# keep track of nb of iterations
n_iter += 1
state['n_iter'] += 1
############################################################
# compute gradient descent direction
############################################################
if state['n_iter'] == 1:
d = flat_grad.neg()
old_dirs = []
old_stps = []
H_diag = 1
else:
# do lbfgs update (update memory)
y = flat_grad.sub(prev_flat_grad)
s = d.mul(t)
ys = y.dot(s) # y*s
if ys > 1e-10:
# updating memory
if len(old_dirs) == history_size:
# shift history by one (limited-memory)
old_dirs.pop(0)
old_stps.pop(0)
# store new direction/step
old_dirs.append(s)
old_stps.append(y)
# update scale of initial Hessian approximation
H_diag = ys / y.dot(y) # (y*y)
# compute the approximate (L-BFGS) inverse Hessian
# multiplied by the gradient
num_old = len(old_dirs)
if 'ro' not in state:
state['ro'] = [None] * history_size
state['al'] = [None] * history_size
ro = state['ro']
al = state['al']
for i in range(num_old):
ro[i] = 1. / old_stps[i].dot(old_dirs[i])
# iteration in L-BFGS loop collapsed to use just one buffer
q = flat_grad.neg()
for i in range(num_old-1, -1, -1):
al[i] = old_dirs[i].dot(q) * ro[i]
q.add_(-al[i], old_stps[i])
# multiply by initial Hessian
# r/d is the final direction
d = r = torch.mul(q, H_diag)
for i in range(num_old):
be_i = old_stps[i].dot(r) * ro[i]
r.add_(al[i] - be_i, old_dirs[i])
if prev_flat_grad is None:
prev_flat_grad = flat_grad.clone()
else:
prev_flat_grad.copy_(flat_grad)
prev_loss = loss
############################################################
# compute step length
############################################################
# directional derivative
gtd = flat_grad.dot(d) # g * d
# check that progress can be made along that direction
if gtd > -tolerance_change:
break
# reset initial guess for step size
if state['n_iter'] == 1:
t = min(1., 1. / abs_grad_sum) * lr
else:
t = lr
# optional line search: user function
ls_func_evals = 0
if line_search_fn is not None:
# perform line search, using user function
raise RuntimeError("line search function is not supported yet")
else:
# no line search, simply move with fixed-step
self._add_grad(t, d)
if n_iter != max_iter:
# re-evaluate function only if not in last iteration
# the reason we do this: in a stochastic setting,
# no use to re-evaluate that function here
loss = closure().data[0]
flat_grad = self._gather_flat_grad()
abs_grad_sum = flat_grad.abs().sum()
ls_func_evals = 1
# update func eval
current_evals += ls_func_evals
state['func_evals'] += ls_func_evals
############################################################
# check conditions
############################################################
if n_iter == max_iter:
break
if current_evals >= max_eval:
break
if abs_grad_sum <= tolerance_grad:
break
if d.mul(t).abs_().sum() <= tolerance_change:
break
if abs(loss - prev_loss) < tolerance_change:
break
state['d'] = d
state['t'] = t
state['old_dirs'] = old_dirs
state['old_stps'] = old_stps
state['H_diag'] = H_diag
state['prev_flat_grad'] = prev_flat_grad
state['prev_loss'] = prev_loss
return orig_loss
|