diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 1a2aebd1..49329fbd 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -115,7 +115,7 @@ __device__ __forceinline__ void DeviceSamplingFromProb( bool greater_than_u_diff[VEC_SIZE]; #ifdef FLASHINFER_CUB_SUBTRACTLEFT_DEFINED BlockAdjacentDifference(temp_storage->block_prim.adj_diff) - .SubtractLeft(greater_than_u_diff, greater_than_u, BoolDiffOp()); + .SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp()); #else BlockAdjacentDifference(temp_storage->block_prim.adj_diff) .FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp());