summaryrefslogtreecommitdiff
path: root/torch/legacy/optim/rmsprop.py
blob: 351c8c3fe6963368f90ed55d76f9d6af6f298470 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch


def rmsprop(opfunc, x, config, state=None):
    """ An implementation of RMSprop

    ARGS:

    - 'opfunc' : a function that takes a single input (X), the point
                of a evaluation, and returns f(X) and df/dX
    - 'x'      : the initial point
    - 'config` : a table with configuration parameters for the optimizer
    - 'config['learningRate']'      : learning rate
    - 'config['alpha']'             : smoothing constant
    - 'config['epsilon']'           : value with which to initialise m
    - 'config['weightDecay']'       : weight decay
    - 'state'                    : a table describing the state of the optimizer;
                                after each call the state is modified
    - 'state['m']'                  : leaky sum of squares of parameter gradients,
    - 'state['tmp']'                : and the square root (with epsilon smoothing)

    RETURN:
    - `x`     : the new x vector
    - `f(x)`  : the function, evaluated before the update

    """
    # (0) get/update state
    if config is None and state is None:
        raise ValueError("rmsprop requires a dictionary to retain state between iterations")
    state = state if state is not None else config
    lr = config.get('learningRate', 1e-2)
    alpha = config.get('alpha', 0.99)
    epsilon = config.get('epsilon', 1e-8)
    wd = config.get('weightDecay', 0)

    # (1) evaluate f(x) and df/dx
    fx, dfdx = opfunc(x)

    # (2) weight decay
    if wd != 0:
        dfdx.add_(wd, x)

    # (3) initialize mean square values and square gradient storage
    if not 'm' in state:
        state['m'] = x.new().resize_as_(dfdx).zero_()
        state['tmp'] = x.new().resize_as_(dfdx)

    # (4) calculate new (leaky) mean squared values
    state['m'].mul_(alpha)
    state['m'].addcmul_(1.0 - alpha, dfdx, dfdx)

    # (5) perform update
    torch.sqrt(state['m'], out=state['tmp']).add_(epsilon)
    x.addcdiv_(-lr, dfdx, state['tmp'])

    # return x*, f(x) before optimization
    return x, fx