From 0f89712ecb07d0836a5c94ff3a863b196c540565 Mon Sep 17 00:00:00 2001 From: dbogunowicz Date: Mon, 8 Apr 2024 12:46:49 +0000 Subject: [PATCH] Add Support for Decompressing Models from HF Hub (#2212) commit --- src/sparsetensors/utils/safetensors_load.py | 90 +++++++++++---------- 1 file changed, 47 insertions(+), 43 deletions(-) diff --git a/src/sparsetensors/utils/safetensors_load.py b/src/sparsetensors/utils/safetensors_load.py index c82d3e43..4d71482a 100644 --- a/src/sparsetensors/utils/safetensors_load.py +++ b/src/sparsetensors/utils/safetensors_load.py @@ -22,8 +22,8 @@ __all__ = [ - "get_safetensors_header", "get_safetensors_folder", + "get_safetensors_header", "match_param_name", "merge_names", "get_weight_mappings", @@ -31,6 +31,48 @@ ] +def get_safetensors_folder( + pretrained_model_name_or_path: str, cache_dir: Optional[str] = None +) -> str: + """ + Given a Hugging Face stub or a local path, return the folder containing the + safetensors weight files + + :param pretrained_model_name_or_path: local path to model or HF stub + :param cache_dir: optional cache dir to search through, if none is specified the + model will be searched for in the default TRANSFORMERS_CACHE + :return: local folder containing model data + """ + if os.path.exists(pretrained_model_name_or_path): + # argument is a path to a local folder + return pretrained_model_name_or_path + + safetensors_path = cached_file( + pretrained_model_name_or_path, + SAFE_WEIGHTS_NAME, + cache_dir=cache_dir, + _raise_exceptions_for_missing_entries=False, + ) + index_path = cached_file( + pretrained_model_name_or_path, + SAFE_WEIGHTS_INDEX_NAME, + cache_dir=cache_dir, + _raise_exceptions_for_missing_entries=False, + ) + if safetensors_path is not None: + # found a single cached safetensors file + return os.path.split(safetensors_path)[0] + if index_path is not None: + # found a cached safetensors weight index file + return os.path.split(index_path)[0] + + # model weights could not be found locally or cached from HF Hub + raise ValueError( + "Could not locate safetensors weight or index file from " + f"{pretrained_model_name_or_path}." + ) + + def get_safetensors_header(safetensors_path: str) -> Dict[str, str]: """ Extracts the metadata from a safetensors file as JSON @@ -106,6 +148,10 @@ def get_weight_mappings(model_path: str) -> Dict[str, str]: with open(index_path, "r", encoding="utf-8") as f: index = json.load(f) header = index["weight_map"] + else: + raise ValueError( + f"Could not find a safetensors weight or index file at {model_path}" + ) # convert weight locations to full paths for key, value in header.items(): @@ -148,45 +194,3 @@ def get_nested_weight_mappings( nested_weight_mappings[dense_param][param_name] = weight_mappings[key] return nested_weight_mappings - - -def get_safetensors_folder( - pretrained_model_name_or_path: str, cache_dir: Optional[str] = None -) -> str: - """ - Given a Hugging Face stub or a local path, return the folder containing the - safetensors weight files - - :param pretrained_model_name_or_path: local path to model or HF stub - :param cache_dir: optional cache dir to search through, if none is specified the - model will be searched for in the default TRANSFORMERS_CACHE - :return: local folder containing model data - """ - if os.path.exists(pretrained_model_name_or_path): - # argument is a path to a local folder - return pretrained_model_name_or_path - - safetensors_path = cached_file( - pretrained_model_name_or_path, - SAFE_WEIGHTS_NAME, - cache_dir=cache_dir, - _raise_exceptions_for_missing_entries=False, - ) - index_path = cached_file( - pretrained_model_name_or_path, - SAFE_WEIGHTS_INDEX_NAME, - cache_dir=cache_dir, - _raise_exceptions_for_missing_entries=False, - ) - if safetensors_path is not None: - # found a single cached safetensors file - return os.path.split(safetensors_path)[0] - if index_path is not None: - # found a cached safetensors weight index file - return os.path.split(index_path)[0] - - # model weights could not be found locally or cached from HF Hub - raise ValueError( - "Could not locate safetensors weight or index file from " - f"{pretrained_model_name_or_path}." - )