diff --git a/docs/source/package_reference/hf_api.mdx b/docs/source/package_reference/hf_api.mdx index 0549e6c693..0e8dfabb38 100644 --- a/docs/source/package_reference/hf_api.mdx +++ b/docs/source/package_reference/hf_api.mdx @@ -25,16 +25,30 @@ models = hf_api.list_models() Using the `HfApi` class directly enables you to set a different endpoint to that of the Hugging Face's Hub. +### HfApi + [[autodoc]] HfApi +### ModelInfo + [[autodoc]] huggingface_hub.hf_api.ModelInfo +### DatasetInfo + [[autodoc]] huggingface_hub.hf_api.DatasetInfo +### SpaceInfo + [[autodoc]] huggingface_hub.hf_api.SpaceInfo +### RepoFile + [[autodoc]] huggingface_hub.hf_api.RepoFile +### CommitInfo + +[[autodoc]] huggingface_hub.hf_api.CommitInfo + ## `create_commit` API Below are the supported values for [`CommitOperation`]: @@ -56,10 +70,18 @@ It does this using the [`HfFolder`] utility, which saves data at the root of the Some helpers to filter repositories on the Hub are available in the `huggingface_hub` package. +### DatasetFilter + [[autodoc]] DatasetFilter +### ModelFilter + [[autodoc]] ModelFilter +### DatasetSearchArguments + [[autodoc]] DatasetSearchArguments +### ModelSearchArguments + [[autodoc]] ModelSearchArguments diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index a413530e0e..89a2e9399e 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -93,6 +93,7 @@ "try_to_load_from_cache", ], "hf_api": [ + "CommitInfo", "CommitOperation", "CommitOperationAdd", "CommitOperationDelete", @@ -306,6 +307,7 @@ def __dir__(): from .file_download import hf_hub_download # noqa: F401 from .file_download import hf_hub_url # noqa: F401 from .file_download import try_to_load_from_cache # noqa: F401 + from .hf_api import CommitInfo # noqa: F401 from .hf_api import CommitOperation # noqa: F401 from .hf_api import CommitOperationAdd # noqa: F401 from .hf_api import CommitOperationDelete # noqa: F401 diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index 418a2680f1..7cb44c8546 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -16,6 +16,7 @@ import re import subprocess import warnings +from dataclasses import dataclass, field from typing import BinaryIO, Dict, Iterable, Iterator, List, Optional, Tuple, Union from urllib.parse import quote @@ -171,6 +172,62 @@ class BlobLfsInfo(TypedDict, total=False): sha256: str +@dataclass +class CommitInfo: + """Data structure containing information about a newly created commit. + + Returned by [`create_commit`]. + + Args: + commit_url (`str`): + Url where to find the commit. + + commit_message (`str`): + The summary (first line) of the commit that has been created. + + commit_description (`str`): + Description of the commit that has been created. Can be empty. + + oid (`str`): + Commit hash id. Example: `"91c54ad1727ee830252e457677f467be0bfd8a57"`. + + pr_url (`str`, *optional*): + Url to the PR that has been created, if any. Populated when `create_pr=True` + is passed. + + pr_revision (`str`, *optional*): + Revision of the PR that has been created, if any. Populated when + `create_pr=True` is passed. Example: `"refs/pr/1"`. + + pr_num (`int`, *optional*): + Number of the PR discussion that has been created, if any. Populated when + `create_pr=True` is passed. Can be passed as `discussion_num` in + [`get_discussion_details`]. Example: `1`. + """ + + commit_url: str + commit_message: str + commit_description: str + oid: str + pr_url: Optional[str] = None + + # Computed from `pr_url` in `__post_init__` + pr_revision: Optional[str] = field(init=False) + pr_num: Optional[str] = field(init=False) + + def __post_init__(self): + """Populate pr-related fields after initialization. + + See https://docs.python.org/3.10/library/dataclasses.html#post-init-processing. + """ + if self.pr_url is not None: + self.pr_revision = _parse_revision_from_pr_url(self.pr_url) + self.pr_num = int(self.pr_revision.split("/")[-1]) + else: + self.pr_revision = None + self.pr_num = None + + class RepoFile: """ Data structure that represents a public file inside a repo, accessible from @@ -1850,7 +1907,7 @@ def create_commit( create_pr: Optional[bool] = None, num_threads: int = 5, parent_commit: Optional[str] = None, - ) -> Optional[str]: + ) -> CommitInfo: """ Creates a commit in the given repo, deleting & uploading files as needed. @@ -1902,9 +1959,9 @@ def create_commit( if the repo is updated / committed to concurrently. Returns: - `str` or `None`: - If `create_pr` is `True`, returns the URL to the newly created Pull Request - on the Hub. Otherwise returns `None`. + [`CommitInfo`]: + Instance of [`CommitInfo`] containing information about the newly + created commit (commit hash, commit url, pr url, commit message,...). Raises: [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) @@ -2015,7 +2072,14 @@ def create_commit( params={"create_pr": "1"} if create_pr else None, ) hf_raise_for_status(commit_resp, endpoint_name="commit") - return commit_resp.json().get("pullRequestUrl", None) + commit_data = commit_resp.json() + return CommitInfo( + commit_url=commit_data["commitUrl"], + commit_message=commit_message, + commit_description=commit_description, + oid=commit_data["commitOid"], + pr_url=commit_data["pullRequestUrl"] if create_pr else None, + ) @validate_hf_hub_args def upload_file( @@ -2157,7 +2221,7 @@ def upload_file( path_in_repo=path_in_repo, ) - pr_url = self.create_commit( + commit_info = self.create_commit( repo_id=repo_id, repo_type=repo_type, operations=[operation], @@ -2169,8 +2233,8 @@ def upload_file( parent_commit=parent_commit, ) - if pr_url is not None: - revision = quote(_parse_revision_from_pr_url(pr_url), safe="") + if commit_info.pr_url is not None: + revision = quote(_parse_revision_from_pr_url(commit_info.pr_url), safe="") if repo_type in REPO_TYPES_URL_PREFIXES: repo_id = REPO_TYPES_URL_PREFIXES[repo_type] + repo_id revision = revision if revision is not None else DEFAULT_REVISION @@ -2317,7 +2381,7 @@ def upload_folder( ignore_patterns=ignore_patterns, ) - pr_url = self.create_commit( + commit_info = self.create_commit( repo_type=repo_type, repo_id=repo_id, operations=files_to_add, @@ -2329,8 +2393,8 @@ def upload_folder( parent_commit=parent_commit, ) - if pr_url is not None: - revision = quote(_parse_revision_from_pr_url(pr_url), safe="") + if commit_info.pr_url is not None: + revision = quote(_parse_revision_from_pr_url(commit_info.pr_url), safe="") if repo_type in REPO_TYPES_URL_PREFIXES: repo_id = REPO_TYPES_URL_PREFIXES[repo_type] + repo_id revision = revision if revision is not None else DEFAULT_REVISION @@ -2350,7 +2414,7 @@ def delete_file( commit_description: Optional[str] = None, create_pr: Optional[bool] = None, parent_commit: Optional[str] = None, - ): + ) -> CommitInfo: """ Deletes a file in the given repo. diff --git a/src/huggingface_hub/keras_mixin.py b/src/huggingface_hub/keras_mixin.py index 9872bf7502..923451cec1 100644 --- a/src/huggingface_hub/keras_mixin.py +++ b/src/huggingface_hub/keras_mixin.py @@ -446,7 +446,7 @@ def push_to_hub_keras( allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, ) - pr_url = api.create_commit( + commit_info = api.create_commit( repo_type="model", repo_id=repo_id, operations=operations, @@ -458,8 +458,8 @@ def push_to_hub_keras( revision = branch if revision is None: revision = ( - quote(_parse_revision_from_pr_url(pr_url), safe="") - if pr_url is not None + quote(_parse_revision_from_pr_url(commit_info.pr_url), safe="") + if commit_info.pr_url is not None else DEFAULT_REVISION ) return f"{api.endpoint}/{repo_id}/tree/{revision}/" diff --git a/src/huggingface_hub/utils/_validators.py b/src/huggingface_hub/utils/_validators.py index 79435f6c8b..8dc0eba434 100644 --- a/src/huggingface_hub/utils/_validators.py +++ b/src/huggingface_hub/utils/_validators.py @@ -17,7 +17,7 @@ import re from functools import wraps from itertools import chain -from typing import Callable +from typing import TypeVar REPO_ID_REGEX = re.compile( @@ -40,7 +40,11 @@ class HFValidationError(ValueError): """ -def validate_hf_hub_args(fn: Callable) -> Callable: +# type hint meaning "function signature not changed by decorator" +CallableT = TypeVar("CallableT") # callable type + + +def validate_hf_hub_args(fn: CallableT) -> CallableT: """Validate values received as argument for any public method of `huggingface_hub`. The goal of this decorator is to harmonize validation of arguments reused diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index 66c73006ea..43e5ff6223 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -41,6 +41,7 @@ from huggingface_hub.file_download import cached_download, hf_hub_download from huggingface_hub.hf_api import ( USERNAME_PLACEHOLDER, + CommitInfo, DatasetInfo, DatasetSearchArguments, HfApi, @@ -766,10 +767,26 @@ def test_create_commit_create_pr(self): token=self._token, create_pr=True, ) + + # Check commit info + self.assertIsInstance(resp, CommitInfo) + commit_id = resp.oid + self.assertIn("pr_revision='refs/pr/1'", str(resp)) + self.assertIsInstance(commit_id, str) + self.assertGreater(len(commit_id), 0) + self.assertEqual( + resp.commit_url, + f"{self._api.endpoint}/{USER}/{REPO_NAME}/commit/{commit_id}", + ) + self.assertEqual(resp.commit_message, "Test create_commit") + self.assertEqual(resp.commit_description, "") self.assertEqual( - resp, + resp.pr_url, f"{self._api.endpoint}/{USER}/{REPO_NAME}/discussions/1", ) + self.assertEqual(resp.pr_num, 1) + self.assertEqual(resp.pr_revision, "refs/pr/1") + with self.assertRaises(HTTPError) as ctx: # Should raise a 404 hf_hub_download( @@ -830,13 +847,17 @@ def test_create_commit(self): path_or_fileobj=self.tmp_file, ), ] - return_val = self._api.create_commit( + resp = self._api.create_commit( operations=operations, commit_message="Test create_commit", repo_id=f"{USER}/{REPO_NAME}", token=self._token, ) - self.assertIsNone(return_val) + # Check commit info + self.assertIsInstance(resp, CommitInfo) + self.assertIsNone(resp.pr_url) # No pr created + self.assertIsNone(resp.pr_num) + self.assertIsNone(resp.pr_revision) with self.assertRaises(HTTPError): # Should raise a 404 hf_hub_download( diff --git a/tests/test_init_lazy_loading.py b/tests/test_init_lazy_loading.py index 54b747e479..c66b9db0d5 100644 --- a/tests/test_init_lazy_loading.py +++ b/tests/test_init_lazy_loading.py @@ -31,7 +31,7 @@ def test_autocomplete_on_root_imports(self) -> None: self.assertTrue( signature_list[0] .docstring() - .startswith("create_commit(self, repo_id: str") + .startswith("create_commit(repo_id: str,") ) break else: