Skip to content

Commit

Permalink
feat(granite): Add support for loading state_dict from safetensors
Browse files Browse the repository at this point in the history
Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
  • Loading branch information
gabe-l-hart committed Oct 3, 2024
1 parent e855b72 commit 4628416
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion torchchat/cli/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,33 @@ def permute(w, n_heads):

merged_result = {}
for file in sorted(bin_files):
state_dict = torch.load(

# The state_dict can be loaded from either a torch zip file or
# safetensors. We take our best guess from the name and try all
# possibilities
load_pt_mmap = lambda: torch.load(
str(file), map_location="cpu", mmap=True, weights_only=True
)
load_pt_no_mmap = lambda: torch.load(
str(file), map_location="cpu", mmap=False, weights_only=True
)
def load_safetensors():
import safetensors.torch
with open(file, "rb") as handle:
return safetensors.torch.load(handle.read())
if "safetensors" in str(file):
loaders = [load_safetensors, load_pt_mmap, load_pt_no_mmap]
else:
loaders = [load_pt_mmap, load_pt_no_mmap, load_safetensors]

state_dict = None
for loader in loaders:
try:
state_dict = loader()
break
except Exception:
continue
assert state_dict is not None, f"Unable to load tensors from {file}"
merged_result.update(state_dict)
final_result = {}
for key, value in merged_result.items():
Expand Down

0 comments on commit 4628416

Please sign in to comment.