-
Notifications
You must be signed in to change notification settings - Fork 583
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update model_loader deps and qqq quantization deps (#2220)
- Loading branch information
Showing
58 changed files
with
2,363 additions
and
366 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import logging | ||
from typing import Optional | ||
|
||
import torch | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class DeviceConfig: | ||
device: Optional[torch.device] | ||
|
||
def __init__(self, device: str = "cuda") -> None: | ||
if device in ["cuda", "xpu", "hpu"]: | ||
self.device_type = device | ||
else: | ||
raise RuntimeError(f"Not supported device type: {device}") | ||
self.device = torch.device(self.device_type) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py | ||
import enum | ||
import json | ||
import logging | ||
from dataclasses import dataclass, field | ||
from typing import List, Optional, Union | ||
|
||
from sglang.srt.utils import is_hip | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class LoadFormat(str, enum.Enum): | ||
AUTO = "auto" | ||
PT = "pt" | ||
SAFETENSORS = "safetensors" | ||
NPCACHE = "npcache" | ||
DUMMY = "dummy" | ||
SHARDED_STATE = "sharded_state" | ||
GGUF = "gguf" | ||
BITSANDBYTES = "bitsandbytes" | ||
MISTRAL = "mistral" | ||
|
||
|
||
@dataclass | ||
class LoadConfig: | ||
""" | ||
download_dir: Directory to download and load the weights, default to the | ||
default cache directory of huggingface. | ||
load_format: The format of the model weights to load: | ||
"auto" will try to load the weights in the safetensors format and | ||
fall back to the pytorch bin format if safetensors format is | ||
not available. | ||
"pt" will load the weights in the pytorch bin format. | ||
"safetensors" will load the weights in the safetensors format. | ||
"npcache" will load the weights in pytorch format and store | ||
a numpy cache to speed up the loading. | ||
"dummy" will initialize the weights with random values, which is | ||
mainly for profiling. | ||
"bitsandbytes" will load nf4 type weights. | ||
ignore_patterns: The list of patterns to ignore when loading the model. | ||
Default to "original/**/*" to avoid repeated loading of llama's | ||
checkpoints. | ||
""" | ||
|
||
load_format: Union[str, LoadFormat] = LoadFormat.AUTO | ||
download_dir: Optional[str] = None | ||
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict) | ||
ignore_patterns: Optional[Union[List[str], str]] = None | ||
|
||
def __post_init__(self): | ||
model_loader_extra_config = self.model_loader_extra_config or {} | ||
if isinstance(model_loader_extra_config, str): | ||
self.model_loader_extra_config = json.loads(model_loader_extra_config) | ||
self._verify_load_format() | ||
|
||
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: | ||
logger.info( | ||
"Ignoring the following patterns when downloading weights: %s", | ||
self.ignore_patterns, | ||
) | ||
else: | ||
self.ignore_patterns = ["original/**/*"] | ||
|
||
def _verify_load_format(self) -> None: | ||
if not isinstance(self.load_format, str): | ||
return | ||
|
||
load_format = self.load_format.lower() | ||
self.load_format = LoadFormat(load_format) | ||
|
||
rocm_not_supported_load_format: List[str] = [] | ||
if is_hip() and load_format in rocm_not_supported_load_format: | ||
rocm_supported_load_format = [ | ||
f | ||
for f in LoadFormat.__members__ | ||
if (f not in rocm_not_supported_load_format) | ||
] | ||
raise ValueError( | ||
f"load format '{load_format}' is not supported in ROCm. " | ||
f"Supported load formats are " | ||
f"{rocm_supported_load_format}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,6 +42,7 @@ | |
"Fp8LinearMethod", | ||
"MarlinLinearMethod", | ||
"GPTQLinearMethod", | ||
"QQQLinearMethod", | ||
] | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.