diff options
author | SsnL <SsnL@users.noreply.github.com> | 2017-11-06 14:20:51 -0500 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-11-06 14:20:51 -0500 |
commit | f76d6c029c443606559122440299b724f1b7ff30 (patch) | |
tree | 35e81ad32f098b4207eafe54c05c9c0efb6a87f1 /torch/optim | |
parent | c2626f6031bd48b6aff517d3de6bcc8f0bd01271 (diff) | |
download | pytorch-f76d6c029c443606559122440299b724f1b7ff30.tar.gz pytorch-f76d6c029c443606559122440299b724f1b7ff30.tar.bz2 pytorch-f76d6c029c443606559122440299b724f1b7ff30.zip |
Sparse Adam optimizer for sparse gradients (#3137)
* sparse adam
* Favor dense addition over sparse_mask
Diffstat (limited to 'torch/optim')
-rw-r--r-- | torch/optim/__init__.py | 1 | ||||
-rw-r--r-- | torch/optim/adadelta.py | 8 | ||||
-rw-r--r-- | torch/optim/adagrad.py | 13 | ||||
-rw-r--r-- | torch/optim/adam.py | 8 | ||||
-rw-r--r-- | torch/optim/adamax.py | 6 | ||||
-rw-r--r-- | torch/optim/asgd.py | 5 | ||||
-rw-r--r-- | torch/optim/rmsprop.py | 9 | ||||
-rw-r--r-- | torch/optim/rprop.py | 5 | ||||
-rw-r--r-- | torch/optim/sgd.py | 3 | ||||
-rw-r--r-- | torch/optim/sparse_adam.py | 95 |
10 files changed, 134 insertions, 19 deletions
diff --git a/torch/optim/__init__.py b/torch/optim/__init__.py index 2fd9916284..f10601ad2c 100644 --- a/torch/optim/__init__.py +++ b/torch/optim/__init__.py @@ -8,6 +8,7 @@ future. from .adadelta import Adadelta from .adagrad import Adagrad from .adam import Adam +from .sparse_adam import SparseAdam from .adamax import Adamax from .asgd import ASGD from .sgd import SGD diff --git a/torch/optim/adadelta.py b/torch/optim/adadelta.py index c7a23418b1..a37febaab5 100644 --- a/torch/optim/adadelta.py +++ b/torch/optim/adadelta.py @@ -1,3 +1,5 @@ +import torch + from .optimizer import Optimizer @@ -40,13 +42,15 @@ class Adadelta(Optimizer): if p.grad is None: continue grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Adadelta does not support sparse gradients') state = self.state[p] # State initialization if len(state) == 0: state['step'] = 0 - state['square_avg'] = grad.new().resize_as_(grad).zero_() - state['acc_delta'] = grad.new().resize_as_(grad).zero_() + state['square_avg'] = torch.zeros_like(p.data) + state['acc_delta'] = torch.zeros_like(p.data) square_avg, acc_delta = state['square_avg'], state['acc_delta'] rho, eps = group['rho'], group['eps'] diff --git a/torch/optim/adagrad.py b/torch/optim/adagrad.py index 70b74dbdab..7c152df1b0 100644 --- a/torch/optim/adagrad.py +++ b/torch/optim/adagrad.py @@ -1,5 +1,4 @@ import torch - from .optimizer import Optimizer @@ -28,7 +27,7 @@ class Adagrad(Optimizer): for p in group['params']: state = self.state[p] state['step'] = 0 - state['sum'] = p.data.new().resize_as_(p.data).zero_() + state['sum'] = torch.zeros_like(p.data) def share_memory(self): for group in self.param_groups: @@ -59,21 +58,21 @@ class Adagrad(Optimizer): if group['weight_decay'] != 0: if p.grad.data.is_sparse: - raise RuntimeError("weight_decay option is not compatible with sparse gradients ") + raise RuntimeError("weight_decay option is not compatible with sparse gradients") grad = grad.add(group['weight_decay'], p.data) clr = group['lr'] / (1 + (state['step'] - 1) * group['lr_decay']) - if p.grad.data.is_sparse: + if grad.is_sparse: grad = grad.coalesce() # the update is non-linear so indices must be unique grad_indices = grad._indices() grad_values = grad._values() - size = torch.Size([x for x in grad.size()]) + size = grad.size() def make_sparse(values): - constructor = type(p.grad.data) + constructor = grad.new if grad_indices.dim() == 0 or values.dim() == 0: - return constructor() + return constructor().resize_as_(grad) return constructor(grad_indices, values, size) state['sum'].add_(make_sparse(grad_values.pow(2))) std = state['sum']._sparse_mask(grad) diff --git a/torch/optim/adam.py b/torch/optim/adam.py index 817f2aade0..e600839d20 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -1,4 +1,5 @@ import math +import torch from .optimizer import Optimizer @@ -43,15 +44,18 @@ class Adam(Optimizer): if p.grad is None: continue grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + state = self.state[p] # State initialization if len(state) == 0: state['step'] = 0 # Exponential moving average of gradient values - state['exp_avg'] = grad.new().resize_as_(grad).zero_() + state['exp_avg'] = torch.zeros_like(p.data) # Exponential moving average of squared gradient values - state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_() + state['exp_avg_sq'] = torch.zeros_like(p.data) exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] beta1, beta2 = group['betas'] diff --git a/torch/optim/adamax.py b/torch/optim/adamax.py index 1cf4a57194..dfec5ac371 100644 --- a/torch/optim/adamax.py +++ b/torch/optim/adamax.py @@ -41,13 +41,15 @@ class Adamax(Optimizer): if p.grad is None: continue grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Adamax does not support sparse gradients') state = self.state[p] # State initialization if len(state) == 0: state['step'] = 0 - state['exp_avg'] = grad.new().resize_as_(grad).zero_() - state['exp_inf'] = grad.new().resize_as_(grad).zero_() + state['exp_avg'] = torch.zeros_like(p.data) + state['exp_inf'] = torch.zeros_like(p.data) exp_avg, exp_inf = state['exp_avg'], state['exp_inf'] beta1, beta2 = group['betas'] diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py index f37aaccfbb..f72d1b20c6 100644 --- a/torch/optim/asgd.py +++ b/torch/optim/asgd.py @@ -1,4 +1,5 @@ import math +import torch from .optimizer import Optimizer @@ -42,6 +43,8 @@ class ASGD(Optimizer): if p.grad is None: continue grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('ASGD does not support sparse gradients') state = self.state[p] # State initialization @@ -49,7 +52,7 @@ class ASGD(Optimizer): state['step'] = 0 state['eta'] = group['lr'] state['mu'] = 1 - state['ax'] = grad.new().resize_as_(grad).zero_() + state['ax'] = torch.zeros_like(p.data) state['step'] += 1 diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py index 94f8d6b6fd..5f760fd624 100644 --- a/torch/optim/rmsprop.py +++ b/torch/optim/rmsprop.py @@ -1,3 +1,4 @@ +import torch from .optimizer import Optimizer @@ -50,16 +51,18 @@ class RMSprop(Optimizer): if p.grad is None: continue grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('RMSprop does not support sparse gradients') state = self.state[p] # State initialization if len(state) == 0: state['step'] = 0 - state['square_avg'] = grad.new().resize_as_(grad).zero_() + state['square_avg'] = torch.zeros_like(p.data) if group['momentum'] > 0: - state['momentum_buffer'] = grad.new().resize_as_(grad).zero_() + state['momentum_buffer'] = torch.zeros_like(p.data) if group['centered']: - state['grad_avg'] = grad.new().resize_as_(grad).zero_() + state['grad_avg'] = torch.zeros_like(p.data) square_avg = state['square_avg'] alpha = group['alpha'] diff --git a/torch/optim/rprop.py b/torch/optim/rprop.py index e01a63348c..86705e6ad4 100644 --- a/torch/optim/rprop.py +++ b/torch/optim/rprop.py @@ -1,4 +1,5 @@ import math +import torch from .optimizer import Optimizer @@ -36,12 +37,14 @@ class Rprop(Optimizer): if p.grad is None: continue grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Rprop does not support sparse gradients') state = self.state[p] # State initialization if len(state) == 0: state['step'] = 0 - state['prev'] = grad.new().resize_as_(grad).zero_() + state['prev'] = torch.zeros_like(p.data) state['step_size'] = grad.new().resize_as_(grad).fill_(group['lr']) etaminus, etaplus = group['etas'] diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py index f51a516a29..349b6885de 100644 --- a/torch/optim/sgd.py +++ b/torch/optim/sgd.py @@ -1,3 +1,4 @@ +import torch from .optimizer import Optimizer, required @@ -86,7 +87,7 @@ class SGD(Optimizer): if momentum != 0: param_state = self.state[p] if 'momentum_buffer' not in param_state: - buf = param_state['momentum_buffer'] = p.data.new().resize_as_(p.data).zero_() + buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) buf.mul_(momentum).add_(d_p) else: buf = param_state['momentum_buffer'] diff --git a/torch/optim/sparse_adam.py b/torch/optim/sparse_adam.py new file mode 100644 index 0000000000..381dc94e16 --- /dev/null +++ b/torch/optim/sparse_adam.py @@ -0,0 +1,95 @@ +import math +import torch +from .optimizer import Optimizer + + +class SparseAdam(Optimizer): + """Implements lazy version of Adam algorithm suitable for sparse tensors. + + In this variant, only moments that show up in the gradient get updated, and + only those portions of the gradient get applied to the parameters. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8): + defaults = dict(lr=lr, betas=betas, eps=eps) + super(SparseAdam, self).__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if not grad.is_sparse: + raise RuntimeError('SparseAdam does not support dense gradients, please consider Adam instead') + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + + state['step'] += 1 + + grad = grad.coalesce() # the update is non-linear so indices must be unique + grad_indices = grad._indices() + grad_values = grad._values() + size = grad.size() + + def make_sparse(values): + constructor = grad.new + if grad_indices.dim() == 0 or values.dim() == 0: + return constructor().resize_as_(grad) + return constructor(grad_indices, values, size) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + # Decay the first and second moment running average coefficient + # old <- b * old + (1 - b) * new + # <==> old += (1 - b) * (new - old) + old_exp_avg_values = exp_avg._sparse_mask(grad)._values() + exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1) + exp_avg.add_(make_sparse(exp_avg_update_values)) + old_exp_avg_sq_values = exp_avg_sq._sparse_mask(grad)._values() + exp_avg_sq_update_values = grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2) + exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values)) + + # Dense addition again is intended, avoiding another _sparse_mask + numer = exp_avg_update_values.add_(old_exp_avg_values) + denom = exp_avg_sq_update_values.add_(old_exp_avg_sq_values).sqrt_().add_(group['eps']) + del exp_avg_update_values, exp_avg_sq_update_values + + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 + + p.data.add_(make_sparse(-step_size * numer.div_(denom))) + + return loss |