forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTHCReduceApplyUtils.cuh
117 lines (104 loc) · 4.08 KB
/
THCReduceApplyUtils.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#ifndef THC_REDUCE_APPLY_UTILS_INC
#define THC_REDUCE_APPLY_UTILS_INC
#include <cuda.h>
#include <assert.h>
#include <THC/THCGeneral.h>
#include <THC/THCTensor.h>
#include <THC/THCDeviceUtils.cuh>
#include <THC/THCTensorInfo.cuh>
// Enum that indicates whether tensor arguments are read/write or
// read-only
enum TensorArgType { ReadWrite, ReadOnly };
// Reduce N values concurrently, i.e. suppose N = 2, and there are 4 threads:
// (1, 2), (3, 4), (5, 6), (7, 8), then the return in threadVals for thread 0
// is (1 + 3 + 5 + 7, 2 + 4 + 6 + 8) = (16, 20)
//
// If smem is not used again, there is no need to __syncthreads before this
// call. However, if smem will be used, e.g., this function is called in a loop,
// then __syncthreads is needed either before or afterwards to prevent non-0
// threads overriding smem in the next loop before num-0 thread reads from it.
template <typename T, typename ReduceOp, int N>
__device__ void reduceNValuesInBlock(T *smem,
T threadVals[N],
const unsigned int numVals,
ReduceOp reduceOp,
T init) {
if (numVals == 0) {
#pragma unroll
for (int i = 0; i < N; ++i) {
threadVals[i] = init;
}
return;
}
// We store each of the N values contiguously, so if N = 2, all values for
// the first threadVal for each thread in the block are stored followed by
// all of the values for the second threadVal for each thread in the block
if (threadIdx.x < numVals) {
#pragma unroll
for (int i = 0; i < N; ++i) {
smem[i * numVals + threadIdx.x] = threadVals[i];
}
}
__syncthreads();
// Number of lanes in the final reduction --> this is used to determine
// where to put the outputs of each of the n things we are reducing. If
// nLP = 32, then we have the 32 outputs for the first threadVal,
// followed by the 32 outputs for the second threadVal, etc.
const unsigned int numLanesParticipating = min(numVals, warpSize);
if (numVals > warpSize && ((threadIdx.x / warpSize) == 0 )) {
#pragma unroll
for (int i = 0; i < N; ++i) {
threadVals[i] = threadIdx.x < numVals ? threadVals[i] : init;
}
for (int i = warpSize + threadIdx.x; i < numVals; i += warpSize) {
#pragma unroll
for (int j = 0; j < N; ++j) {
threadVals[j] = reduceOp(threadVals[j], smem[j * numVals + i]);
}
}
#pragma unroll
for (int i = 0; i < N; ++i) {
smem[i * numLanesParticipating + threadIdx.x] = threadVals[i];
}
}
__syncthreads();
if (threadIdx.x == 0) {
if (numLanesParticipating == 32) {
#pragma unroll
for (int i = 0; i < N; ++i) {
#pragma unroll
for (int j = 1; j < 32; ++j) {
threadVals[i] = reduceOp(threadVals[i], smem[i * 32 + j]);
}
}
} else {
#pragma unroll
for (int i = 0; i < N; ++i) {
for (int j = 1; j < numLanesParticipating; ++j) {
threadVals[i] = reduceOp(threadVals[i], smem[i * numVals + j]);
}
}
}
}
}
// Block-wide reduction in shared memory helper; only threadIdx.x == 0 will
// return the reduced value
//
// If smem is not used again, there is no need to __syncthreads before this
// call. However, if smem will be used, e.g., this function is called in a loop,
// then __syncthreads is needed either before or afterwards to prevent non-0
// threads overriding smem in the next loop before num-0 thread reads from it.
template <typename T, typename ReduceOp>
__device__ T reduceBlock(T* smem,
const unsigned int numVals,
T threadVal,
ReduceOp reduceOp,
T init) {
reduceNValuesInBlock<T, ReduceOp, 1>(smem, &threadVal, numVals, reduceOp, init);
return threadVal;
}
// Make sure the given tensor doesn't have too many dimensions
void THCCheckTensorDims(THCState* state, THCudaTensor* tensor, int arg);
// Produces a grid with at least one point per tile
TORCH_CUDA_CU_API bool THC_getGridFromTiles(ptrdiff_t gridTiles, dim3& grid);
#endif // THC_REDUCE_APPLY_UTILS_INC