From 6d6e4715d9ae1cab29ae6c943d4a019b10f60d0c Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Tue, 20 Sep 2022 11:29:02 +0200 Subject: [PATCH 1/8] Return more information in create_commit output --- docs/source/package_reference/hf_api.mdx | 22 +++++++++ src/huggingface_hub/__init__.py | 1 + src/huggingface_hub/hf_api.py | 63 +++++++++++++++++++----- src/huggingface_hub/keras_mixin.py | 6 +-- src/huggingface_hub/utils/_validators.py | 8 ++- tests/test_hf_api.py | 21 ++++++-- 6 files changed, 101 insertions(+), 20 deletions(-) diff --git a/docs/source/package_reference/hf_api.mdx b/docs/source/package_reference/hf_api.mdx index 0549e6c693..0e8dfabb38 100644 --- a/docs/source/package_reference/hf_api.mdx +++ b/docs/source/package_reference/hf_api.mdx @@ -25,16 +25,30 @@ models = hf_api.list_models() Using the `HfApi` class directly enables you to set a different endpoint to that of the Hugging Face's Hub. +### HfApi + [[autodoc]] HfApi +### ModelInfo + [[autodoc]] huggingface_hub.hf_api.ModelInfo +### DatasetInfo + [[autodoc]] huggingface_hub.hf_api.DatasetInfo +### SpaceInfo + [[autodoc]] huggingface_hub.hf_api.SpaceInfo +### RepoFile + [[autodoc]] huggingface_hub.hf_api.RepoFile +### CommitInfo + +[[autodoc]] huggingface_hub.hf_api.CommitInfo + ## `create_commit` API Below are the supported values for [`CommitOperation`]: @@ -56,10 +70,18 @@ It does this using the [`HfFolder`] utility, which saves data at the root of the Some helpers to filter repositories on the Hub are available in the `huggingface_hub` package. +### DatasetFilter + [[autodoc]] DatasetFilter +### ModelFilter + [[autodoc]] ModelFilter +### DatasetSearchArguments + [[autodoc]] DatasetSearchArguments +### ModelSearchArguments + [[autodoc]] ModelSearchArguments diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index 78c51bd1e9..0d2463aad4 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -103,6 +103,7 @@ "change_discussion_status", "comment_discussion", "create_commit", + "CommitInfo", "create_discussion", "create_pull_request", "create_repo", diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index e0758c0665..3a183860a0 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -16,6 +16,7 @@ import re import subprocess import warnings +from dataclasses import dataclass from os.path import expanduser from typing import BinaryIO, Dict, Iterable, Iterator, List, Optional, Tuple, Union from urllib.parse import quote @@ -170,6 +171,37 @@ class BlobLfsInfo(TypedDict, total=False): sha256: str +@dataclass +class CommitInfo: + """Data structure containing information about a newly created commit. + + Returned by `create_commit(...)`. + + Args: + commit_url (str): + Url where to find the commit. + + commit_message (str): + The summary (first line) of the commit that has been created. + + commit_description (str): + Description of the commit that has been created. Can be empty. + + oid (str): + Commit hash id. Example: `"91c54ad1727ee830252e457677f467be0bfd8a57"`. + + pr_url (str, *optional*): + Url to the PR that has been created, if any. Populated when `create_pr=True` + is passed. + """ + + commit_url: str + commit_message: str + commit_description: str + oid: str + pr_url: Optional[str] = None + + class RepoFile: """ Data structure that represents a public file inside a repo, accessible from @@ -1953,7 +1985,7 @@ def create_commit( create_pr: Optional[bool] = None, num_threads: int = 5, parent_commit: Optional[str] = None, - ) -> Optional[str]: + ) -> CommitInfo: """ Creates a commit in the given repo, deleting & uploading files as needed. @@ -2005,9 +2037,9 @@ def create_commit( if the repo is updated / committed to concurrently. Returns: - `str` or `None`: - If `create_pr` is `True`, returns the URL to the newly created Pull Request - on the Hub. Otherwise returns `None`. + [`CommitInfo`]: + Instance of [`CommitInfo`] containing information about the newly + created commit (commit hash, commit url, pr url, commit message,...). Raises: [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) @@ -2118,7 +2150,14 @@ def create_commit( params={"create_pr": "1"} if create_pr else None, ) hf_raise_for_status(commit_resp, endpoint_name="commit") - return commit_resp.json().get("pullRequestUrl", None) + commit_data = commit_resp.json() + return CommitInfo( + commit_url=commit_data["commitUrl"], + commit_message=commit_message, + commit_description=commit_description, + oid=commit_data["commitOid"], + pr_url=commit_data["pullRequestUrl"] if create_pr else None, + ) @validate_hf_hub_args def upload_file( @@ -2260,7 +2299,7 @@ def upload_file( path_in_repo=path_in_repo, ) - pr_url = self.create_commit( + commit_info = self.create_commit( repo_id=repo_id, repo_type=repo_type, operations=[operation], @@ -2272,8 +2311,8 @@ def upload_file( parent_commit=parent_commit, ) - if pr_url is not None: - revision = quote(_parse_revision_from_pr_url(pr_url), safe="") + if commit_info.pr_url is not None: + revision = quote(_parse_revision_from_pr_url(commit_info.pr_url), safe="") if repo_type in REPO_TYPES_URL_PREFIXES: repo_id = REPO_TYPES_URL_PREFIXES[repo_type] + repo_id revision = revision if revision is not None else DEFAULT_REVISION @@ -2420,7 +2459,7 @@ def upload_folder( ignore_patterns=ignore_patterns, ) - pr_url = self.create_commit( + commit_info = self.create_commit( repo_type=repo_type, repo_id=repo_id, operations=files_to_add, @@ -2432,8 +2471,8 @@ def upload_folder( parent_commit=parent_commit, ) - if pr_url is not None: - revision = quote(_parse_revision_from_pr_url(pr_url), safe="") + if commit_info.pr_url is not None: + revision = quote(_parse_revision_from_pr_url(commit_info.pr_url), safe="") if repo_type in REPO_TYPES_URL_PREFIXES: repo_id = REPO_TYPES_URL_PREFIXES[repo_type] + repo_id revision = revision if revision is not None else DEFAULT_REVISION @@ -2453,7 +2492,7 @@ def delete_file( commit_description: Optional[str] = None, create_pr: Optional[bool] = None, parent_commit: Optional[str] = None, - ): + ) -> CommitInfo: """ Deletes a file in the given repo. diff --git a/src/huggingface_hub/keras_mixin.py b/src/huggingface_hub/keras_mixin.py index dfcd90758b..6bfe5121df 100644 --- a/src/huggingface_hub/keras_mixin.py +++ b/src/huggingface_hub/keras_mixin.py @@ -458,7 +458,7 @@ def push_to_hub_keras( allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, ) - pr_url = api.create_commit( + commit_info = api.create_commit( repo_type="model", repo_id=repo_id, operations=operations, @@ -470,8 +470,8 @@ def push_to_hub_keras( revision = branch if revision is None: revision = ( - quote(_parse_revision_from_pr_url(pr_url), safe="") - if pr_url is not None + quote(_parse_revision_from_pr_url(commit_info.pr_url), safe="") + if commit_info.pr_url is not None else DEFAULT_REVISION ) return f"{api.endpoint}/{repo_id}/tree/{revision}/" diff --git a/src/huggingface_hub/utils/_validators.py b/src/huggingface_hub/utils/_validators.py index 79435f6c8b..e34b8b59b2 100644 --- a/src/huggingface_hub/utils/_validators.py +++ b/src/huggingface_hub/utils/_validators.py @@ -17,7 +17,7 @@ import re from functools import wraps from itertools import chain -from typing import Callable +from typing import Callable, TypeVar REPO_ID_REGEX = re.compile( @@ -40,7 +40,11 @@ class HFValidationError(ValueError): """ -def validate_hf_hub_args(fn: Callable) -> Callable: +# type hint meaning "function signature not changed by decorator" +CT = TypeVar("CT") # callable type + + +def validate_hf_hub_args(fn: CT) -> CT: """Validate values received as argument for any public method of `huggingface_hub`. The goal of this decorator is to harmonize validation of arguments reused diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index 3bd41e2a99..c421447532 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -40,6 +40,7 @@ from huggingface_hub.file_download import cached_download, hf_hub_download from huggingface_hub.hf_api import ( USERNAME_PLACEHOLDER, + CommitInfo, DatasetInfo, DatasetSearchArguments, HfApi, @@ -759,10 +760,22 @@ def test_create_commit_create_pr(self): token=self._token, create_pr=True, ) + + # Check commit info + self.assertIsInstance(resp, CommitInfo) + commit_id = resp.oid + self.assertGreater(len(resp.oid), 0) + self.assertEqual( + resp.commit_url, + f"{self._api.endpoint}/{USER}/{REPO_NAME}/commit/{commit_id}", + ) + self.assertEqual(resp.commit_message, "Test create_commit") + self.assertEqual(resp.commit_description, "") self.assertEqual( - resp, + resp.pr_url, f"{self._api.endpoint}/{USER}/{REPO_NAME}/discussions/1", ) + with self.assertRaises(HTTPError) as ctx: # Should raise a 404 hf_hub_download( @@ -823,13 +836,15 @@ def test_create_commit(self): path_or_fileobj=self.tmp_file, ), ] - return_val = self._api.create_commit( + resp = self._api.create_commit( operations=operations, commit_message="Test create_commit", repo_id=f"{USER}/{REPO_NAME}", token=self._token, ) - self.assertIsNone(return_val) + # Check commit info + self.assertIsInstance(resp, CommitInfo) + self.assertIsNone(resp.pr_url) # No pr created with self.assertRaises(HTTPError): # Should raise a 404 hf_hub_download( From 4d6712459daaf806b2c71736c8d0452bb1824ff9 Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Tue, 20 Sep 2022 11:41:00 +0200 Subject: [PATCH 2/8] flake8 --- src/huggingface_hub/__init__.py | 3 ++- src/huggingface_hub/utils/_validators.py | 2 +- tests/test_hf_api.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index 0d2463aad4..7f154f75f6 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -93,6 +93,7 @@ "try_to_load_from_cache", ], "hf_api": [ + "CommitInfo", "CommitOperation", "CommitOperationAdd", "CommitOperationDelete", @@ -103,7 +104,6 @@ "change_discussion_status", "comment_discussion", "create_commit", - "CommitInfo", "create_discussion", "create_pull_request", "create_repo", @@ -307,6 +307,7 @@ def __dir__(): from .file_download import hf_hub_download # noqa: F401 from .file_download import hf_hub_url # noqa: F401 from .file_download import try_to_load_from_cache # noqa: F401 + from .hf_api import CommitInfo # noqa: F401 from .hf_api import CommitOperation # noqa: F401 from .hf_api import CommitOperationAdd # noqa: F401 from .hf_api import CommitOperationDelete # noqa: F401 diff --git a/src/huggingface_hub/utils/_validators.py b/src/huggingface_hub/utils/_validators.py index e34b8b59b2..dd904b7f68 100644 --- a/src/huggingface_hub/utils/_validators.py +++ b/src/huggingface_hub/utils/_validators.py @@ -17,7 +17,7 @@ import re from functools import wraps from itertools import chain -from typing import Callable, TypeVar +from typing import TypeVar REPO_ID_REGEX = re.compile( diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index c421447532..74395f8f48 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -844,7 +844,7 @@ def test_create_commit(self): ) # Check commit info self.assertIsInstance(resp, CommitInfo) - self.assertIsNone(resp.pr_url) # No pr created + self.assertIsNone(resp.pr_url) # No pr created with self.assertRaises(HTTPError): # Should raise a 404 hf_hub_download( From 3e9d8650544fc197f0bcc525682916281b257cce Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Tue, 20 Sep 2022 15:05:29 +0200 Subject: [PATCH 3/8] requested changes --- src/huggingface_hub/hf_api.py | 12 ++++++------ src/huggingface_hub/utils/_validators.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index 3a183860a0..e4c1cb6a30 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -175,22 +175,22 @@ class BlobLfsInfo(TypedDict, total=False): class CommitInfo: """Data structure containing information about a newly created commit. - Returned by `create_commit(...)`. + Returned by [`create_commit`]. Args: - commit_url (str): + commit_url (`str`): Url where to find the commit. - commit_message (str): + commit_message (`str`): The summary (first line) of the commit that has been created. - commit_description (str): + commit_description (`str`): Description of the commit that has been created. Can be empty. - oid (str): + oid (`str`): Commit hash id. Example: `"91c54ad1727ee830252e457677f467be0bfd8a57"`. - pr_url (str, *optional*): + pr_url (`str`, *optional*): Url to the PR that has been created, if any. Populated when `create_pr=True` is passed. """ diff --git a/src/huggingface_hub/utils/_validators.py b/src/huggingface_hub/utils/_validators.py index dd904b7f68..8dc0eba434 100644 --- a/src/huggingface_hub/utils/_validators.py +++ b/src/huggingface_hub/utils/_validators.py @@ -41,10 +41,10 @@ class HFValidationError(ValueError): # type hint meaning "function signature not changed by decorator" -CT = TypeVar("CT") # callable type +CallableT = TypeVar("CallableT") # callable type -def validate_hf_hub_args(fn: CT) -> CT: +def validate_hf_hub_args(fn: CallableT) -> CallableT: """Validate values received as argument for any public method of `huggingface_hub`. The goal of this decorator is to harmonize validation of arguments reused From 1f3d8549226f0a9b4a7a7ca6dd38d9fb3ab1e7af Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Tue, 20 Sep 2022 15:50:20 +0200 Subject: [PATCH 4/8] fix autocomplete test --- tests/test_init_lazy_loading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_init_lazy_loading.py b/tests/test_init_lazy_loading.py index 54b747e479..c66b9db0d5 100644 --- a/tests/test_init_lazy_loading.py +++ b/tests/test_init_lazy_loading.py @@ -31,7 +31,7 @@ def test_autocomplete_on_root_imports(self) -> None: self.assertTrue( signature_list[0] .docstring() - .startswith("create_commit(self, repo_id: str") + .startswith("create_commit(repo_id: str,") ) break else: From f606c58a818310bf38eab0b1d4f1f4e39f705171 Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Tue, 20 Sep 2022 16:09:29 +0200 Subject: [PATCH 5/8] Add pr_revision and pr_url to CommitInfo --- src/huggingface_hub/hf_api.py | 20 ++++++++++++++++++++ tests/test_hf_api.py | 4 ++++ 2 files changed, 24 insertions(+) diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index e4c1cb6a30..6fe1027dac 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -193,6 +193,15 @@ class CommitInfo: pr_url (`str`, *optional*): Url to the PR that has been created, if any. Populated when `create_pr=True` is passed. + + pr_revision (`str`, *optional*): + Revision of the PR that has been created, if any. Populated when + `create_pr=True` is passed. Example: `"refs/pr/1"`. + + pr_num (`int`, *optional*): + Number of the PR discussion that has been created, if any. Populated when + `create_pr=True` is passed. Can be passed as `discussion_num` in + [`get_discussion_details`]. Example: `1`. """ commit_url: str @@ -201,6 +210,17 @@ class CommitInfo: oid: str pr_url: Optional[str] = None + @property + def pr_revision(self) -> Optional[str]: + if self.pr_url is not None: + return _parse_revision_from_pr_url(self.pr_url) + + @property + def pr_num(self) -> Optional[int]: + revision = self.pr_revision + if revision is not None: + return int(revision.split("/")[-1]) + class RepoFile: """ diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index 74395f8f48..d04b0bd73b 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -775,6 +775,8 @@ def test_create_commit_create_pr(self): resp.pr_url, f"{self._api.endpoint}/{USER}/{REPO_NAME}/discussions/1", ) + self.assertEqual(resp.pr_num, 1) + self.assertEqual(resp.pr_revision, "refs/pr/1") with self.assertRaises(HTTPError) as ctx: # Should raise a 404 @@ -845,6 +847,8 @@ def test_create_commit(self): # Check commit info self.assertIsInstance(resp, CommitInfo) self.assertIsNone(resp.pr_url) # No pr created + self.assertIsNone(resp.pr_num) + self.assertIsNone(resp.pr_revision) with self.assertRaises(HTTPError): # Should raise a 404 hf_hub_download( From 6cc1b994db73a2b5788bcf326689f4e0ad73ad6d Mon Sep 17 00:00:00 2001 From: Lucain Date: Thu, 22 Sep 2022 11:03:34 +0200 Subject: [PATCH 6/8] Update tests/test_hf_api.py Co-authored-by: Omar Sanseviero --- tests/test_hf_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index d04b0bd73b..113b19f668 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -764,7 +764,7 @@ def test_create_commit_create_pr(self): # Check commit info self.assertIsInstance(resp, CommitInfo) commit_id = resp.oid - self.assertGreater(len(resp.oid), 0) + self.assertGreater(len(commit_id), 0) self.assertEqual( resp.commit_url, f"{self._api.endpoint}/{USER}/{REPO_NAME}/commit/{commit_id}", From 475336f4cca15c3b2096d5be0307c6cfade1eb33 Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Thu, 22 Sep 2022 11:36:12 +0200 Subject: [PATCH 7/8] nicely handle properties in dataclass --- src/huggingface_hub/hf_api.py | 25 +++++++++++++++---------- tests/test_hf_api.py | 2 ++ 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index 9b347dcebb..c7e9fd4056 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -16,7 +16,7 @@ import re import subprocess import warnings -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import BinaryIO, Dict, Iterable, Iterator, List, Optional, Tuple, Union from urllib.parse import quote @@ -211,16 +211,21 @@ class CommitInfo: oid: str pr_url: Optional[str] = None - @property - def pr_revision(self) -> Optional[str]: - if self.pr_url is not None: - return _parse_revision_from_pr_url(self.pr_url) + # Computed from `pr_url` in `__post_init__` + pr_revision: Optional[str] = field(init=False) + pr_num: Optional[str] = field(init=False) + + def __post_init__(self): + """Populate pr-related fields after initialization. - @property - def pr_num(self) -> Optional[int]: - revision = self.pr_revision - if revision is not None: - return int(revision.split("/")[-1]) + See https://docs.python.org/3.10/library/dataclasses.html#post-init-processing. + """ + if self.pr_url is not None: + self.pr_revision = _parse_revision_from_pr_url(self.pr_url) + self.pr_num = int(self.pr_revision.split("/")[-1]) + else: + self.pr_revision = None + self.pr_num = None class RepoFile: diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index 8151457835..43e5ff6223 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -771,6 +771,8 @@ def test_create_commit_create_pr(self): # Check commit info self.assertIsInstance(resp, CommitInfo) commit_id = resp.oid + self.assertIn("pr_revision='refs/pr/1'", str(resp)) + self.assertIsInstance(commit_id, str) self.assertGreater(len(commit_id), 0) self.assertEqual( resp.commit_url, From 2d459ec7631f2c6b5c0ea2f99ae9fab1db543b70 Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Thu, 22 Sep 2022 11:39:32 +0200 Subject: [PATCH 8/8] make style --- src/huggingface_hub/hf_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index c7e9fd4056..7cb44c8546 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -212,7 +212,7 @@ class CommitInfo: pr_url: Optional[str] = None # Computed from `pr_url` in `__post_init__` - pr_revision: Optional[str] = field(init=False) + pr_revision: Optional[str] = field(init=False) pr_num: Optional[str] = field(init=False) def __post_init__(self):