Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return more information in create_commit output #1066

Merged
merged 10 commits into from
Sep 23, 2022
22 changes: 22 additions & 0 deletions docs/source/package_reference/hf_api.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,30 @@ models = hf_api.list_models()

Using the `HfApi` class directly enables you to set a different endpoint to that of the Hugging Face's Hub.

### HfApi

[[autodoc]] HfApi

### ModelInfo

[[autodoc]] huggingface_hub.hf_api.ModelInfo

### DatasetInfo

[[autodoc]] huggingface_hub.hf_api.DatasetInfo

### SpaceInfo

[[autodoc]] huggingface_hub.hf_api.SpaceInfo

### RepoFile

[[autodoc]] huggingface_hub.hf_api.RepoFile

### CommitInfo

[[autodoc]] huggingface_hub.hf_api.CommitInfo

## `create_commit` API

Below are the supported values for [`CommitOperation`]:
Expand All @@ -56,10 +70,18 @@ It does this using the [`HfFolder`] utility, which saves data at the root of the

Some helpers to filter repositories on the Hub are available in the `huggingface_hub` package.

### DatasetFilter

[[autodoc]] DatasetFilter

### ModelFilter

[[autodoc]] ModelFilter

### DatasetSearchArguments

[[autodoc]] DatasetSearchArguments

### ModelSearchArguments

[[autodoc]] ModelSearchArguments
2 changes: 2 additions & 0 deletions src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
"try_to_load_from_cache",
],
"hf_api": [
"CommitInfo",
"CommitOperation",
"CommitOperationAdd",
"CommitOperationDelete",
Expand Down Expand Up @@ -306,6 +307,7 @@ def __dir__():
from .file_download import hf_hub_download # noqa: F401
from .file_download import hf_hub_url # noqa: F401
from .file_download import try_to_load_from_cache # noqa: F401
from .hf_api import CommitInfo # noqa: F401
from .hf_api import CommitOperation # noqa: F401
from .hf_api import CommitOperationAdd # noqa: F401
from .hf_api import CommitOperationDelete # noqa: F401
Expand Down
83 changes: 71 additions & 12 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import re
import subprocess
import warnings
from dataclasses import dataclass
from os.path import expanduser
from typing import BinaryIO, Dict, Iterable, Iterator, List, Optional, Tuple, Union
from urllib.parse import quote
Expand Down Expand Up @@ -170,6 +171,57 @@ class BlobLfsInfo(TypedDict, total=False):
sha256: str


@dataclass
class CommitInfo:
"""Data structure containing information about a newly created commit.

Returned by [`create_commit`].

Args:
commit_url (`str`):
Url where to find the commit.

commit_message (`str`):
The summary (first line) of the commit that has been created.

commit_description (`str`):
Description of the commit that has been created. Can be empty.

oid (`str`):
Commit hash id. Example: `"91c54ad1727ee830252e457677f467be0bfd8a57"`.

pr_url (`str`, *optional*):
Url to the PR that has been created, if any. Populated when `create_pr=True`
is passed.

pr_revision (`str`, *optional*):
Revision of the PR that has been created, if any. Populated when
`create_pr=True` is passed. Example: `"refs/pr/1"`.

pr_num (`int`, *optional*):
Number of the PR discussion that has been created, if any. Populated when
`create_pr=True` is passed. Can be passed as `discussion_num` in
[`get_discussion_details`]. Example: `1`.
"""

commit_url: str
commit_message: str
commit_description: str
oid: str
pr_url: Optional[str] = None
Wauplin marked this conversation as resolved.
Show resolved Hide resolved

@property
def pr_revision(self) -> Optional[str]:
if self.pr_url is not None:
return _parse_revision_from_pr_url(self.pr_url)

@property
def pr_num(self) -> Optional[int]:
revision = self.pr_revision
if revision is not None:
return int(revision.split("/")[-1])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Maybe consider adding a __str__ method for a nice string representation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@osanseviero Thanks for the comment , I discovered that a good way to handle properties in dataclasses is to use the __post_init__ method and field(init=False) in attribute definition. This way we don't use @property and the dataclass is fully aware of the computed attributes as if they were normal attributes.

See https://docs.python.org/3.10/library/dataclasses.html#post-init-processing and 475336f.


class RepoFile:
"""
Data structure that represents a public file inside a repo, accessible from
Expand Down Expand Up @@ -1953,7 +2005,7 @@ def create_commit(
create_pr: Optional[bool] = None,
num_threads: int = 5,
parent_commit: Optional[str] = None,
) -> Optional[str]:
) -> CommitInfo:
"""
Creates a commit in the given repo, deleting & uploading files as needed.

Expand Down Expand Up @@ -2005,9 +2057,9 @@ def create_commit(
if the repo is updated / committed to concurrently.

Returns:
`str` or `None`:
If `create_pr` is `True`, returns the URL to the newly created Pull Request
on the Hub. Otherwise returns `None`.
[`CommitInfo`]:
Instance of [`CommitInfo`] containing information about the newly
created commit (commit hash, commit url, pr url, commit message,...).

Raises:
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
Expand Down Expand Up @@ -2118,7 +2170,14 @@ def create_commit(
params={"create_pr": "1"} if create_pr else None,
)
hf_raise_for_status(commit_resp, endpoint_name="commit")
return commit_resp.json().get("pullRequestUrl", None)
commit_data = commit_resp.json()
return CommitInfo(
commit_url=commit_data["commitUrl"],
commit_message=commit_message,
commit_description=commit_description,
oid=commit_data["commitOid"],
pr_url=commit_data["pullRequestUrl"] if create_pr else None,
)

@validate_hf_hub_args
def upload_file(
Expand Down Expand Up @@ -2260,7 +2319,7 @@ def upload_file(
path_in_repo=path_in_repo,
)

pr_url = self.create_commit(
commit_info = self.create_commit(
repo_id=repo_id,
repo_type=repo_type,
operations=[operation],
Expand All @@ -2272,8 +2331,8 @@ def upload_file(
parent_commit=parent_commit,
)

if pr_url is not None:
revision = quote(_parse_revision_from_pr_url(pr_url), safe="")
if commit_info.pr_url is not None:
revision = quote(_parse_revision_from_pr_url(commit_info.pr_url), safe="")
if repo_type in REPO_TYPES_URL_PREFIXES:
repo_id = REPO_TYPES_URL_PREFIXES[repo_type] + repo_id
revision = revision if revision is not None else DEFAULT_REVISION
Expand Down Expand Up @@ -2420,7 +2479,7 @@ def upload_folder(
ignore_patterns=ignore_patterns,
)

pr_url = self.create_commit(
commit_info = self.create_commit(
repo_type=repo_type,
repo_id=repo_id,
operations=files_to_add,
Expand All @@ -2432,8 +2491,8 @@ def upload_folder(
parent_commit=parent_commit,
)

if pr_url is not None:
revision = quote(_parse_revision_from_pr_url(pr_url), safe="")
if commit_info.pr_url is not None:
revision = quote(_parse_revision_from_pr_url(commit_info.pr_url), safe="")
if repo_type in REPO_TYPES_URL_PREFIXES:
repo_id = REPO_TYPES_URL_PREFIXES[repo_type] + repo_id
revision = revision if revision is not None else DEFAULT_REVISION
Expand All @@ -2453,7 +2512,7 @@ def delete_file(
commit_description: Optional[str] = None,
create_pr: Optional[bool] = None,
parent_commit: Optional[str] = None,
):
) -> CommitInfo:
"""
Deletes a file in the given repo.

Expand Down
6 changes: 3 additions & 3 deletions src/huggingface_hub/keras_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def push_to_hub_keras(
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
pr_url = api.create_commit(
commit_info = api.create_commit(
repo_type="model",
repo_id=repo_id,
operations=operations,
Expand All @@ -470,8 +470,8 @@ def push_to_hub_keras(
revision = branch
if revision is None:
revision = (
quote(_parse_revision_from_pr_url(pr_url), safe="")
if pr_url is not None
quote(_parse_revision_from_pr_url(commit_info.pr_url), safe="")
if commit_info.pr_url is not None
else DEFAULT_REVISION
)
return f"{api.endpoint}/{repo_id}/tree/{revision}/"
Expand Down
8 changes: 6 additions & 2 deletions src/huggingface_hub/utils/_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import re
from functools import wraps
from itertools import chain
from typing import Callable
from typing import TypeVar


REPO_ID_REGEX = re.compile(
Expand All @@ -40,7 +40,11 @@ class HFValidationError(ValueError):
"""


def validate_hf_hub_args(fn: Callable) -> Callable:
# type hint meaning "function signature not changed by decorator"
CallableT = TypeVar("CallableT") # callable type


def validate_hf_hub_args(fn: CallableT) -> CallableT:
"""Validate values received as argument for any public method of `huggingface_hub`.

The goal of this decorator is to harmonize validation of arguments reused
Expand Down
25 changes: 22 additions & 3 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from huggingface_hub.file_download import cached_download, hf_hub_download
from huggingface_hub.hf_api import (
USERNAME_PLACEHOLDER,
CommitInfo,
DatasetInfo,
DatasetSearchArguments,
HfApi,
Expand Down Expand Up @@ -759,10 +760,24 @@ def test_create_commit_create_pr(self):
token=self._token,
create_pr=True,
)

# Check commit info
self.assertIsInstance(resp, CommitInfo)
commit_id = resp.oid
self.assertGreater(len(resp.oid), 0)
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(
resp.commit_url,
f"{self._api.endpoint}/{USER}/{REPO_NAME}/commit/{commit_id}",
)
self.assertEqual(resp.commit_message, "Test create_commit")
self.assertEqual(resp.commit_description, "")
self.assertEqual(
resp,
resp.pr_url,
f"{self._api.endpoint}/{USER}/{REPO_NAME}/discussions/1",
)
self.assertEqual(resp.pr_num, 1)
self.assertEqual(resp.pr_revision, "refs/pr/1")

with self.assertRaises(HTTPError) as ctx:
# Should raise a 404
hf_hub_download(
Expand Down Expand Up @@ -823,13 +838,17 @@ def test_create_commit(self):
path_or_fileobj=self.tmp_file,
),
]
return_val = self._api.create_commit(
resp = self._api.create_commit(
operations=operations,
commit_message="Test create_commit",
repo_id=f"{USER}/{REPO_NAME}",
token=self._token,
)
self.assertIsNone(return_val)
# Check commit info
self.assertIsInstance(resp, CommitInfo)
self.assertIsNone(resp.pr_url) # No pr created
self.assertIsNone(resp.pr_num)
self.assertIsNone(resp.pr_revision)
with self.assertRaises(HTTPError):
# Should raise a 404
hf_hub_download(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_init_lazy_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_autocomplete_on_root_imports(self) -> None:
self.assertTrue(
signature_list[0]
.docstring()
.startswith("create_commit(self, repo_id: str")
.startswith("create_commit(repo_id: str,")
)
break
else:
Expand Down