summaryrefslogtreecommitdiff
path: root/caffe2/operators/elu_op.cu
blob: 95f2e1ebd74aaf47894212ec7e120613e4a130ac (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#include "caffe2/operators/elu_op.h"

#include <algorithm>
#include <functional>

#include "caffe2/core/context_gpu.h"

namespace caffe2 {

namespace {

template <typename T>
__global__ void EluCUDAKernel(const int N, const T alpha, const T* X, T* Y);

template <>
__global__ void
EluCUDAKernel<float>(const int N, const float alpha, const float* X, float* Y) {
  CUDA_1D_KERNEL_LOOP(i, N) {
#if __CUDA_ARCH__ >= 350
    Y[i] =
        __ldg(X + i) < 0 ? alpha * (expf(__ldg(X + i)) - 1.0f) : __ldg(X + i);
#else
    Y[i] = X[i] < 0 ? alpha * (expf(X[i]) - 1.0f) : X[i];
#endif
  }
}

template <typename T>
__global__ void EluGradientCUDAKernel(
    const int N,
    const T alpha,
    const T* dY,
    const T* Y,
    T* dX) {
  CUDA_1D_KERNEL_LOOP(i, N) {
#if __CUDA_ARCH__ >= 350
    dX[i] = __ldg(Y + i) < 0 ? __ldg(dY + i) * (__ldg(Y + i) + alpha)
                             : __ldg(dY + i);
#else
    dX[i] = Y[i] < 0 ? dY[i] * (Y[i] + alpha) : dY[i];
#endif
  }
}

} // namespace

template <>
template <typename T>
bool EluFunctor<CUDAContext>::
operator()(const int N, const T* X, T* Y, CUDAContext* context) const {
  EluCUDAKernel<T>
      <<<CAFFE_GET_BLOCKS(N),
         CAFFE_CUDA_NUM_THREADS,
         0,
         context->cuda_stream()>>>(N, alpha, X, Y);
  return true;
}

template <>
template <typename T>
bool EluGradientFunctor<CUDAContext>::Forward(
    const std::vector<int>& Y_dims,
    const std::vector<int>& /* dY_dims */,
    const T* Y,
    const T* dY,
    T* dX,
    CUDAContext* context) const {
  const int size = std::accumulate(
      Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
  EluGradientCUDAKernel<T>
      <<<CAFFE_GET_BLOCKS(size),
         CAFFE_CUDA_NUM_THREADS,
         0,
         context->cuda_stream()>>>(size, alpha, dY, Y, dX);
  return true;
}

REGISTER_CUDA_OPERATOR(
    Elu,
    UnaryElementwiseWithArgsOp<
        TensorTypes<float>,
        CUDAContext,
        EluFunctor<CUDAContext>>);
REGISTER_CUDA_OPERATOR(
    EluGradient,
    BinaryElementwiseWithArgsOp<
        TensorTypes<float>,
        CUDAContext,
        EluGradientFunctor<CUDAContext>>);

} // namespace caffe2