Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle refs/convert/parquet and PR revision correctly in hffs #1712

Merged
merged 5 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions src/huggingface_hub/hf_file_system.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import os
import re
import tempfile
from dataclasses import dataclass
from datetime import datetime
Expand All @@ -23,6 +24,17 @@
)


# Regex used to match special revisions with "/" in them (see #1710)
SPECIAL_REFS_REVISION_REGEX = re.compile(
r"""
(^refs\/convert\/parquet) # `refs/convert/parquet` revisions
|
(^refs\/pr\/\d+) # PR revisions
""",
re.VERBOSE,
)


@dataclass
class HfFileSystemResolvedPath:
"""Data structure containing information about a resolved Hugging Face file system path."""
Expand All @@ -34,7 +46,7 @@ class HfFileSystemResolvedPath:

def unresolve(self) -> str:
return (
f"{REPO_TYPES_URL_PREFIXES.get(self.repo_type, '') + self.repo_id}@{safe_quote(self.revision)}/{self.path_in_repo}"
f"{REPO_TYPES_URL_PREFIXES.get(self.repo_type, '') + self.repo_id}@{safe_revision(self.revision)}/{self.path_in_repo}"
.rstrip("/")
)

Expand Down Expand Up @@ -156,7 +168,13 @@ def _align_revision_in_path_with_revision(
if "@" in path:
repo_id, revision_in_path = path.split("@", 1)
if "/" in revision_in_path:
revision_in_path, path_in_repo = revision_in_path.split("/", 1)
match = SPECIAL_REFS_REVISION_REGEX.search(revision_in_path)
if match is not None and revision in (None, match.group()):
# Handle `refs/convert/parquet` and PR revisions separately
path_in_repo = SPECIAL_REFS_REVISION_REGEX.sub("", revision_in_path).lstrip("/")
revision_in_path = match.group()
else:
revision_in_path, path_in_repo = revision_in_path.split("/", 1)
else:
path_in_repo = ""
revision_in_path = unquote(revision_in_path)
Expand Down Expand Up @@ -262,7 +280,7 @@ def ls(
) -> List[Union[str, Dict[str, Any]]]:
"""List the contents of a directory."""
resolved_path = self.resolve_path(path, revision=revision)
revision_in_path = "@" + safe_quote(resolved_path.revision)
revision_in_path = "@" + safe_revision(resolved_path.revision)
has_revision_in_path = revision_in_path in path
path = resolved_path.unresolve()
if path not in self.dircache or refresh:
Expand Down Expand Up @@ -367,7 +385,7 @@ def modified(self, path: str, **kwargs) -> datetime:
def info(self, path: str, **kwargs) -> Dict[str, Any]:
resolved_path = self.resolve_path(path)
if not resolved_path.path_in_repo:
revision_in_path = "@" + safe_quote(resolved_path.revision)
revision_in_path = "@" + safe_revision(resolved_path.revision)
has_revision_in_path = revision_in_path in path
name = resolved_path.unresolve()
name = name.replace(revision_in_path, "", 1) if not has_revision_in_path else name
Expand Down Expand Up @@ -418,5 +436,9 @@ def _upload_chunk(self, final: bool = False) -> None:
)


def safe_revision(revision: str) -> str:
return revision if SPECIAL_REFS_REVISION_REGEX.match(revision) else safe_quote(revision)


def safe_quote(s: str) -> str:
return quote(s, safe="")
66 changes: 58 additions & 8 deletions tests/test_hf_file_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def test_list_data_directory_with_revision(self):
self.assertTrue(files[0]["name"].endswith("/data/binary_data_for_pr.bin")) # PR file


@pytest.mark.parametrize("path_in_repo", ["", "foo"])
@pytest.mark.parametrize("path_in_repo", ["", "file.txt", "path/to/file"])
@pytest.mark.parametrize(
"root_path,revision,repo_type,repo_id,resolved_revision",
[
Expand All @@ -238,6 +238,19 @@ def test_list_data_directory_with_revision(self):
("hf://datasets/squad", None, "dataset", "squad", "main"),
("hf://datasets/squad", "dev", "dataset", "squad", "dev"),
("hf://datasets/squad@dev", None, "dataset", "squad", "dev"),
# Parse with `refs/convert/parquet` and `refs/pr/(\d)+` revisions.
# Regression tests for https://github.com/huggingface/huggingface_hub/issues/1710.
("datasets/squad@refs/convert/parquet", None, "dataset", "squad", "refs/convert/parquet"),
(
"hf://datasets/username/my_dataset@refs/convert/parquet",
None,
"dataset",
"username/my_dataset",
"refs/convert/parquet",
),
("gpt2@refs/pr/2", None, "model", "gpt2", "refs/pr/2"),
("hf://username/my_model@refs/pr/10", None, "model", "username/my_model", "refs/pr/10"),
("hf://username/my_model@refs/pr/10", "refs/pr/10", "model", "username/my_model", "refs/pr/10"),
],
)
def test_resolve_path(
Expand All @@ -251,13 +264,7 @@ def test_resolve_path(
fs = HfFileSystem()
path = root_path + "/" + path_in_repo if path_in_repo else root_path

def mock_repo_info(repo_id: str, *, revision: str, repo_type: str, **kwargs):
if repo_id not in ["gpt2", "squad", "username/my_dataset", "username/my_model"]:
raise RepositoryNotFoundError(repo_id)
if revision is not None and revision not in ["main", "dev"]:
raise RevisionNotFoundError(revision)

with patch.object(fs._api, "repo_info", mock_repo_info):
with mock_repo_info(fs):
resolved_path = fs.resolve_path(path, revision=revision)
assert (
resolved_path.repo_type,
Expand All @@ -267,6 +274,49 @@ def mock_repo_info(repo_id: str, *, revision: str, repo_type: str, **kwargs):
) == (repo_type, repo_id, resolved_revision, path_in_repo)


@pytest.mark.parametrize("path_in_repo", ["", "file.txt", "path/to/file"])
@pytest.mark.parametrize(
"path,revision,expected_path",
[
("hf://datasets/squad@dev", None, "datasets/squad@dev"),
("datasets/squad@refs/convert/parquet", None, "datasets/squad@refs/convert/parquet"),
("hf://username/my_model@refs/pr/10", None, "username/my_model@refs/pr/10"),
("username/my_model", "refs/weirdo", "username/my_model@refs%2Fweirdo"), # not a "special revision" -> encode
],
)
def test_unresolve_path(path: str, revision: Optional[str], expected_path: str, path_in_repo: str) -> None:
fs = HfFileSystem()
path = path + "/" + path_in_repo if path_in_repo else path
expected_path = expected_path + "/" + path_in_repo if path_in_repo else expected_path

with mock_repo_info(fs):
assert fs.resolve_path(path, revision=revision).unresolve() == expected_path


def test_resolve_path_with_refs_revision() -> None:
"""
Testing a very specific edge case where a user has a repo with a revisions named "refs" and a file/directory
named "pr/10". We can still process them but the user has to use the `revision` argument to disambiguate between
the two.
"""
fs = HfFileSystem()
with mock_repo_info(fs):
resolved = fs.resolve_path("hf://username/my_model@refs/pr/10", revision="refs")
assert resolved.revision == "refs"
assert resolved.path_in_repo == "pr/10"
assert resolved.unresolve() == "username/my_model@refs/pr/10"


def mock_repo_info(fs: HfFileSystem):
def _inner(repo_id: str, *, revision: str, repo_type: str, **kwargs):
if repo_id not in ["gpt2", "squad", "username/my_dataset", "username/my_model"]:
raise RepositoryNotFoundError(repo_id)
if revision is not None and revision not in ["main", "dev", "refs"] and not revision.startswith("refs/"):
raise RevisionNotFoundError(revision)

return patch.object(fs._api, "repo_info", _inner)


def test_resolve_path_with_non_matching_revisions():
fs = HfFileSystem()
with pytest.raises(ValueError):
Expand Down
Loading