From ba36b5520ab6759045abfd89d1d108f861053fb1 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 16 Dec 2024 15:04:16 -0800 Subject: [PATCH] Revert "Small fixes for torchao quant" (#2493) --- python/sglang/srt/layers/torchao_utils.py | 3 +-- python/sglang/srt/model_executor/model_runner.py | 8 ++++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index 1fdda4fad4..910309da97 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -26,12 +26,11 @@ def apply_torchao_config_to_model( quantize_, ) from torchao.quantization.observer import PerRow, PerTensor - from torchao.quantization.quant_api import _is_linear if filter_fn is None: def filter_fn(module, fqn): - return _is_linear(module) and "proj" in fqn + return "proj" in fqn if torchao_config == "" or torchao_config is None: return model diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index a3f62f250e..db024c5c7f 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -157,10 +157,6 @@ def __init__( self.sampler = Sampler() self.load_model() - apply_torchao_config_to_model( - self.model, global_server_args_dict["torchao_config"] - ) - # Apply torch TP if the model supports it supports_torch_tp = getattr(self.model, "supports_torch_tp", False) if self.tp_size > 1 and supports_torch_tp: @@ -169,6 +165,10 @@ def __init__( else: self.torch_tp_applied = False + apply_torchao_config_to_model( + self.model, global_server_args_dict["torchao_config"] + ) + # Init memory pool and attention backends if server_args.lora_paths is not None: self.init_lora_manager()