-
-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] Dynamic Per-Token Activation Quantization #5037
Changes from all commits
4d27a2c
92b3703
2a3eb83
3dd1fe8
f2f8c52
c9308eb
d9d49b5
b111ee6
c31a7af
ca01b39
f0197d4
4624b46
75757d5
e1df0eb
bc0991c
74ad650
43c43f3
cf5600f
169ce7f
03b53e7
f9df31b
ba4b6b3
3c223c6
b27f31a
b589cdd
98159cf
8dbeb31
5eeb40a
c55e023
f5cbbd3
a685957
4dfb37f
de81f9e
15f1863
bd53847
b2926f3
1274386
18640c8
5c5dc84
a44b4a0
6f0e6e1
0090454
4b10fd7
68a59c7
b0afe67
4f4951e
869de3f
e68e391
d77cf50
51a4e59
6777319
54c797a
6bcab22
ece93e1
1d87a99
3dd1b5f
fed7cdd
66719a9
2ec6a2c
0c7f870
34e2e12
e79517e
39e66d1
59f8ec1
7a83601
9ea47c8
7abb2c8
eb4e119
d62930d
80b6fac
fa1ceef
7075318
60a6d73
f36519b
2c6e580
b3d692a
f3bf9e3
2bd62e0
460f514
dfcd61a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,27 +4,59 @@ | |
from vllm._C import ops | ||
|
||
DTYPES = [torch.half, torch.bfloat16, torch.float] | ||
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 8192] # Arbitrary values for testing | ||
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192, | ||
8193] # Arbitrary values for testing | ||
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing | ||
SEEDS = [0] | ||
SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] | ||
|
||
|
||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) | ||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we add a larger hidden size (> 1024) that's not nice number as well? I see 5120, but it is a multiple of 256 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added hidden-sizes 5137 and 8193 |
||
@pytest.mark.parametrize("dtype", DTYPES) | ||
@pytest.mark.parametrize("seed", SEEDS) | ||
@torch.inference_mode() | ||
def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, | ||
dtype: torch.dtype, seed: int) -> None: | ||
torch.random.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
int8_traits = torch.iinfo(torch.int8) | ||
|
||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 | ||
|
||
x_token_max, _ = x.max(dim=1) | ||
x_token_max = x_token_max.to(dtype=torch.float32) | ||
scales = (x_token_max / float(127.0))[:, None].to(device="cuda", | ||
dtype=torch.float32) | ||
torch_out = (x / scales).round().clamp(int8_traits.min, | ||
int8_traits.max).to(torch.int8) | ||
|
||
ops_out = torch.empty_like(x, dtype=torch.int8, device="cuda") | ||
scales_out = torch.empty_like(scales, dtype=torch.float32, device="cuda") | ||
ops.dynamic_scaled_int8_quant(ops_out, x, scales_out) | ||
|
||
assert torch.allclose(scales_out, scales) | ||
assert torch.allclose(torch_out, ops_out, | ||
atol=1) # big atol to account for rounding errors | ||
|
||
|
||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) | ||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) | ||
@pytest.mark.parametrize("dtype", DTYPES) | ||
@pytest.mark.parametrize("seed", SEEDS) | ||
@pytest.mark.parametrize("scale", SCALE) | ||
@torch.inference_mode() | ||
def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, | ||
seed: int, scale: float) -> None: | ||
def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, | ||
dtype: torch.dtype, seed: int, | ||
scale: float) -> None: | ||
torch.random.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
int8_traits = torch.iinfo(torch.int8) | ||
|
||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 | ||
|
||
out1 = (x / scale).round().clamp( | ||
torch.iinfo(torch.int8).min, | ||
torch.iinfo(torch.int8).max).to(torch.int8) | ||
out1 = (x / scale).round().clamp(int8_traits.min, | ||
int8_traits.max).to(torch.int8) | ||
out2 = torch.empty_like(x, dtype=torch.int8) | ||
scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda") | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -264,21 +264,33 @@ def scaled_fp8_quant( | |
|
||
|
||
# int8 | ||
def static_scaled_int8_quant(input: torch.Tensor, | ||
scale: torch.Tensor) -> torch.Tensor: | ||
def scaled_int8_quant( | ||
input: torch.Tensor, | ||
scale: Optional[torch.Tensor] = None | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
""" | ||
Quantize the input tensor to int8 and return the quantized tensor. | ||
Quantize the input tensor to int8 and return the quantized tensor and scale. | ||
|
||
Args: | ||
input: The input tensor to be quantized to int8. | ||
scale: Scaling factor for the int8 quantization. | ||
scale: Optional scaling factor for the int8 quantization. | ||
When not provided, we invoke dynamic-per-token quantization. | ||
|
||
Returns: | ||
torch.Tensor: Output tensor in int8. | ||
Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales. | ||
""" | ||
q = torch.empty_like(input, dtype=torch.int8) | ||
vllm_ops.static_scaled_int8_quant(q, input, scale) | ||
return q | ||
output = torch.empty_like(input, dtype=torch.int8) | ||
if scale is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we make the names of the variables used internally in this function match the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. renamed |
||
# static-per-tensor quantization. | ||
vllm_ops.static_scaled_int8_quant(output, input, scale) | ||
return output, scale | ||
|
||
# dynamic-per-token quantization. | ||
input_scales = torch.empty((input.numel() // input.shape[-1], 1), | ||
device=input.device, | ||
dtype=torch.float32) | ||
vllm_ops.dynamic_scaled_int8_quant(output, input, input_scales) | ||
return output, input_scales | ||
|
||
|
||
# moe | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a common place we can put CUDA utils like this? We have the exact same helper fn in
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did some sleuthing, but can't find a good place to put it. Should we create a
math_utils.cuh
file ? @robertgshaw2-neuralmagic @mgoinThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We definitely need another refactoring for
csrc/quantization
...but I don't have an out-of-box solution for this ATM.