From 2fc163ca18af5e84438d13609ce75358e751842f Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 2 Oct 2024 09:17:32 -0600 Subject: [PATCH] feat(safetensors): Use model_info to determine whether to download pth or safetensors The logic here will prefer pth over safetensors unless the model's config explicitly states a preference for safetensors over pth. If only one of the two is found, the download will use whichever is present. Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart --- torchchat/cli/download.py | 30 ++++++++++++++++++++++++-- torchchat/model_config/model_config.py | 1 + 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/torchchat/cli/download.py b/torchchat/cli/download.py index 6ac3e8d9d..14dfeb062 100644 --- a/torchchat/cli/download.py +++ b/torchchat/cli/download.py @@ -22,18 +22,44 @@ def _download_hf_snapshot( model_config: ModelConfig, artifact_dir: Path, hf_token: Optional[str] ): - from huggingface_hub import snapshot_download + from huggingface_hub import model_info, snapshot_download from requests.exceptions import HTTPError # Download and store the HF model artifacts. print(f"Downloading {model_config.name} from HuggingFace...", file=sys.stderr) try: + # Fetch the info about the model's repo + model_info = model_info(model_config.distribution_path, token=hf_token) + model_fnames = [f.rfilename for f in model_info.siblings] + + # Check the model config for preference between safetensors and pth + has_pth = any(f.endswith(".pth") for f in model_fnames) + has_safetensors = any(f.endswith(".safetensors") for f in model_fnames) + + # If told to prefer safetensors, ignore pth files + if model_config.prefer_safetensors: + if not has_safetensors: + print( + f"Model {model_config.name} does not have safetensors files, but prefer_safetensors is set to True. Using pth files instead.", + file=sys.stderr, + ) + exit(1) + ignore_patterns = "*.pth" + + # If the model has both, prefer pth files over safetensors + elif has_pth and has_safetensors: + ignore_patterns = "*safetensors*" + + # Otherwise, download everything + else: + ignore_patterns = None + snapshot_download( model_config.distribution_path, local_dir=artifact_dir, local_dir_use_symlinks=False, token=hf_token, - ignore_patterns="*safetensors*", + ignore_patterns=ignore_patterns, ) except HTTPError as e: if e.response.status_code == 401: # Missing HuggingFace CLI login. diff --git a/torchchat/model_config/model_config.py b/torchchat/model_config/model_config.py index 584a87a74..540804ada 100644 --- a/torchchat/model_config/model_config.py +++ b/torchchat/model_config/model_config.py @@ -46,6 +46,7 @@ class ModelConfig: checkpoint_file: str = field(default="model.pth") tokenizer_file: str = field(default="tokenizer.model") transformer_params_key: str = field(default=None) + prefer_safetensors: bool = field(default=False) # Keys are stored in lowercase.