summaryrefslogtreecommitdiff
path: root/third_party/nccl/src/primitives.h
blob: bcaeca8f901edd5026e27bc218f0a721f9d492d2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
/*************************************************************************
 * Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
 *
 * See LICENSE.txt for license information
 ************************************************************************/

#ifndef PRIMITIVES_H_
#define PRIMITIVES_H_

#include <type_traits>
#include "copy_kernel.h" // for FuncPassA
#include "reduce_kernel.h" // for reduction funcs


/* Defines primitive operations: Copy, Reduce, DoubleCopy, and ReduceCopy.
 *
 * In order to reduce the reptetion of template arguments, the operations
 * are bundled as static methods of the Primitives class.
 *
 * Each primitive operation copies/reduces a contiguous buffer and syncs
 * an optional set of flags against a sub-step counter. The sync value is
 * based on the step parameter. Sync flags must be of type WaitFlag or
 * PostFlag. The primitive routines wait for all WaitFlag args to attain
 * at least a value of SUBSTEPS*(step-1)+substep+1 (i.e. completion of
 * corresponding substep by previous step) before executing the transfer.
 * After each substep is transfered, all PostFlag arguments get updated to
 * the value SUBSTEPS*step+substep+1.
 */


class WaitFlag {
  volatile int * const flag;
  const int shift;
  public:
  __device__ __forceinline__
  WaitFlag(volatile int * const flag, const int shift) : flag(flag), shift(shift) { }
  __device__ __forceinline__
  void wait(int val) { while (*flag < (val + shift)) /*SPIN*/; }
};


class PostFlag {
  volatile int * const flag;
  const int shift;
  public:
  __device__ __forceinline__
  PostFlag(volatile int* const flag, const int shift) : flag(flag), shift(shift) { }
  __device__ __forceinline__
  void post(int val) { *flag = (val + shift); }
};


// Helper to check if any argument is of type T.
// e.g. AnyAre<WaitFlag>(Flag1, Flag2, ...)
template<typename T> __device__ __forceinline__
bool AnyAre() { return false; }

template<typename T, typename FIRST_T, typename... TAIL_Ts>
__device__ __forceinline__
bool AnyAre(FIRST_T first, TAIL_Ts... tail) {
  return std::is_same<T, FIRST_T>::value || AnyAre<T>(tail...);
}


// Wait on all WaitFlags, ignore PostFlags
__device__ __forceinline__
void WaitOnFlags(int val) { }

template <typename... TAIL_Ts> __device__ __forceinline__
void WaitOnFlags(int val, WaitFlag flag, TAIL_Ts... tail) {
  flag.wait(val);
  WaitOnFlags(val, tail...);
}

template <typename... TAIL_Ts> __device__ __forceinline__
void WaitOnFlags(int val, PostFlag, TAIL_Ts... tail) {
  WaitOnFlags(val, tail...);
}


// Post all PostFlags, ingnore WaitFlags
__device__ __forceinline__
void PostToFlags(int val) { }

template <typename... TAIL_Ts> __device__ __forceinline__
void PostToFlags(int val, WaitFlag flag, TAIL_Ts... tail) {
  PostToFlags(val, tail...);
}

template <typename... TAIL_Ts> __device__ __forceinline__
void PostToFlags(int val, PostFlag flag, TAIL_Ts... tail) {
  flag.post(val);
  PostToFlags(val, tail...);
}


// Create pointer arithmetic syntax that doesn't break for nullptr_t
template <typename Tptr> __device__ __forceinline__
Tptr ptradd(Tptr ptr, int i) {
  return ptr + i;
}

__device__ __forceinline__
std::nullptr_t ptradd(std::nullptr_t ptr, int i) {
  return nullptr;
}


// Implementation of primitive types
template <int THREADS, int UNROLL, int SUBSTEPS, typename T, typename REDOP=FuncSum<T> >
class Primitives {
  private:
  template <typename SRC2_T, // either T* or nullptr_t
            typename DST2_T, // either T* or nullptr_t
            typename... SYNC_Ts> // either WaitFunc or PostFunc
  static __device__ __forceinline__ void
  GenericOp(const T*     src1,
            const SRC2_T src2,
                  T*     dst1,
                  DST2_T dst2,
            int len, int maxoffset, int step, SYNC_Ts... flags) {

    enum { noSrc2 = std::is_same<SRC2_T, std::nullptr_t>::value };
    enum { noDst2 = std::is_same<DST2_T, std::nullptr_t>::value };
    static_assert(noSrc2 || std::is_same<SRC2_T, const T*>::value,
        "src2 must be of type T* or nullptr_t");
    static_assert(noDst2 || std::is_same<DST2_T, T*>::value,
        "dst2 must be of type T* or nullptr_t");

    using OpType = typename std::conditional<noSrc2, FuncPassA<T>, REDOP>::type;

    if (threadIdx.x < THREADS) {
      int sliceSize = len / SUBSTEPS;
      int sliceOffset = 0;
      #pragma unroll 1
      for (int sub=0; sub<SUBSTEPS; ++sub) {
        if (AnyAre<WaitFlag>(flags...)) {
          if (threadIdx.x == 0) {
            WaitOnFlags(SUBSTEPS*step + sub + 1, flags...);
          }
          asm volatile ("bar.sync 1, %0;" :: "r"(THREADS));
        }
        ReduceOrCopy
            <
             UNROLL,
             THREADS,
             OpType,
             T,
             !std::is_same<DST2_T, std::nullptr_t>::value, // HAS_DEST1
             !std::is_same<SRC2_T, std::nullptr_t>::value  // HAS_SRC1
            >
            (
             threadIdx.x,
             ptradd(dst1, sliceOffset),
             ptradd(dst2, sliceOffset),
             ptradd(src1, sliceOffset),
             ptradd(src2, sliceOffset),
             min(sliceSize, maxoffset-sliceOffset)
            );
        if (AnyAre<PostFlag>(flags...)) {
          __syncthreads();
        }
        sliceOffset += sliceSize;
      }
    } else {
      for(int sub=0; sub<SUBSTEPS; ++sub) {
        if (AnyAre<PostFlag>(flags...)) {
          __syncthreads();
          __threadfence_system();
          PostToFlags(SUBSTEPS*step + sub + 1, flags...);
        }
      }
    }
  }

  public:
  template <typename... SYNC_Ts>
  static __device__ __forceinline__ void
  Copy(const T* src, T* dst,
      int len, int maxOffset, int step, SYNC_Ts... flags) {
    GenericOp(src, nullptr, dst, nullptr, len, maxOffset, step, flags...);
  }

  template <typename... SYNC_Ts>
  static __device__ __forceinline__ void
  DoubleCopy(const T* src, T* dst1, T* dst2,
      int len, int maxOffset, int step, SYNC_Ts... flags) {
    GenericOp(src, nullptr, dst1, dst2, len, maxOffset, step, flags...);
  }

  template <typename... SYNC_Ts>
  static __device__ __forceinline__ void
  Reduce(const T* src1, const T* src2, T* dst,
      int len, int maxOffset, int step, SYNC_Ts... flags) {
    GenericOp(src1, src2, dst, nullptr, len, maxOffset, step, flags...);
  }

  template <typename... SYNC_Ts>
  static __device__ __forceinline__ void
  ReduceCopy(const T* src1, const T* src2, T* dst1, T* dst2,
      int len, int maxOffset, int step, SYNC_Ts... flags) {
    GenericOp(src1, src2, dst1, dst2, len, maxOffset, step, flags...);
  }
};

#endif // end include guard