-
-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] Dynamic Per-Token Activation Quantization #5037
Changes from 77 commits
4d27a2c
92b3703
2a3eb83
3dd1fe8
f2f8c52
c9308eb
d9d49b5
b111ee6
c31a7af
ca01b39
f0197d4
4624b46
75757d5
e1df0eb
bc0991c
74ad650
43c43f3
cf5600f
169ce7f
03b53e7
f9df31b
ba4b6b3
3c223c6
b27f31a
b589cdd
98159cf
8dbeb31
5eeb40a
c55e023
f5cbbd3
a685957
4dfb37f
de81f9e
15f1863
bd53847
b2926f3
1274386
18640c8
5c5dc84
a44b4a0
6f0e6e1
0090454
4b10fd7
68a59c7
b0afe67
4f4951e
869de3f
e68e391
d77cf50
51a4e59
6777319
54c797a
6bcab22
ece93e1
1d87a99
3dd1b5f
fed7cdd
66719a9
2ec6a2c
0c7f870
34e2e12
e79517e
39e66d1
59f8ec1
7a83601
9ea47c8
7abb2c8
eb4e119
d62930d
80b6fac
fa1ceef
7075318
60a6d73
f36519b
2c6e580
b3d692a
f3bf9e3
2bd62e0
460f514
dfcd61a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -3,6 +3,7 @@ | |||||||||
#include <cmath> | ||||||||||
|
||||||||||
#include "../../dispatch_utils.h" | ||||||||||
#include "../../reduction_utils.cuh" | ||||||||||
|
||||||||||
static inline __device__ int8_t float_to_int8_rn(float x) { | ||||||||||
#ifdef USE_ROCM | ||||||||||
|
@@ -38,6 +39,38 @@ __global__ void static_scaled_int8_quant_kernel( | |||||||||
float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale); | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
template <typename scalar_t, typename scale_type> | ||||||||||
__global__ void dynamic_scaled_int8_quant_kernel( | ||||||||||
const scalar_t* __restrict__ input, int8_t* __restrict__ out, | ||||||||||
scale_type scale, const int hidden_size) { | ||||||||||
const int tid = threadIdx.x; | ||||||||||
const int token_idx = blockIdx.x; | ||||||||||
|
||||||||||
float amax_val = 0.0f; | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: would it be more readable as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. |
||||||||||
const float zero = 0.0f; | ||||||||||
|
||||||||||
for (int i = tid; i < hidden_size; i += blockDim.x) { | ||||||||||
float val = (float)input[token_idx * hidden_size + i]; | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: It's best to use static cast instead of C-style casts when possible, since they are checked by the compiler.
Suggested change
|
||||||||||
val = val > zero ? val : -val; | ||||||||||
if (val > amax_val) amax_val = val; | ||||||||||
} | ||||||||||
|
||||||||||
__shared__ float s_amax; | ||||||||||
const float block_amax_val = blockReduceMax(amax_val); | ||||||||||
if (tid == 0) { | ||||||||||
s_amax = block_amax_val; | ||||||||||
scale[token_idx] = block_amax_val / 127.0f; | ||||||||||
} | ||||||||||
__syncthreads(); | ||||||||||
|
||||||||||
float tmp_scale = 127.0f / s_amax; | ||||||||||
for (int i = tid; i < hidden_size; i += blockDim.x) { | ||||||||||
out[token_idx * hidden_size + i] = float_to_int8_rn( | ||||||||||
((float)input[token_idx * hidden_size + i]) * tmp_scale); | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
} // namespace vllm | ||||||||||
|
||||||||||
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] | ||||||||||
|
@@ -60,3 +93,22 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] | |||||||||
scale.data_ptr<float>(), hidden_size); | ||||||||||
}); | ||||||||||
} | ||||||||||
|
||||||||||
void dynamic_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] | ||||||||||
torch::Tensor& input, // [..., hidden_size] | ||||||||||
torch::Tensor& scales) { | ||||||||||
TORCH_CHECK(input.is_contiguous()); | ||||||||||
TORCH_CHECK(out.is_contiguous()); | ||||||||||
int hidden_size = input.size(-1); | ||||||||||
int num_tokens = input.numel() / hidden_size; | ||||||||||
dim3 grid(num_tokens); | ||||||||||
dim3 block(std::min(hidden_size, 1024)); | ||||||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||||||||
VLLM_DISPATCH_FLOATING_TYPES( | ||||||||||
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { | ||||||||||
vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float*> | ||||||||||
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), | ||||||||||
out.data_ptr<int8_t>(), | ||||||||||
scales.data_ptr<float>(), hidden_size); | ||||||||||
}); | ||||||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,29 +21,47 @@ | |
#include "cuda_compat.h" | ||
|
||
namespace vllm { | ||
|
||
namespace detail { | ||
|
||
template <typename T> | ||
__inline__ __device__ T _max(T a, T b) { | ||
return max(a, b); | ||
} | ||
|
||
template <typename T> | ||
__inline__ __device__ T _sum(T a, T b) { | ||
return a + b; | ||
} | ||
|
||
} // namespace detail | ||
|
||
template <typename T> | ||
using ReduceFnType = T (*)(T, T); | ||
|
||
// Helper function to return the next largest power of 2 | ||
static constexpr int _nextPow2(unsigned int num) { | ||
if (num <= 1) return num; | ||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); | ||
} | ||
Comment on lines
+42
to
+46
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a common place we can put CUDA utils like this? We have the exact same helper fn in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did some sleuthing, but can't find a good place to put it. Should we create a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We definitely need another refactoring for |
||
|
||
template <typename T, int numLanes = WARP_SIZE> | ||
__inline__ __device__ T warpReduceSum(T val) { | ||
__inline__ __device__ T warpReduce(T val, ReduceFnType<T> fn) { | ||
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0, | ||
"numLanes is not a positive power of 2!"); | ||
static_assert(numLanes <= WARP_SIZE); | ||
#pragma unroll | ||
for (int mask = numLanes >> 1; mask > 0; mask >>= 1) | ||
val += VLLM_SHFL_XOR_SYNC(val, mask); | ||
return val; | ||
} | ||
val = fn(val, VLLM_SHFL_XOR_SYNC(val, mask)); | ||
|
||
// Helper function to return the next largest power of 2 | ||
static constexpr int _nextPow2(unsigned int num) { | ||
if (num <= 1) return num; | ||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); | ||
return val; | ||
} | ||
|
||
/* Calculate the sum of all elements in a block */ | ||
template <typename T, int maxBlockSize = 1024> | ||
__inline__ __device__ T blockReduceSum(T val) { | ||
__inline__ __device__ T blockReduce(T val, ReduceFnType<T> fn) { | ||
static_assert(maxBlockSize <= 1024); | ||
if constexpr (maxBlockSize > WARP_SIZE) { | ||
val = warpReduceSum<T>(val); | ||
val = warpReduce<T>(val, fn); | ||
// Calculates max number of lanes that need to participate in the last | ||
// warpReduce | ||
constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE; | ||
|
@@ -56,12 +74,22 @@ __inline__ __device__ T blockReduceSum(T val) { | |
|
||
val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] | ||
: (T)(0.0f); | ||
val = warpReduceSum<T, _nextPow2(maxActiveLanes)>(val); | ||
val = warpReduce<T, _nextPow2(maxActiveLanes)>(val, fn); | ||
} else { | ||
// A single warpReduce is equal to blockReduce | ||
val = warpReduceSum<T, _nextPow2(maxBlockSize)>(val); | ||
val = warpReduce<T, _nextPow2(maxBlockSize)>(val, fn); | ||
} | ||
return val; | ||
} | ||
|
||
template <typename T, int maxBlockSize = 1024> | ||
__inline__ __device__ T blockReduceMax(T val) { | ||
return blockReduce<T, maxBlockSize>(val, detail::_max<T>); | ||
} | ||
|
||
template <typename T, int maxBlockSize = 1024> | ||
__inline__ __device__ T blockReduceSum(T val) { | ||
return blockReduce<T, maxBlockSize>(val, detail::_sum<T>); | ||
} | ||
|
||
} // namespace vllm |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,21 +10,52 @@ | |
SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] | ||
|
||
|
||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) | ||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we add a larger hidden size (> 1024) that's not nice number as well? I see 5120, but it is a multiple of 256 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added hidden-sizes 5137 and 8193 |
||
@pytest.mark.parametrize("dtype", DTYPES) | ||
@pytest.mark.parametrize("seed", SEEDS) | ||
@torch.inference_mode() | ||
def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, | ||
dtype: torch.dtype, seed: int) -> None: | ||
torch.random.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
int8_traits = torch.iinfo(torch.int8) | ||
|
||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 | ||
|
||
x_token_max, _ = x.max(dim=1) | ||
x_token_max = x_token_max.to(dtype=torch.float32) | ||
scales = (x_token_max / float(127.0))[:, None].to(device="cuda", | ||
dtype=torch.float32) | ||
torch_out = (x / scales).round().clamp(int8_traits.min, | ||
int8_traits.max).to(torch.int8) | ||
|
||
ops_out = torch.empty_like(x, dtype=torch.int8, device="cuda") | ||
scales_out = torch.empty_like(scales, dtype=torch.float32, device="cuda") | ||
ops.dynamic_scaled_int8_quant(ops_out, x, scales_out) | ||
|
||
assert torch.allclose(scales_out, scales) | ||
assert torch.allclose(torch_out, ops_out, | ||
atol=1) # big atol to account for rounding errors | ||
|
||
|
||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) | ||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) | ||
@pytest.mark.parametrize("dtype", DTYPES) | ||
@pytest.mark.parametrize("seed", SEEDS) | ||
@pytest.mark.parametrize("scale", SCALE) | ||
@torch.inference_mode() | ||
def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, | ||
seed: int, scale: float) -> None: | ||
def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, | ||
dtype: torch.dtype, seed: int, | ||
scale: float) -> None: | ||
torch.random.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
int8_traits = torch.iinfo(torch.int8) | ||
|
||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 | ||
|
||
out1 = (x / scale).round().clamp( | ||
torch.iinfo(torch.int8).min, | ||
torch.iinfo(torch.int8).max).to(torch.int8) | ||
out1 = (x / scale).round().clamp(int8_traits.min, | ||
int8_traits.max).to(torch.int8) | ||
out2 = torch.empty_like(x, dtype=torch.int8) | ||
scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda") | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -264,21 +264,33 @@ def scaled_fp8_quant( | |
|
||
|
||
# int8 | ||
def static_scaled_int8_quant(input: torch.Tensor, | ||
scale: torch.Tensor) -> torch.Tensor: | ||
def scaled_int8_quant( | ||
input: torch.Tensor, | ||
scale: Optional[torch.Tensor] = None | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
""" | ||
Quantize the input tensor to int8 and return the quantized tensor. | ||
Quantize the input tensor to int8 and return the quantized tensor and scale. | ||
|
||
Args: | ||
input: The input tensor to be quantized to int8. | ||
scale: Scaling factor for the int8 quantization. | ||
scale: Optional scaling factor for the int8 quantization. | ||
When not provided, we invoke dynamic-per-token quantization. | ||
|
||
Returns: | ||
torch.Tensor: Output tensor in int8. | ||
Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales. | ||
""" | ||
q = torch.empty_like(input, dtype=torch.int8) | ||
vllm_ops.static_scaled_int8_quant(q, input, scale) | ||
return q | ||
output = torch.empty_like(input, dtype=torch.int8) | ||
if scale is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we make the names of the variables used internally in this function match the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. renamed |
||
# static-per-tensor quantization. | ||
vllm_ops.static_scaled_int8_quant(output, input, scale) | ||
return output, scale | ||
|
||
# dynamic-per-token quantization. | ||
input_scales = torch.empty((input.numel() // input.shape[-1], 1), | ||
device=input.device, | ||
dtype=torch.float32) | ||
vllm_ops.dynamic_scaled_int8_quant(output, input, input_scales) | ||
return output, input_scales | ||
|
||
|
||
# moe | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.