Skip to content

Commit

Permalink
make refiners.conversion.utils.Hub.expected_sha256 optional
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurent2916 committed Nov 12, 2024
1 parent 1f70f43 commit 432a6ff
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 23 deletions.
54 changes: 32 additions & 22 deletions src/refiners/conversion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ def download_file_url(url: str, destination: Path) -> None:
logging.debug(f"Downloading {url} to {destination}")

# get the size of the file
response = requests.get(url, stream=True)
response = requests.get(url=url, stream=True)
response.raise_for_status()
total = int(response.headers.get("content-length", 0))
chunk_size = 1024 * 1000 # 1 MiB

# create a progress bar
bar = tqdm(
Expand All @@ -45,7 +46,7 @@ def download_file_url(url: str, destination: Path) -> None:
with destination.open("wb") as f:
with requests.get(url, stream=True) as r:
r.raise_for_status()
for chunk in r.iter_content(chunk_size=1024 * 1000):
for chunk in r.iter_content(chunk_size=chunk_size):
size = f.write(chunk)
bar.update(size)
bar.close()
Expand All @@ -63,8 +64,8 @@ def __init__(
self,
repo_id: str,
filename: str,
expected_sha256: str,
revision: str = "main",
expected_sha256: str | None = None,
download_url: str | None = None,
) -> None:
"""Initialize the HubPath.
Expand All @@ -73,14 +74,14 @@ def __init__(
repo_id: The repository identifier on the hub.
filename: The filename of the file in the repository.
revision: The revision of the file on the hf hub.
expected_sha256: The sha256 hash of the file.
expected_sha256: The sha256 hash of the file, to optionally check against the local or remote hash.
download_url: The url to download the file from, if not from the huggingface hub.
"""
self.repo_id = repo_id
self.filename = filename
self.revision = revision
self.expected_sha256 = expected_sha256.lower()
self.override_download_url = download_url
self.expected_sha256 = expected_sha256.lower() if expected_sha256 is not None else None
self.download_url = download_url

@staticmethod
def hub_location():
Expand All @@ -90,16 +91,22 @@ def hub_location():
@property
def hf_url(self) -> str:
"""Return the url to the file on the hf hub."""
assert self.override_download_url is None, f"{self.repo_id}/{self.filename} is not available on the hub"
assert self.download_url is None, f"{self.repo_id}/{self.filename} is not available on the hub"
return hf_hub_url(
repo_id=self.repo_id,
filename=self.filename,
revision=self.revision,
)

@property
def hf_metadata(self) -> HfFileMetadata:
"""Return the metadata of the file on the hf hub."""
return get_hf_file_metadata(self.hf_url)

@property
def hf_cache_path(self) -> Path:
"""Download the file from the hf hub and return its path in the local hf cache."""
assert self.download_url is None, f"{self.repo_id}/{self.filename} is not available on the hub"
return Path(
hf_hub_download(
repo_id=self.repo_id,
Expand All @@ -108,11 +115,6 @@ def hf_cache_path(self) -> Path:
),
)

@property
def hf_metadata(self) -> HfFileMetadata:
"""Return the metadata of the file on the hf hub."""
return get_hf_file_metadata(self.hf_url)

@property
def hf_sha256_hash(self) -> str:
"""Return the sha256 hash of the file on the hf hub."""
Expand All @@ -127,24 +129,32 @@ def local_path(self) -> Path:
return self.hub_location() / self.repo_id / self.filename

@property
def local_hash(self) -> str:
def local_sh256_hash(self) -> str:
"""Return the sha256 hash of the file in the local hub."""
assert self.local_path.is_file(), f"{self.local_path} does not exist"
# TODO: use https://docs.python.org/3/library/hashlib.html#hashlib.file_digest when support python >= 3.11
return sha256(self.local_path.read_bytes()).hexdigest().lower()

def check_local_hash(self) -> bool:
"""Check if the sha256 hash of the file in the local hub is correct."""
if self.expected_sha256 != self.local_hash:
logging.warning(f"{self.local_path} local sha256 mismatch, {self.local_hash} != {self.expected_sha256}")
if self.expected_sha256 is None:
logging.warning(f"{self.repo_id}/{self.filename} has no expected sha256 hash, skipping check")
return True
elif self.expected_sha256 != self.local_sh256_hash:
logging.warning(
f"{self.local_path} local sha256 mismatch, {self.local_sh256_hash} != {self.expected_sha256}"
)
return False
else:
logging.debug(f"{self.local_path} local sha256 is correct ({self.local_hash})")
logging.debug(f"{self.local_path} local sha256 is correct ({self.local_sh256_hash})")
return True

def check_remote_hash(self) -> bool:
"""Check if the sha256 hash of the file on the hf hub is correct."""
if self.expected_sha256 != self.hf_sha256_hash:
if self.expected_sha256 is None:
logging.warning(f"{self.repo_id}/{self.filename} has no expected sha256 hash, skipping check")
return True
elif self.expected_sha256 != self.hf_sha256_hash:
logging.warning(
f"{self.local_path} remote sha256 mismatch, {self.hf_sha256_hash} != {self.expected_sha256}"
)
Expand All @@ -154,14 +164,14 @@ def check_remote_hash(self) -> bool:
return True

def download(self) -> None:
"""Download the file from the hf hub or from the override download url."""
self.local_path.parent.mkdir(parents=True, exist_ok=True)
"""Download the file from the hf hub or from the override download url, and save it to the local hub."""
if self.local_path.is_file():
logging.warning(f"{self.local_path} already exists")
elif self.override_download_url is not None:
download_file_url(url=self.override_download_url, destination=self.local_path)
elif self.download_url is not None:
self.local_path.parent.mkdir(parents=True, exist_ok=True)
download_file_url(url=self.download_url, destination=self.local_path)
else:
# TODO: pas assez de message de log quand local_path existe pas et que ça vient du hf cache
self.local_path.parent.mkdir(parents=True, exist_ok=True)
self.local_path.symlink_to(self.hf_cache_path)
assert self.check_local_hash()

Expand Down
2 changes: 1 addition & 1 deletion tests/weight_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_path(hub: Hub, use_local_weights: bool) -> Path:
if use_local_weights:
path = hub.local_path
else:
if hub.override_download_url is not None:
if hub.download_url is not None:
pytest.skip(f"{hub.filename} is not available on Hugging Face Hub")

try:
Expand Down

0 comments on commit 432a6ff

Please sign in to comment.