Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] support torchao for qwen2 models #2219

Open
tricky61 opened this issue Nov 27, 2024 · 11 comments
Open

[Feature] support torchao for qwen2 models #2219

tricky61 opened this issue Nov 27, 2024 · 11 comments
Labels
good first issue Good for newcomers help wanted Extra attention is needed

Comments

@tricky61
Copy link

I used one A30 card, and used Qwen2-7B-Instruct, the speed with quantization seems no different

python3 -m sglang.bench_latency --model ../Qwen2-7B-Instruct --batch-size 1 --input-len 200 --output-len 100
Benchmark ...
Prefill. latency: 0.03508 s, throughput: 5700.84 token/s
Decode. latency: 0.01952 s, throughput: 51.23 token/s
Decode. latency: 0.01947 s, throughput: 51.37 token/s
Decode. latency: 0.01939 s, throughput: 51.58 token/s
Decode. latency: 0.01933 s, throughput: 51.74 token/s
Decode. latency: 0.01928 s, throughput: 51.87 token/s
Decode. median latency: 0.01924 s, median throughput: 51.98 token/s
Total. latency: 1.942 s, throughput: 154.52 token/s

python3 -m sglang.bench_latency --model ../Qwen2-7B-Instruct --batch-size 1 --input-len 200 --output-len 100 --enable-torch-compile
Benchmark ...
Prefill. latency: 0.03655 s, throughput: 5471.84 token/s
Decode. latency: 0.01852 s, throughput: 54.00 token/s
Decode. latency: 0.01847 s, throughput: 54.14 token/s
Decode. latency: 0.01845 s, throughput: 54.21 token/s
Decode. latency: 0.01843 s, throughput: 54.26 token/s
Decode. latency: 0.01838 s, throughput: 54.39 token/s
Decode. median latency: 0.01836 s, median throughput: 54.46 token/s
Total. latency: 1.855 s, throughput: 161.71 token/s

python3 -m sglang.bench_latency --model ../Qwen2-7B-Instruct --batch-size 1 --input-len 200 --output-len 100 --enable-torch-compile --torchao-config int8wo
Benchmark ...
Prefill. latency: 0.04469 s, throughput: 4475.31 token/s
Decode. latency: 0.01860 s, throughput: 53.77 token/s
Decode. latency: 0.01849 s, throughput: 54.09 token/s
Decode. latency: 0.01844 s, throughput: 54.24 token/s
Decode. latency: 0.01841 s, throughput: 54.32 token/s
Decode. latency: 0.01837 s, throughput: 54.45 token/s
Decode. median latency: 0.01836 s, median throughput: 54.46 token/s
Total. latency: 1.863 s, throughput: 160.99 token/s

python3 -m sglang.bench_latency --model ../Qwen2-7B-Instruct --batch-size 1 --input-len 200 --output-len 100 --enable-torch-compile --torchao-config int4wo
Benchmark ...
Prefill. latency: 0.03558 s, throughput: 5621.52 token/s
Decode. latency: 0.01855 s, throughput: 53.91 token/s
Decode. latency: 0.01852 s, throughput: 54.01 token/s
Decode. latency: 0.01845 s, throughput: 54.20 token/s
Decode. latency: 0.01842 s, throughput: 54.28 token/s
Decode. latency: 0.01841 s, throughput: 54.33 token/s
Decode. median latency: 0.01837 s, median throughput: 54.44 token/s
Total. latency: 1.855 s, throughput: 161.72 token/s

@HandH1998
Copy link
Collaborator

I believe this issue is related to the implementation of TorchAO. For example, with int8wo, it is necessary to convert int8 weights to the same data type as the activations(refer to https://github.com/pytorch/ao/blob/main/torchao/dtypes/uintx/plain_layout.py#L214), usually fp16, before performing the GEMM operation. When the batch size is small, the primary benefit of weight quantization is the reduction of memory access. However, int8wo does not actually decrease the memory access overhead compared to the original fp16 pipeline, resulting in no speedup. The same issue occurs with int4wo.

To achieve speedup, you might consider using AWQ or GPTQ, which can effectively reduce weight memory access. Additionally, these methods are supported by SGLang.

@tricky61
Copy link
Author

I believe this issue is related to the implementation of TorchAO. For example, with int8wo, it is necessary to convert int8 weights to the same data type as the activations(refer to https://github.com/pytorch/ao/blob/main/torchao/dtypes/uintx/plain_layout.py#L214), usually fp16, before performing the GEMM operation. When the batch size is small, the primary benefit of weight quantization is the reduction of memory access. However, int8wo does not actually decrease the memory access overhead compared to the original fp16 pipeline, resulting in no speedup. The same issue occurs with int4wo.

To achieve speedup, you might consider using AWQ or GPTQ, which can effectively reduce weight memory access. Additionally, these methods are supported by SGLang.

#1341
I see speed up with int8wo or int4wo in this issue

@HandH1998
Copy link
Collaborator

There is no speedup with torch.complile for int8wo and int4wo, refering to second part of results from #1341.

@merrymercy
Copy link
Contributor

merrymercy commented Dec 1, 2024

It seems torchao has not been applied to qwen. Can you copy this line from llama.py

apply_torchao_config_(self, params_dict, set(["proj.weight"]))
to qwen.py and qwen2.py?

Please send a pull request after you find it works!

@merrymercy merrymercy changed the title Qwen2 no speed up with quantization? [Feature] support torchao for qwen2 models Dec 1, 2024
@merrymercy merrymercy added good first issue Good for newcomers help wanted Extra attention is needed labels Dec 1, 2024
@tricky61
Copy link
Author

tricky61 commented Dec 2, 2024

It seems torchao has not been applied to qwen. Can you copy this line from llama.py

apply_torchao_config_(self, params_dict, set(["proj.weight"]))

to qwen.py and qwen2.py?
Please send a pull request after you find it works!

yes, seems the torchao-config works. but the int8wo's throughput decreases, and the int4wo-128 increases.
python3 -m sglang.bench_latency --model ../Qwen2-7B-Instruct --batch-size 1 --input-len 200 --output-len 100 --torchao-config int8wo
Benchmark ...
Prefill. latency: 0.09426 s, throughput: 2121.75 token/s
Decode. latency: 0.07580 s, throughput: 13.19 token/s
Decode. latency: 0.07538 s, throughput: 13.27 token/s
Decode. latency: 0.07553 s, throughput: 13.24 token/s
Decode. latency: 0.07549 s, throughput: 13.25 token/s
Decode. latency: 0.07530 s, throughput: 13.28 token/s
Decode. median latency: 0.07523 s, median throughput: 13.29 token/s
Total. latency: 7.551 s, throughput: 39.73 token/s

python3 -m sglang.bench_latency --model ../Qwen2-7B-Instruct --batch-size 1 --input-len 200 --output-len 100 --torchao-config int4wo-128
Benchmark ...
Prefill. latency: 0.31131 s, throughput: 642.44 token/s
Decode. latency: 0.00983 s, throughput: 101.68 token/s
Decode. latency: 0.00957 s, throughput: 104.50 token/s
Decode. latency: 0.00963 s, throughput: 103.81 token/s
Decode. latency: 0.00962 s, throughput: 103.98 token/s
Decode. latency: 0.00980 s, throughput: 102.05 token/s
Decode. median latency: 0.00979 s, median throughput: 102.19 token/s
Total. latency: 1.282 s, throughput: 234.03 token/s

By the way, with adding --eneble-torch-compile, seems error
[rank0]: Exception: Capture cuda graph failed: Failed running call_function (*(FakeTensor(..., device='cuda:0', size=(1, 3584), dtype=torch.bfloat16), AffineQuantizedTensor(layout_tensor=FakeTensor(..., device='cuda:0', size=(576, 32, 32, 4), dtype=torch.int32), block_size=(1, 128), shape=torch.Size([4608, 3584]), device=cuda:0, dtype=torch.bfloat16, requires_grad=False), Parameter(FakeTensor(..., device='cuda:0', size=(4608,), dtype=torch.bfloat16,
[rank0]: requires_grad=True))), **{}):
[rank0]: 'FakeTensor' object has no attribute 'layout_type'

[rank0]: from user code:
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 38, in inner
[rank0]: return fn(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: File "/home/service/var/sglang/python/sglang/srt/models/qwen2.py", line 292, in forward
[rank0]: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/service/var/sglang/python/sglang/srt/models/qwen2.py", line 257, in forward
[rank0]: hidden_states, residual = layer(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/service/var/sglang/python/sglang/srt/models/qwen2.py", line 209, in forward
[rank0]: hidden_states = self.self_attn(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/service/var/sglang/python/sglang/srt/models/qwen2.py", line 155, in forward
[rank0]: qkv, _ = self.qkv_proj(hidden_states)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/service/var/sglang/python/sglang/srt/layers/linear.py", line 380, in forward
[rank0]: output_parallel = self.quant_method.apply(self, input
, bias)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/linear.py", line 135, in apply
[rank0]: return F.linear(x, layer.weight, bias)

[rank0]: Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

[rank0]: You can suppress this exception and fall back to eager by setting:
[rank0]: import torch._dynamo
[rank0]: torch._dynamo.config.suppress_errors = True

@tricky61
Copy link
Author

tricky61 commented Dec 2, 2024

There is no speedup with torch.complile for int8wo and int4wo, refering to second part of results from #1341.

ok, and when i use --enable-torch-compile and int8wo or int4wo-128, errors
[rank0]: Exception: Capture cuda graph failed: Failed running call_function (*(FakeTensor(..., device='cuda:0', size=(1, 3584), dtype=torch.bfloat16), AffineQuantizedTensor(layout_tensor=FakeTensor(..., device='cuda:0', size=(576, 32, 32, 4), dtype=torch.int32), block_size=(1, 128), shape=torch.Size([4608, 3584]), device=cuda:0, dtype=torch.bfloat16, requires_grad=False), Parameter(FakeTensor(..., device='cuda:0', size=(4608,), dtype=torch.bfloat16,
[rank0]: requires_grad=True))), **{}):
[rank0]: 'FakeTensor' object has no attribute 'layout_type'

[rank0]: from user code:
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 38, in inner
[rank0]: return fn(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: File "/home/service/var/sglang/python/sglang/srt/models/qwen2.py", line 292, in forward
[rank0]: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/service/var/sglang/python/sglang/srt/models/qwen2.py", line 257, in forward
[rank0]: hidden_states, residual = layer(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/service/var/sglang/python/sglang/srt/models/qwen2.py", line 209, in forward
[rank0]: hidden_states = self.self_attn(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/service/var/sglang/python/sglang/srt/models/qwen2.py", line 155, in forward
[rank0]: qkv, _ = self.qkv_proj(hidden_states)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/service/var/sglang/python/sglang/srt/layers/linear.py", line 380, in forward
[rank0]: output_parallel = self.quant_method.apply(self, input
, bias)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/linear.py", line 135, in apply
[rank0]: return F.linear(x, layer.weight, bias)

[rank0]: Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

[rank0]: You can suppress this exception and fall back to eager by setting:
[rank0]: import torch._dynamo
[rank0]: torch._dynamo.config.suppress_errors = True

@tricky61
Copy link
Author

tricky61 commented Dec 2, 2024

It seems torchao has not been applied to qwen. Can you copy this line from llama.py

apply_torchao_config_(self, params_dict, set(["proj.weight"]))

to qwen.py and qwen2.py?
Please send a pull request after you find it works!

@HandH1998
another question, the A30 does not support fp8
when I runs tensorrt-llm's benchmark of fp8, it cannot run. the sglang can run, but the throughput drops significantly. I thought it cannot run as tensorrt-llm
python3 -m sglang.bench_latency --model ../Qwen2-7B-Instruct --batch-size 1 --input-len 200 --output-len 100 --torchao-config fp8wo
Benchmark ...
Prefill. latency: 0.15525 s, throughput: 1288.20 token/s
Decode. latency: 0.13796 s, throughput: 7.25 token/s
Decode. latency: 0.13617 s, throughput: 7.34 token/s
Decode. latency: 0.13601 s, throughput: 7.35 token/s
Decode. latency: 0.13598 s, throughput: 7.35 token/s
Decode. latency: 0.13597 s, throughput: 7.35 token/s
Decode. median latency: 0.13691 s, median throughput: 7.30 token/s
Total. latency: 13.714 s, throughput: 21.87 token/s

@tricky61
Copy link
Author

tricky61 commented Dec 2, 2024

@HandH1998 @merrymercy
By the way,the int8dq fails with cuda graph.

[rank0]: File "/home/service/var/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 273, in capture_one_batch_size
[rank0]: out = run_once()
[rank0]: File "/home/service/var/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 258, in run_once
[rank0]: return forward(input_ids, forward_batch.positions, forward_batch)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: File "/home/service/var/sglang/python/sglang/srt/models/qwen2.py", line 292, in forward
[rank0]: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/service/var/sglang/python/sglang/srt/models/qwen2.py", line 257, in forward
[rank0]: hidden_states, residual = layer(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/service/var/sglang/python/sglang/srt/models/qwen2.py", line 209, in forward
[rank0]: hidden_states = self.self_attn(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/service/var/sglang/python/sglang/srt/models/qwen2.py", line 155, in forward
[rank0]: qkv, _ = self.qkv_proj(hidden_states)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]: return self.call_impl(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/service/var/sglang/python/sglang/srt/layers/linear.py", line 380, in forward
[rank0]: output_parallel = self.quant_method.apply(self, input
, bias)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/linear.py", line 135, in apply
[rank0]: return F.linear(x, layer.weight, bias)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torchao/utils.py", line 374, in dispatch__torch_function

[rank0]: return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torchao/utils.py", line 357, in wrapper
[rank0]: return func(f, types, args, kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torchao/quantization/linear_activation_quantized_tensor.py", line 102, in _
[rank0]: return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torchao/quantization/linear_activation_quantized_tensor.py", line 73, in quantized_linear_op
[rank0]: return torch.nn.functional.linear(aqt, original_weight_tensor, bias)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torchao/utils.py", line 374, in dispatch__torch_function

[rank0]: return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torchao/utils.py", line 357, in wrapper
[rank0]: return func(f, types, args, kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torchao/dtypes/affine_quantized_tensor.py", line 1754, in _
[rank0]: return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torchao/dtypes/affine_quantized_tensor.py", line 231, in _quantized_linear_op
[rank0]: return impl(input_tensor, weight_tensor, bias)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torchao/dtypes/affine_quantized_tensor.py", line 1372, in _linear_int8_act_int8_weight_impl
[rank0]: y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1))
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torchao/kernel/intmm.py", line 137, in int_scaled_matmul
[rank0]: c = safe_int_mm(a, b)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torchao/kernel/intmm.py", line 39, in safe_int_mm
[rank0]: if dynamo_is_compiling() or "FakeTensor" in input.repr():
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/_tensor.py", line 463, in repr
[rank0]: return torch._tensor_str._str(self, tensor_contents=tensor_contents)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/_tensor_str.py", line 698, in _str
[rank0]: return _str_intern(self, tensor_contents=tensor_contents)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/_tensor_str.py", line 618, in _str_intern
[rank0]: tensor_str = _tensor_str(self, indent)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/_tensor_str.py", line 350, in _tensor_str
[rank0]: formatter = _Formatter(get_summarized_data(self) if summarize else self)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/_tensor_str.py", line 134, in init
[rank0]: value_str = f"{value}"
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/_tensor.py", line 986, in format
[rank0]: return self.item().format(format_spec)
[rank0]: RuntimeError: CUDA error: operation not permitted when stream is capturing
[rank0]: CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
[rank0]: For debugging consider passing CUDA_LAUNCH_BLOCKING=1
[rank0]: Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

@merrymercy
Copy link
Contributor

please use torch 2.5 by pip install vllm==0.6.4.post1

@tricky61
Copy link
Author

tricky61 commented Dec 5, 2024

@merrymercy @HandH1998
Is this normal for batch-size >= 8 slow than fp16 or cuda memory oom?

@JamesSand
Copy link
Contributor

Hi @tricky61 @merrymercy , do your guys have any ideas of the above error?
I have also encountered the problem that int8dq is incompatible with capture cuda graph. The error is exactly same as mentioned above.

I have tried to upgrade torch to 2.5 by pip install vllm==0.6.4.post1.

My pip list torch result is

Name: torch
Version: 2.5.1
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team

My pip list vllm result is

Name: vllm                                                                                                                           
Version: 0.6.4.post1                                                                                                                 
Summary: A high-throughput and memory-efficient inference and serving engine for LLMs                                                
Home-page: https://github.com/vllm-project/vllm                                                                                      
Author: vLLM Team 

My script to reproduce the error is

python3 -m sglang.launch_server \
    --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
    --torchao-config int8dq \
    --port 30000 --host 0.0.0.0

When disable capture cuda graph, the server can run, but it seems extremely slow.
Here is my script for disable capture cuda graph

python3 -m sglang.launch_server \
    --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
    --torchao-config int8dq \
    --disable-cuda-graph \
    --port 30000 --host 0.0.0.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

5 participants