1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
|
#include "caffe2/operators/swish_op.h"
#include "caffe2/core/context_gpu.h"
namespace caffe2 {
namespace {
template <typename T>
__global__ void SwishCUDAKernel(const int N, const T* X, T* Y) {
CUDA_1D_KERNEL_LOOP(i, N) {
#if __CUDA_ARCH__ >= 350
Y[i] = __ldg(X + i) / (T(1) + exp(-__ldg(X + i)));
#else
Y[i] = X[i] / (T(1) + exp(-X[i]));
#endif
}
}
template <typename T>
__global__ void SwishGradientCUDAKernel(
const int N,
const T* X,
const T* Y,
const T* dY,
T* dX) {
CUDA_1D_KERNEL_LOOP(i, N) {
#if __CUDA_ARCH__ >= 350
dX[i] = __ldg(dY + i) *
(__ldg(Y + i) + (T(1) - __ldg(Y + i)) / (T(1) + exp(-__ldg(X + i))));
#else
dX[i] = dY[i] * (Y[i] + (T(1) - Y[i]) / (T(1) + exp(-X[i])));
#endif
}
}
} // namespace
template <>
template <typename T>
bool SwishFunctor<CUDAContext>::
operator()(const int N, const T* X, T* Y, CUDAContext* context) const {
SwishCUDAKernel<T>
<<<CAFFE_GET_BLOCKS(N),
CAFFE_CUDA_NUM_THREADS,
0,
context->cuda_stream()>>>(N, X, Y);
return true;
}
template <>
template <typename T>
bool SwishGradientOp<CUDAContext>::DoRunWithType() {
auto& Xin = Input(X);
auto& Yin = Input(Y);
auto& DYin = Input(DY);
auto* DXout = Output(DX);
CAFFE_ENFORCE_EQ(Xin.size(), Yin.size());
CAFFE_ENFORCE_EQ(DYin.size(), Yin.size());
DXout->ResizeLike(Yin);
const int n = Xin.size();
const T* x = Xin.template data<T>();
const T* y = Yin.template data<T>();
const T* dy = DYin.template data<T>();
T* dx = DXout->template mutable_data<T>();
SwishGradientCUDAKernel<T>
<<<CAFFE_GET_BLOCKS(n),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(n, x, y, dy, dx);
return true;
}
template <>
bool SwishGradientOp<CUDAContext>::RunOnDevice() {
return DispatchHelper<TensorTypes<float, double>>::call(this, Input(X));
}
REGISTER_CUDA_OPERATOR(
Swish,
UnaryElementwiseOp<
TensorTypes<float, double>,
CUDAContext,
SwishFunctor<CUDAContext>>);
REGISTER_CUDA_OPERATOR(SwishGradient, SwishGradientOp<CUDAContext>);
} // namespace caffe2
|