Skip to content

Commit

Permalink
Safetensors (#1255)
Browse files Browse the repository at this point in the history
* feat(granite): Add support for finding weight mapping files with other names

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat(granite): Add support for loading state_dict from safetensors

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat(safetensors): Use model_info to determine whether to download pth or safetensors

The logic here will prefer pth over safetensors unless the model's config
explicitly states a preference for safetensors over pth. If only one of the
two is found, the download will use whichever is present.

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

---------

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
  • Loading branch information
gabe-l-hart authored Oct 4, 2024
1 parent d8c0aaf commit 766bee9
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 4 deletions.
34 changes: 32 additions & 2 deletions torchchat/cli/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import glob
import json
import os
import re
Expand Down Expand Up @@ -41,7 +42,12 @@ def convert_hf_checkpoint(
print(f"Model config {config.__dict__}")

# Load the json file containing weight mapping
model_map_json = model_dir / "pytorch_model.bin.index.json"
model_map_json_matches = [Path(m) for m in glob.glob(str(model_dir / "*.index.json"))]
assert len(model_map_json_matches) <= 1, "Found multiple weight mapping files"
if len(model_map_json_matches):
model_map_json = model_map_json_matches[0]
else:
model_map_json = model_dir / "pytorch_model.bin.index.json"

# If there is no weight mapping, check for a consolidated model and
# tokenizer we can move. Llama 2 and Mistral have weight mappings, while
Expand Down Expand Up @@ -96,9 +102,33 @@ def permute(w, n_heads):

merged_result = {}
for file in sorted(bin_files):
state_dict = torch.load(

# The state_dict can be loaded from either a torch zip file or
# safetensors. We take our best guess from the name and try all
# possibilities
load_pt_mmap = lambda: torch.load(
str(file), map_location="cpu", mmap=True, weights_only=True
)
load_pt_no_mmap = lambda: torch.load(
str(file), map_location="cpu", mmap=False, weights_only=True
)
def load_safetensors():
import safetensors.torch
with open(file, "rb") as handle:
return safetensors.torch.load(handle.read())
if "safetensors" in str(file):
loaders = [load_safetensors, load_pt_mmap, load_pt_no_mmap]
else:
loaders = [load_pt_mmap, load_pt_no_mmap, load_safetensors]

state_dict = None
for loader in loaders:
try:
state_dict = loader()
break
except Exception:
continue
assert state_dict is not None, f"Unable to load tensors from {file}"
merged_result.update(state_dict)
final_result = {}
for key, value in merged_result.items():
Expand Down
30 changes: 28 additions & 2 deletions torchchat/cli/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,44 @@
def _download_hf_snapshot(
model_config: ModelConfig, artifact_dir: Path, hf_token: Optional[str]
):
from huggingface_hub import snapshot_download
from huggingface_hub import model_info, snapshot_download
from requests.exceptions import HTTPError

# Download and store the HF model artifacts.
print(f"Downloading {model_config.name} from HuggingFace...", file=sys.stderr)
try:
# Fetch the info about the model's repo
model_info = model_info(model_config.distribution_path, token=hf_token)
model_fnames = [f.rfilename for f in model_info.siblings]

# Check the model config for preference between safetensors and pth
has_pth = any(f.endswith(".pth") for f in model_fnames)
has_safetensors = any(f.endswith(".safetensors") for f in model_fnames)

# If told to prefer safetensors, ignore pth files
if model_config.prefer_safetensors:
if not has_safetensors:
print(
f"Model {model_config.name} does not have safetensors files, but prefer_safetensors is set to True. Using pth files instead.",
file=sys.stderr,
)
exit(1)
ignore_patterns = "*.pth"

# If the model has both, prefer pth files over safetensors
elif has_pth and has_safetensors:
ignore_patterns = "*safetensors*"

# Otherwise, download everything
else:
ignore_patterns = None

snapshot_download(
model_config.distribution_path,
local_dir=artifact_dir,
local_dir_use_symlinks=False,
token=hf_token,
ignore_patterns="*safetensors*",
ignore_patterns=ignore_patterns,
)
except HTTPError as e:
if e.response.status_code == 401: # Missing HuggingFace CLI login.
Expand Down
1 change: 1 addition & 0 deletions torchchat/model_config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class ModelConfig:
checkpoint_file: str = field(default="model.pth")
tokenizer_file: str = field(default="tokenizer.model")
transformer_params_key: str = field(default=None)
prefer_safetensors: bool = field(default=False)


# Keys are stored in lowercase.
Expand Down

0 comments on commit 766bee9

Please sign in to comment.