1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
|
#include "gftrl_op.h"
namespace caffe2 {
// Computes one coordinate
template <typename T>
inline void gftrl_compute(
const T& w,
const T& n,
const T& z,
const T& g,
T& nw,
T& nn,
T& nz,
const T& z_norm,
const int OutputDim,
const GFtrlParams<T>& params) {
auto new_n = n + g * g;
auto sigma = (sqrt(new_n) - sqrt(n)) * params.alphaInv;
nn = new_n;
nz = z + g - sigma * w;
// update the weight
if (z_norm > params.lambda1 * std::sqrt(OutputDim)) {
nw = nz * (params.lambda1 * std::sqrt(OutputDim) / z_norm - 1) /
((params.beta + sqrt(new_n)) * params.alphaInv + params.lambda2);
} else {
nw = 0.0;
}
}
template <typename Context, typename T>
void gftrl_update(
int OutputDim, // # of output nodes
int InputDim, // # of input features
const T* w,
const T* nz,
const T* g,
T* new_w,
T* new_nz,
const GFtrlParams<T>& params,
Context* /*context*/) {
for (auto j = 0; j < InputDim; ++j) {
T z_norm = 0.0;
for (auto i = 0; i < OutputDim; ++i) {
int idx = i * InputDim + j;
auto new_n = nz[idx * 2] + g[idx] * g[idx];
auto sigma = (sqrt(new_n) - sqrt(nz[idx * 2])) * params.alphaInv;
auto new_z = nz[idx * 2 + 1] + g[idx] - sigma * w[idx];
z_norm = z_norm + new_z * new_z;
}
z_norm = sqrt(z_norm);
for (auto i = 0; i < OutputDim; ++i) {
int idx = i * InputDim + j;
gftrl_compute(
w[idx],
nz[idx * 2],
nz[idx * 2 + 1],
g[idx],
new_w[idx],
new_nz[idx * 2],
new_nz[idx * 2 + 1],
z_norm,
OutputDim,
params);
}
}
}
template <typename T, typename Context>
bool GFtrlOp<T, Context>::RunOnDevice() {
// run time learning rate override
if (ALPHA < InputSize()) {
CAFFE_ENFORCE_EQ(Input(ALPHA).numel(), 1, "alpha should be real-valued");
params_.alphaInv = 1.0 / *(Input(ALPHA).template data<T>());
}
CAFFE_ENFORCE_EQ(Input(GRAD).numel(), Input(VAR).numel());
CAFFE_ENFORCE_EQ(Input(GRAD).numel() * 2, Input(N_Z).numel());
Output(OUTPUT_VAR)->ResizeLike(Input(VAR));
Output(OUTPUT_N_Z)->ResizeLike(Input(N_Z));
gftrl_update<Context>(
Input(GRAD).size(0), // # of output nodes
Input(GRAD).numel() / Input(GRAD).size(0), // # of input features
Input(VAR).template data<T>(),
Input(N_Z).template data<T>(),
Input(GRAD).template data<T>(),
Output(OUTPUT_VAR)->template mutable_data<T>(),
Output(OUTPUT_N_Z)->template mutable_data<T>(),
params_,
&context_);
return true;
}
namespace {
REGISTER_CPU_OPERATOR(GFtrl, GFtrlOp<float, CPUContext>);
OPERATOR_SCHEMA(GFtrl).NumInputs(3, 4).NumOutputs(2).AllowInplace({{0, 0},
{1, 1}});
SHOULD_NOT_DO_GRADIENT(GFtrl);
} // namespace
} // namespace caffe2
|