summaryrefslogtreecommitdiff
path: root/caffe2/image/transform_gpu.cu
blob: bb557429f5ad6229e66824b3a2d4070ccfdbf9ea (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#include "caffe2/core/context_gpu.h"
#include "caffe2/image/transform_gpu.h"
#include "caffe2/utils/conversions.h"

/**
 *
 * Copyright (c) 2016, NVIDIA CORPORATION, All rights reserved
 * Distributed under 2-clause BSD license; see accompanying LICENSE file
 *
 **/

namespace caffe2 {

namespace {

// input in (int8, NHWC), output in (fp32, NCHW)
template <typename In, typename Out>
__global__ void transform_kernel(
    const int N,
    const int C,
    const int H,
    const int W,
    const float* mean,
    const float* std,
    const In* in,
    Out* out) {
  const int n = blockIdx.x;

  const int nStride = C*H*W;

  // pointers to data for this image
  const In* input_ptr = &in[n*nStride];
  Out* output_ptr = &out[n*nStride];

  // either read or write uncoalesced - try reading
  for (int c=0; c < C; ++c) {
    for (int h=threadIdx.y; h < H; h += blockDim.y) {
      for (int w=threadIdx.x; w < W; w += blockDim.x) {
        int in_idx = c + C*w + C*W*h;  // HWC
        int out_idx = c*H*W + h*W + w;  // CHW

        output_ptr[out_idx] = convert::To<float,Out>(
          (convert::To<In,float>(input_ptr[in_idx])-mean[c]) * std[c]);
      }
    }
  }
}

}

template <typename T_IN, typename T_OUT, class Context>

bool TransformOnGPU(
    Tensor& X,
    Tensor* Y,
    Tensor& mean,
    Tensor& std,
    Context* context) {
  // data comes in as NHWC
  const int N = X.dim32(0), C = X.dim32(3), H = X.dim32(1), W = X.dim32(2);
  // data goes out as NCHW
  Y->Resize(std::vector<int>{N,C,H,W});

  auto* input_data = X.template data<T_IN>();
  auto* output_data = Y->template mutable_data<T_OUT>();

  transform_kernel<
    T_IN, T_OUT><<<N, dim3(16, 16), 0, context->cuda_stream()>>>(
      N, C, H, W, mean.template data<float>(), std.template data<float>(),
      input_data, output_data);
  return true;
};

template bool TransformOnGPU<uint8_t, float, CUDAContext>(
    Tensor& X,
    Tensor* Y,
    Tensor& mean,
    Tensor& std,
    CUDAContext* context);

template bool TransformOnGPU<uint8_t, float16, CUDAContext>(
    Tensor& X,
    Tensor* Y,
    Tensor& mean,
    Tensor& std,
    CUDAContext* context);

}  // namespace caffe2