Skip to content

Commit

Permalink
feat(safetensors): Use model_info to determine whether to download pt…
Browse files Browse the repository at this point in the history
…h or safetensors

The logic here will prefer pth over safetensors unless the model's config
explicitly states a preference for safetensors over pth. If only one of the
two is found, the download will use whichever is present.

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
  • Loading branch information
gabe-l-hart committed Oct 3, 2024
1 parent 4628416 commit 2fc163c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
30 changes: 28 additions & 2 deletions torchchat/cli/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,44 @@
def _download_hf_snapshot(
model_config: ModelConfig, artifact_dir: Path, hf_token: Optional[str]
):
from huggingface_hub import snapshot_download
from huggingface_hub import model_info, snapshot_download
from requests.exceptions import HTTPError

# Download and store the HF model artifacts.
print(f"Downloading {model_config.name} from HuggingFace...", file=sys.stderr)
try:
# Fetch the info about the model's repo
model_info = model_info(model_config.distribution_path, token=hf_token)
model_fnames = [f.rfilename for f in model_info.siblings]

# Check the model config for preference between safetensors and pth
has_pth = any(f.endswith(".pth") for f in model_fnames)
has_safetensors = any(f.endswith(".safetensors") for f in model_fnames)

# If told to prefer safetensors, ignore pth files
if model_config.prefer_safetensors:
if not has_safetensors:
print(
f"Model {model_config.name} does not have safetensors files, but prefer_safetensors is set to True. Using pth files instead.",
file=sys.stderr,
)
exit(1)
ignore_patterns = "*.pth"

# If the model has both, prefer pth files over safetensors
elif has_pth and has_safetensors:
ignore_patterns = "*safetensors*"

# Otherwise, download everything
else:
ignore_patterns = None

snapshot_download(
model_config.distribution_path,
local_dir=artifact_dir,
local_dir_use_symlinks=False,
token=hf_token,
ignore_patterns="*safetensors*",
ignore_patterns=ignore_patterns,
)
except HTTPError as e:
if e.response.status_code == 401: # Missing HuggingFace CLI login.
Expand Down
1 change: 1 addition & 0 deletions torchchat/model_config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class ModelConfig:
checkpoint_file: str = field(default="model.pth")
tokenizer_file: str = field(default="tokenizer.model")
transformer_params_key: str = field(default=None)
prefer_safetensors: bool = field(default=False)


# Keys are stored in lowercase.
Expand Down

0 comments on commit 2fc163c

Please sign in to comment.