summaryrefslogtreecommitdiff
path: root/src/caffe/util/math_functions.cpp
blob: c3c0a69ccbf386a0fc3bf3b548842b46e402bebe (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
// Copyright 2013 Yangqing Jia
// Copyright 2014 kloudkl@github

//#include <mkl.h>
#include <eigen3/Eigen/Dense>
#include <boost/random.hpp>

#include <cublas_v2.h>
#include "caffe/common.hpp"
#include "caffe/util/math_functions.hpp"

namespace caffe {

const int data_alignment = Eigen::Aligned; // how is data allocated ?
typedef Eigen::Map<const Eigen::VectorXf, data_alignment> const_map_vector_float_t;
typedef Eigen::Map<Eigen::VectorXf, data_alignment> map_vector_float_t;
typedef Eigen::Map<const Eigen::VectorXd, data_alignment> const_map_vector_double_t;
typedef Eigen::Map<Eigen::VectorXd, data_alignment> map_vector_double_t;

template<>
void caffe_cpu_gemm<float>(const CBLAS_TRANSPOSE TransA,
    const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
    const float alpha, const float* A, const float* B, const float beta,
    float* C) {
  int lda = (TransA == CblasNoTrans) ? K : M;
  int ldb = (TransB == CblasNoTrans) ? N : K;
  cblas_sgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B,
      ldb, beta, C, N);
}

template<>
void caffe_cpu_gemm<double>(const CBLAS_TRANSPOSE TransA,
    const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
    const double alpha, const double* A, const double* B, const double beta,
    double* C) {
  int lda = (TransA == CblasNoTrans) ? K : M;
  int ldb = (TransB == CblasNoTrans) ? N : K;
  cblas_dgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B,
      ldb, beta, C, N);
}

template <>
void caffe_gpu_gemm<float>(const CBLAS_TRANSPOSE TransA,
    const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
    const float alpha, const float* A, const float* B, const float beta,
    float* C) {
  // Note that cublas follows fortran order.
  int lda = (TransA == CblasNoTrans) ? K : M;
  int ldb = (TransB == CblasNoTrans) ? N : K;
  cublasOperation_t cuTransA =
      (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  CUBLAS_CHECK(cublasSgemm(Caffe::cublas_handle(), cuTransB, cuTransA,
      N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
}

template <>
void caffe_gpu_gemm<double>(const CBLAS_TRANSPOSE TransA,
    const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
    const double alpha, const double* A, const double* B, const double beta,
    double* C) {
  // Note that cublas follows fortran order.
  int lda = (TransA == CblasNoTrans) ? K : M;
  int ldb = (TransB == CblasNoTrans) ? N : K;
  cublasOperation_t cuTransA =
      (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  cublasOperation_t cuTransB =
      (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  CUBLAS_CHECK(cublasDgemm(Caffe::cublas_handle(), cuTransB, cuTransA,
      N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
}

template <>
void caffe_cpu_gemv<float>(const CBLAS_TRANSPOSE TransA, const int M,
    const int N, const float alpha, const float* A, const float* x,
    const float beta, float* y) {
  cblas_sgemv(CblasRowMajor, TransA, M, N, alpha, A, N, x, 1, beta, y, 1);
}

template <>
void caffe_cpu_gemv<double>(const CBLAS_TRANSPOSE TransA, const int M,
    const int N, const double alpha, const double* A, const double* x,
    const double beta, double* y) {
  cblas_dgemv(CblasRowMajor, TransA, M, N, alpha, A, N, x, 1, beta, y, 1);
}

template <>
void caffe_gpu_gemv<float>(const CBLAS_TRANSPOSE TransA, const int M,
    const int N, const float alpha, const float* A, const float* x,
    const float beta, float* y) {
  cublasOperation_t cuTransA =
      (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
  CUBLAS_CHECK(cublasSgemv(Caffe::cublas_handle(), cuTransA, N, M, &alpha,
      A, N, x, 1, &beta, y, 1));
}

template <>
void caffe_gpu_gemv<double>(const CBLAS_TRANSPOSE TransA, const int M,
    const int N, const double alpha, const double* A, const double* x,
    const double beta, double* y) {
  cublasOperation_t cuTransA =
      (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
  CUBLAS_CHECK(cublasDgemv(Caffe::cublas_handle(), cuTransA, N, M, &alpha,
      A, N, x, 1, &beta, y, 1));
}

template <>
void caffe_axpy<float>(const int N, const float alpha, const float* X,
    float* Y) { cblas_saxpy(N, alpha, X, 1, Y, 1); }

template <>
void caffe_axpy<double>(const int N, const double alpha, const double* X,
    double* Y) { cblas_daxpy(N, alpha, X, 1, Y, 1); }


template <>
void caffe_gpu_axpy<float>(const int N, const float alpha, const float* X,
    float* Y) {
  CUBLAS_CHECK(cublasSaxpy(Caffe::cublas_handle(), N, &alpha, X, 1, Y, 1));
}

template <>
void caffe_gpu_axpy<double>(const int N, const double alpha, const double* X,
    double* Y) {
  CUBLAS_CHECK(cublasDaxpy(Caffe::cublas_handle(), N, &alpha, X, 1, Y, 1));
}

template <>
void caffe_axpby<float>(const int N, const float alpha, const float* X,
    const float beta, float* Y) {
    // y := a*x + b*y
    //cblas_saxpby(N, alpha, X, 1, beta, Y, 1);
    map_vector_float_t(Y, N) *= beta;
    map_vector_float_t(Y, N) += (alpha * const_map_vector_float_t(X, N));

}

template <>
void caffe_axpby<double>(const int N, const double alpha, const double* X,
    const double beta, double* Y) {
    // y := a*x + b*y
  //cblas_daxpby(N, alpha, X, 1, beta, Y, 1);
    map_vector_double_t(Y, N) *= beta;
    map_vector_double_t(Y, N) += (alpha * const_map_vector_double_t(X, N));
}

template <>
void caffe_copy<float>(const int N, const float* X, float* Y) {
  cblas_scopy(N, X, 1, Y, 1);
}

template <>
void caffe_copy<double>(const int N, const double* X, double* Y) {
  cblas_dcopy(N, X, 1, Y, 1);
}

template <>
void caffe_gpu_copy<float>(const int N, const float* X, float* Y) {
  CUBLAS_CHECK(cublasScopy(Caffe::cublas_handle(), N, X, 1, Y, 1));
}

template <>
void caffe_gpu_copy<double>(const int N, const double* X, double* Y) {
  CUBLAS_CHECK(cublasDcopy(Caffe::cublas_handle(), N, X, 1, Y, 1));
}

template <>
void caffe_scal<float>(const int N, const float alpha, float *X) {
  cblas_sscal(N, alpha, X, 1);
}

template <>
void caffe_scal<double>(const int N, const double alpha, double *X) {
  cblas_dscal(N, alpha, X, 1);
}

template <>
void caffe_gpu_scal<float>(const int N, const float alpha, float *X) {
  CUBLAS_CHECK(cublasSscal(Caffe::cublas_handle(), N, &alpha, X, 1));
}

template <>
void caffe_gpu_scal<double>(const int N, const double alpha, double *X) {
  CUBLAS_CHECK(cublasDscal(Caffe::cublas_handle(), N, &alpha, X, 1));
}

template <>
void caffe_gpu_axpby<float>(const int N, const float alpha, const float* X,
    const float beta, float* Y) {
  caffe_gpu_scal<float>(N, beta, Y);
  caffe_gpu_axpy<float>(N, alpha, X, Y);
}

template <>
void caffe_gpu_axpby<double>(const int N, const double alpha, const double* X,
    const double beta, double* Y) {
  caffe_gpu_scal<double>(N, beta, Y);
  caffe_gpu_axpy<double>(N, alpha, X, Y);
}

template <>
void caffe_sqr<float>(const int n, const float* a, float* y) {
  //vsSqr(n, a, y);
  map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array().sqrt();
}

template <>
void caffe_sqr<double>(const int n, const double* a, double* y) {
    //vdSqr(n, a, y);
    map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array().sqrt();
}

template <>
void caffe_add<float>(const int n, const float* a, const float* b,
    float* y) {
    //vsAdd(n, a, b, y);
    map_vector_float_t(y, n) = const_map_vector_float_t(a, n) + const_map_vector_float_t(b, n);
}

template <>
void caffe_add<double>(const int n, const double* a, const double* b,
    double* y) {
    //vdAdd(n, a, b, y);
    map_vector_double_t(y, n) = const_map_vector_double_t(a, n) + const_map_vector_double_t(b, n);
}

template <>
void caffe_sub<float>(const int n, const float* a, const float* b,
    float* y) {
    //vsSub(n, a, b, y);
    map_vector_float_t(y, n) = const_map_vector_float_t(a, n) - const_map_vector_float_t(b, n);
}

template <>
void caffe_sub<double>(const int n, const double* a, const double* b,
    double* y) {
    //vdSub(n, a, b, y);
    map_vector_double_t(y, n) = const_map_vector_double_t(a, n) - const_map_vector_double_t(b, n);
}

template <>
void caffe_mul<float>(const int n, const float* a, const float* b,
    float* y) {
    //vsMul(n, a, b, y);
    map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array() * const_map_vector_float_t(b, n).array();
}

template <>
void caffe_mul<double>(const int n, const double* a, const double* b,
    double* y) {
    //vdMul(n, a, b, y);
    map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array() * const_map_vector_double_t(b, n).array();
}

template <>
void caffe_div<float>(const int n, const float* a, const float* b,
    float* y) {
    //vsDiv(n, a, b, y);
    map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array() / const_map_vector_float_t(b, n).array();
}

template <>
void caffe_div<double>(const int n, const double* a, const double* b,
    double* y) {
    //vdDiv(n, a, b, y);
    map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array() / const_map_vector_double_t(b, n).array();
}

template <>
void caffe_powx<float>(const int n, const float* a, const float b,
    float* y) {
    //vsPowx(n, a, b, y);
    map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array().pow(b);
}

template <>
void caffe_powx<double>(const int n, const double* a, const double b,
    double* y) {
    //vdPowx(n, a, b, y);
    map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array().pow(b);
}

template <>
void caffe_vRngUniform<float>(const int n, float* r,
    const float a, const float b) {
  //VSL_CHECK(vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(),
  //    n, r, a, b));

  // FIXME check if boundaries are handled in the same way ?
  boost::uniform_real<float> random_distribution(a, b);
  Caffe::random_generator_t &generator = Caffe::vsl_stream();

  for(int i = 0; i < n; i += 1)
  {
      r[i] = random_distribution(generator);
  }
}

template <>
void caffe_vRngUniform<double>(const int n, double* r,
    const double a, const double b) {
  //VSL_CHECK(vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(),
  //    n, r, a, b));

    // FIXME check if boundaries are handled in the same way ?
    boost::uniform_real<double> random_distribution(a, b);
    Caffe::random_generator_t &generator = Caffe::vsl_stream();

    for(int i = 0; i < n; i += 1)
    {
        r[i] = random_distribution(generator);
    }
}

template <>
void caffe_vRngGaussian<float>(const int n, float* r, const float a,
    const float sigma) {
  //VSL_CHECK(vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER,
//      Caffe::vsl_stream(), n, r, a, sigma));

    // FIXME check if parameters are handled in the same way ?
    boost::normal_distribution<float> random_distribution(a, sigma);
    Caffe::random_generator_t &generator = Caffe::vsl_stream();

    for(int i = 0; i < n; i += 1)
    {
        r[i] = random_distribution(generator);
    }
}


template <>
void caffe_vRngGaussian<double>(const int n, double* r, const double a,
    const double sigma) {
  //VSL_CHECK(vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER,
  //    Caffe::vsl_stream(), n, r, a, sigma));

    // FIXME check if parameters are handled in the same way ?
    boost::normal_distribution<double> random_distribution(a, sigma);
    Caffe::random_generator_t &generator = Caffe::vsl_stream();

    for(int i = 0; i < n; i += 1)
    {
        r[i] = random_distribution(generator);
    }
}


template <typename Dtype>
void caffe_vRngBernoulli(const int n, Dtype* r, const double p)
{
    // FIXME check if parameters are handled in the same way ?
    boost::bernoulli_distribution<Dtype> random_distribution(p);
    Caffe::random_generator_t &generator = Caffe::vsl_stream();

    for(int i = 0; i < n; i += 1)
    {
        r[i] = random_distribution(generator);
    }
}

template void caffe_vRngBernoulli<int>(const int n, int* r, const double p);


template <>
void caffe_exp<float>(const int n, const float* a, float* y) {
    //vsExp(n, a, y);
    map_vector_float_t(y, n) = const_map_vector_float_t(a, n).array().exp();
}

template <>
void caffe_exp<double>(const int n, const double* a, double* y) {
    //vdExp(n, a, y);
    map_vector_double_t(y, n) = const_map_vector_double_t(a, n).array().exp();
}

template <>
float caffe_cpu_dot<float>(const int n, const float* x, const float* y) {
  return cblas_sdot(n, x, 1, y, 1);
}

template <>
double caffe_cpu_dot<double>(const int n, const double* x, const double* y) {
  return cblas_ddot(n, x, 1, y, 1);
}

template <>
void caffe_gpu_dot<float>(const int n, const float* x, const float* y,
    float* out) {
  CUBLAS_CHECK(cublasSdot(Caffe::cublas_handle(), n, x, 1, y, 1, out));
}

template <>
void caffe_gpu_dot<double>(const int n, const double* x, const double* y,
    double * out) {
  CUBLAS_CHECK(cublasDdot(Caffe::cublas_handle(), n, x, 1, y, 1, out));
}

template <>
int caffe_hamming_distance<float>(const int n, const float* x,
                                  const float* y) {
  int dist = 0;
  for (int i = 0; i < n; ++i) {
    dist += __builtin_popcount(static_cast<uint32_t>(x[i]) ^
                               static_cast<uint32_t>(y[i]));
  }
  return dist;
}

template <>
int caffe_hamming_distance<double>(const int n, const double* x,
                                   const double* y) {
  int dist = 0;
  for (int i = 0; i < n; ++i) {
    dist += __builtin_popcountl(static_cast<uint64_t>(x[i]) ^
                                static_cast<uint64_t>(y[i]));
  }
  return dist;
}

}  // namespace caffe