Skip to content

Commit

Permalink
bugfix: suppress alignment warning of sampling kernels (#297)
Browse files Browse the repository at this point in the history
We declare multiple kernels inside the `sampling.cuh` and they use
dynamic shared memory (with the same extern variable name) with
different alignment requirements (e.g. some are alignof 4, some are
alignof 64).

In this PR we use different names for extern variable that have
different alignment requirements to suppress the warning.
  • Loading branch information
yzh119 authored Jun 11, 2024
1 parent aff4cf0 commit 1250b68
Showing 1 changed file with 20 additions and 16 deletions.
36 changes: 20 additions & 16 deletions include/flashinfer/sampling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,11 @@ __global__ void SamplingFromProbKernel(DType* probs, DType* uniform_samples, IdT
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx];

extern __shared__ __align__(alignof(
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>)) uint8_t smem[];
extern __shared__ __align__(
alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
uint8_t smem_sampling[];
auto& temp_storage = reinterpret_cast<
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem);
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
temp_storage.data.sampled_id = d - 1;
__syncthreads();

Expand Down Expand Up @@ -171,10 +172,11 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples,
const uint32_t batch_size = gridDim.x;
const uint32_t bx = blockIdx.x, tx = threadIdx.x;

extern __shared__ __align__(alignof(
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>)) uint8_t smem[];
extern __shared__ __align__(
alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
uint8_t smem_sampling[];
auto& temp_storage = reinterpret_cast<
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem);
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);

vec_t<DType, VEC_SIZE> probs_vec;
DType aggregate;
Expand Down Expand Up @@ -264,10 +266,11 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
}
const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx];

extern __shared__ __align__(alignof(
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>)) uint8_t smem[];
extern __shared__ __align__(
alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
uint8_t smem_sampling[];
auto& temp_storage = reinterpret_cast<
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem);
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);

vec_t<DType, VEC_SIZE> probs_vec;
DType aggregate;
Expand Down Expand Up @@ -454,9 +457,9 @@ __global__ void TopPRenormProbKernel(DType* probs, IdType* renormed_prob, float
const uint32_t row_idx = bx;

extern __shared__ __align__(alignof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>))
uint8_t smem[];
uint8_t smem_renorm[];
auto& temp_storage =
reinterpret_cast<RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(smem);
reinterpret_cast<RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm);
temp_storage.data.max_val = DType(0);
vec_t<DType, VEC_SIZE> probs_vec;
DType probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0
Expand Down Expand Up @@ -543,9 +546,9 @@ __global__ void TopKRenormProbKernel(DType* probs, IdType* renormed_prob, uint32
const uint32_t row_idx = bx;

extern __shared__ __align__(alignof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>))
uint8_t smem[];
uint8_t smem_renorm[];
auto& temp_storage =
reinterpret_cast<RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(smem);
reinterpret_cast<RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm);
temp_storage.data.max_val = DType(0);
vec_t<DType, VEC_SIZE> probs_vec;
DType probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0
Expand Down Expand Up @@ -674,10 +677,11 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
const uint32_t row_idx = bx;

extern __shared__ __align__(alignof(
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>)) uint8_t smem[];
extern __shared__ __align__(
alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
uint8_t smem_sampling[];
auto& temp_storage = reinterpret_cast<
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem);
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);

uint32_t pos = 0;
for (pos = 0; pos < num_speculative_tokens; ++pos) {
Expand Down

0 comments on commit 1250b68

Please sign in to comment.