Skip to content

Commit

Permalink
Handle refs/convert/parquet and PR revision correctly in hffs (#1712)
Browse files Browse the repository at this point in the history
* Handle  correctly in hffs

* styling

* fix unresolve + more tests

* add corner case
  • Loading branch information
Wauplin authored Oct 5, 2023
1 parent 90c1182 commit cbcd8b2
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 12 deletions.
30 changes: 26 additions & 4 deletions src/huggingface_hub/hf_file_system.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import os
import re
import tempfile
from dataclasses import dataclass
from datetime import datetime
Expand All @@ -23,6 +24,17 @@
)


# Regex used to match special revisions with "/" in them (see #1710)
SPECIAL_REFS_REVISION_REGEX = re.compile(
r"""
(^refs\/convert\/parquet) # `refs/convert/parquet` revisions
|
(^refs\/pr\/\d+) # PR revisions
""",
re.VERBOSE,
)


@dataclass
class HfFileSystemResolvedPath:
"""Data structure containing information about a resolved Hugging Face file system path."""
Expand All @@ -34,7 +46,7 @@ class HfFileSystemResolvedPath:

def unresolve(self) -> str:
return (
f"{REPO_TYPES_URL_PREFIXES.get(self.repo_type, '') + self.repo_id}@{safe_quote(self.revision)}/{self.path_in_repo}"
f"{REPO_TYPES_URL_PREFIXES.get(self.repo_type, '') + self.repo_id}@{safe_revision(self.revision)}/{self.path_in_repo}"
.rstrip("/")
)

Expand Down Expand Up @@ -156,7 +168,13 @@ def _align_revision_in_path_with_revision(
if "@" in path:
repo_id, revision_in_path = path.split("@", 1)
if "/" in revision_in_path:
revision_in_path, path_in_repo = revision_in_path.split("/", 1)
match = SPECIAL_REFS_REVISION_REGEX.search(revision_in_path)
if match is not None and revision in (None, match.group()):
# Handle `refs/convert/parquet` and PR revisions separately
path_in_repo = SPECIAL_REFS_REVISION_REGEX.sub("", revision_in_path).lstrip("/")
revision_in_path = match.group()
else:
revision_in_path, path_in_repo = revision_in_path.split("/", 1)
else:
path_in_repo = ""
revision_in_path = unquote(revision_in_path)
Expand Down Expand Up @@ -262,7 +280,7 @@ def ls(
) -> List[Union[str, Dict[str, Any]]]:
"""List the contents of a directory."""
resolved_path = self.resolve_path(path, revision=revision)
revision_in_path = "@" + safe_quote(resolved_path.revision)
revision_in_path = "@" + safe_revision(resolved_path.revision)
has_revision_in_path = revision_in_path in path
path = resolved_path.unresolve()
if path not in self.dircache or refresh:
Expand Down Expand Up @@ -367,7 +385,7 @@ def modified(self, path: str, **kwargs) -> datetime:
def info(self, path: str, **kwargs) -> Dict[str, Any]:
resolved_path = self.resolve_path(path)
if not resolved_path.path_in_repo:
revision_in_path = "@" + safe_quote(resolved_path.revision)
revision_in_path = "@" + safe_revision(resolved_path.revision)
has_revision_in_path = revision_in_path in path
name = resolved_path.unresolve()
name = name.replace(revision_in_path, "", 1) if not has_revision_in_path else name
Expand Down Expand Up @@ -418,5 +436,9 @@ def _upload_chunk(self, final: bool = False) -> None:
)


def safe_revision(revision: str) -> str:
return revision if SPECIAL_REFS_REVISION_REGEX.match(revision) else safe_quote(revision)


def safe_quote(s: str) -> str:
return quote(s, safe="")
66 changes: 58 additions & 8 deletions tests/test_hf_file_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def test_list_data_directory_with_revision(self):
self.assertTrue(files[0]["name"].endswith("/data/binary_data_for_pr.bin")) # PR file


@pytest.mark.parametrize("path_in_repo", ["", "foo"])
@pytest.mark.parametrize("path_in_repo", ["", "file.txt", "path/to/file"])
@pytest.mark.parametrize(
"root_path,revision,repo_type,repo_id,resolved_revision",
[
Expand All @@ -238,6 +238,19 @@ def test_list_data_directory_with_revision(self):
("hf://datasets/squad", None, "dataset", "squad", "main"),
("hf://datasets/squad", "dev", "dataset", "squad", "dev"),
("hf://datasets/squad@dev", None, "dataset", "squad", "dev"),
# Parse with `refs/convert/parquet` and `refs/pr/(\d)+` revisions.
# Regression tests for https://github.com/huggingface/huggingface_hub/issues/1710.
("datasets/squad@refs/convert/parquet", None, "dataset", "squad", "refs/convert/parquet"),
(
"hf://datasets/username/my_dataset@refs/convert/parquet",
None,
"dataset",
"username/my_dataset",
"refs/convert/parquet",
),
("gpt2@refs/pr/2", None, "model", "gpt2", "refs/pr/2"),
("hf://username/my_model@refs/pr/10", None, "model", "username/my_model", "refs/pr/10"),
("hf://username/my_model@refs/pr/10", "refs/pr/10", "model", "username/my_model", "refs/pr/10"),
],
)
def test_resolve_path(
Expand All @@ -251,13 +264,7 @@ def test_resolve_path(
fs = HfFileSystem()
path = root_path + "/" + path_in_repo if path_in_repo else root_path

def mock_repo_info(repo_id: str, *, revision: str, repo_type: str, **kwargs):
if repo_id not in ["gpt2", "squad", "username/my_dataset", "username/my_model"]:
raise RepositoryNotFoundError(repo_id)
if revision is not None and revision not in ["main", "dev"]:
raise RevisionNotFoundError(revision)

with patch.object(fs._api, "repo_info", mock_repo_info):
with mock_repo_info(fs):
resolved_path = fs.resolve_path(path, revision=revision)
assert (
resolved_path.repo_type,
Expand All @@ -267,6 +274,49 @@ def mock_repo_info(repo_id: str, *, revision: str, repo_type: str, **kwargs):
) == (repo_type, repo_id, resolved_revision, path_in_repo)


@pytest.mark.parametrize("path_in_repo", ["", "file.txt", "path/to/file"])
@pytest.mark.parametrize(
"path,revision,expected_path",
[
("hf://datasets/squad@dev", None, "datasets/squad@dev"),
("datasets/squad@refs/convert/parquet", None, "datasets/squad@refs/convert/parquet"),
("hf://username/my_model@refs/pr/10", None, "username/my_model@refs/pr/10"),
("username/my_model", "refs/weirdo", "username/my_model@refs%2Fweirdo"), # not a "special revision" -> encode
],
)
def test_unresolve_path(path: str, revision: Optional[str], expected_path: str, path_in_repo: str) -> None:
fs = HfFileSystem()
path = path + "/" + path_in_repo if path_in_repo else path
expected_path = expected_path + "/" + path_in_repo if path_in_repo else expected_path

with mock_repo_info(fs):
assert fs.resolve_path(path, revision=revision).unresolve() == expected_path


def test_resolve_path_with_refs_revision() -> None:
"""
Testing a very specific edge case where a user has a repo with a revisions named "refs" and a file/directory
named "pr/10". We can still process them but the user has to use the `revision` argument to disambiguate between
the two.
"""
fs = HfFileSystem()
with mock_repo_info(fs):
resolved = fs.resolve_path("hf://username/my_model@refs/pr/10", revision="refs")
assert resolved.revision == "refs"
assert resolved.path_in_repo == "pr/10"
assert resolved.unresolve() == "username/my_model@refs/pr/10"


def mock_repo_info(fs: HfFileSystem):
def _inner(repo_id: str, *, revision: str, repo_type: str, **kwargs):
if repo_id not in ["gpt2", "squad", "username/my_dataset", "username/my_model"]:
raise RepositoryNotFoundError(repo_id)
if revision is not None and revision not in ["main", "dev", "refs"] and not revision.startswith("refs/"):
raise RevisionNotFoundError(revision)

return patch.object(fs._api, "repo_info", _inner)


def test_resolve_path_with_non_matching_revisions():
fs = HfFileSystem()
with pytest.raises(ValueError):
Expand Down

0 comments on commit cbcd8b2

Please sign in to comment.