diff --git a/docs/source/quick-start.mdx b/docs/source/quick-start.mdx index 0ff56a41b3..b0653742b9 100644 --- a/docs/source/quick-start.mdx +++ b/docs/source/quick-start.mdx @@ -54,21 +54,16 @@ full-length hash instead of the shorter 7-character commit hash: For more details and options, see the API reference for [`hf_hub_download`]. -## Create a repository - -To create and share files to the Hub, you need to have a Hugging Face account. [Create -an account](https://hf.co/join) if you don't already have one, and then sign in to find -your [User Access Token](https://huggingface.co/docs/hub/security-tokens) in -your Settings. The User Access Token is used to authenticate your identity to the Hub. - - - -You can also provide your token to our functions and methods. This way you don't need to -store your token anywhere. +## Login - +In a lot of cases, you must be logged in with a Hugging Face account to interact with +the Hub: download private repos, upload files, create PRs,... +[Create an account](https://hf.co/join) if you don't already have one, and then sign in +to get your [User Access Token](https://huggingface.co/docs/hub/security-tokens) from +your [Settings page](https://huggingface.co/settings/tokens). The User Access Token is +used to authenticate your identity to the Hub. -1. Log in to your Hugging Face account with the following command: +Once you have your User Access Token, run the following command in your terminal: ```bash huggingface-cli login @@ -82,9 +77,25 @@ Or if you prefer to work from a Jupyter or Colaboratory notebook, then login wit >>> notebook_login() ``` -2. Enter your User Access Token to authenticate your identity to the Hub. + + +You can also provide your token to the functions and methods. This way you don't need to +store your token anywhere. + + + + + +Once you are logged in, all requests to the Hub will use your access token by default. +If you want to disable implicit use of your token, you should set the +`HF_HUB_DISABLE_IMPLICIT_TOKEN` environment variable. + + + +## Create a repository -After you've created an account, create a repository with the [`create_repo`] function: +Once you've registered and logged in, create a repository with the [`create_repo`] +function: ```py >>> from huggingface_hub import HfApi @@ -102,6 +113,14 @@ If you want your repository to be private, then: Private repositories will not be visible to anyone except yourself. + + +To create a repository or to push content to the Hub, you must provide a User Access +Token that has the `write` permission. You can choose the permission when creating the +token in your [Settings page](https://huggingface.co/settings/tokens). + + + ## Share files to the Hub Use the [`upload_file`] function to add a file to your newly created repository. You diff --git a/setup.py b/setup.py index 1d04fd9830..be30badf89 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,6 @@ def get_version() -> str: extras["tensorflow"] = ["tensorflow", "pydot", "graphviz"] extras["testing"] = extras["cli"] + [ - "datasets", "isort>=5.5.4", "jedi", "Jinja2", diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index 78c51bd1e9..a413530e0e 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -98,7 +98,6 @@ "CommitOperationDelete", "DatasetSearchArguments", "HfApi", - "HfFolder", "ModelSearchArguments", "change_discussion_status", "comment_discussion", @@ -169,6 +168,7 @@ "CorruptedCacheException", "DeleteCacheStrategy", "HFCacheInfo", + "HfFolder", "logging", "scan_cache_dir", ], @@ -311,7 +311,6 @@ def __dir__(): from .hf_api import CommitOperationDelete # noqa: F401 from .hf_api import DatasetSearchArguments # noqa: F401 from .hf_api import HfApi # noqa: F401 - from .hf_api import HfFolder # noqa: F401 from .hf_api import ModelSearchArguments # noqa: F401 from .hf_api import change_discussion_status # noqa: F401 from .hf_api import comment_discussion # noqa: F401 @@ -368,6 +367,7 @@ def __dir__(): from .utils import CorruptedCacheException # noqa: F401 from .utils import DeleteCacheStrategy # noqa: F401 from .utils import HFCacheInfo # noqa: F401 + from .utils import HfFolder # noqa: F401 from .utils import logging # noqa: F401 from .utils import scan_cache_dir # noqa: F401 from .utils.endpoint_helpers import DatasetFilter # noqa: F401 diff --git a/src/huggingface_hub/_commit_api.py b/src/huggingface_hub/_commit_api.py index 36d5c82db6..9005e45d65 100644 --- a/src/huggingface_hub/_commit_api.py +++ b/src/huggingface_hub/_commit_api.py @@ -14,7 +14,7 @@ from .constants import ENDPOINT from .lfs import UploadInfo, _validate_batch_actions, lfs_upload, post_lfs_batch_info -from .utils import hf_raise_for_status, logging, validate_hf_hub_args +from .utils import build_hf_headers, hf_raise_for_status, logging, validate_hf_hub_args from .utils._typing import Literal @@ -355,7 +355,7 @@ def fetch_upload_modes( If the Hub API returned an HTTP 400 error (bad request) """ endpoint = endpoint if endpoint is not None else ENDPOINT - headers = {"authorization": f"Bearer {token}"} if token is not None else None + headers = build_hf_headers(use_auth_token=token) payload = { "files": [ { diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index c7e72580bb..ec74fb95f2 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -4,7 +4,7 @@ from .constants import DEFAULT_REVISION, HUGGINGFACE_HUB_CACHE, REPO_TYPES from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name -from .hf_api import HfApi, HfFolder +from .hf_api import HfApi from .utils import filter_repo_objects, logging, tqdm, validate_hf_hub_args from .utils._deprecation import _deprecate_arguments @@ -108,18 +108,6 @@ def snapshot_download( if isinstance(cache_dir, Path): cache_dir = str(cache_dir) - if isinstance(use_auth_token, str): - token = use_auth_token - elif use_auth_token: - token = HfFolder.get_token() - if token is None: - raise EnvironmentError( - "You specified use_auth_token=True, but a Hugging Face token was not" - " found." - ) - else: - token = None - if repo_type is None: repo_type = "model" if repo_type not in REPO_TYPES: @@ -167,7 +155,10 @@ def snapshot_download( # if we have internet connection we retrieve the correct folder name from the huggingface api _api = HfApi() repo_info = _api.repo_info( - repo_id=repo_id, repo_type=repo_type, revision=revision, use_auth_token=token + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + use_auth_token=use_auth_token, ) filtered_repo_files = list( filter_repo_objects( diff --git a/src/huggingface_hub/commands/user.py b/src/huggingface_hub/commands/user.py index d5477263dd..e06b13672a 100644 --- a/src/huggingface_hub/commands/user.py +++ b/src/huggingface_hub/commands/user.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import subprocess from argparse import ArgumentParser from getpass import getpass @@ -24,10 +23,10 @@ REPO_TYPES_URL_PREFIXES, SPACES_SDK_TYPES, ) -from huggingface_hub.hf_api import HfApi, HfFolder +from huggingface_hub.hf_api import HfApi from requests.exceptions import HTTPError -from ..utils import run_subprocess +from ..utils import HfFolder, run_subprocess from ._cli_utils import ANSI @@ -308,13 +307,16 @@ def login_token_event(t): # Erase token and clear value to make sure it's not saved in the notebook. token_widget.value = "" clear_output() - _login(HfApi(), token=token) + _login(token=token) token_finish_button.on_click(login_token_event) -def _login(hf_api, token=None): - token, name = hf_api._validate_or_retrieve_token(token) +def _login(hf_api: HfApi, token: str) -> None: + if token.startswith("api_org"): + raise ValueError("You must use your personal account token.") + if not hf_api._is_valid_token(token=token): + raise ValueError("Invalid token passed!") hf_api.set_access_token(token) HfFolder.save_token(token) print("Login successful") diff --git a/src/huggingface_hub/fastai_utils.py b/src/huggingface_hub/fastai_utils.py index 26d02cdbea..0fd6ff686d 100644 --- a/src/huggingface_hub/fastai_utils.py +++ b/src/huggingface_hub/fastai_utils.py @@ -7,7 +7,7 @@ from packaging import version -from huggingface_hub import hf_api, snapshot_download +from huggingface_hub import snapshot_download from huggingface_hub.constants import CONFIG_NAME from huggingface_hub.file_download import ( _PY_VERSION, @@ -437,7 +437,6 @@ def push_to_hub_fastai( """ _check_fastai_fastcore_versions() - token, _ = hf_api._validate_or_retrieve_token(token) api = HfApi(endpoint=api_endpoint) api.create_repo( repo_id=repo_id, diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index fc949d73fa..374c3c702a 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -33,10 +33,10 @@ REPO_TYPES, REPO_TYPES_URL_PREFIXES, ) -from .hf_api import HfFolder from .utils import ( EntryNotFoundError, LocalEntryNotFoundError, + build_hf_headers, hf_raise_for_status, http_backoff, logging, @@ -670,23 +670,12 @@ def cached_download( os.makedirs(cache_dir, exist_ok=True) - headers = { - "user-agent": http_user_agent( - library_name=library_name, - library_version=library_version, - user_agent=user_agent, - ) - } - if isinstance(use_auth_token, str): - headers["authorization"] = f"Bearer {use_auth_token}" - elif use_auth_token: - token = HfFolder.get_token() - if token is None: - raise EnvironmentError( - "You specified use_auth_token=True, but a huggingface token was not" - " found." - ) - headers["authorization"] = f"Bearer {token}" + headers = build_hf_headers(use_auth_token=use_auth_token) + headers["user-agent"] = http_user_agent( + library_name=library_name, + library_version=library_version, + user_agent=user_agent, + ) url_to_download = url etag = None @@ -1115,23 +1104,12 @@ def hf_hub_download( url = hf_hub_url(repo_id, filename, repo_type=repo_type, revision=revision) - headers = { - "user-agent": http_user_agent( - library_name=library_name, - library_version=library_version, - user_agent=user_agent, - ) - } - if isinstance(use_auth_token, str): - headers["authorization"] = f"Bearer {use_auth_token}" - elif use_auth_token: - token = HfFolder.get_token() - if token is None: - raise EnvironmentError( - "You specified use_auth_token=True, but a huggingface token was not" - " found." - ) - headers["authorization"] = f"Bearer {token}" + headers = build_hf_headers(use_auth_token=use_auth_token) + headers["user-agent"] = http_user_agent( + library_name=library_name, + library_version=library_version, + user_agent=user_agent, + ) url_to_download = url etag = None @@ -1433,18 +1411,7 @@ def get_hf_file_metadata( A [`HfFileMetadata`] object containing metadata such as location, etag and commit_hash. """ - # TODO: helper to get headers from `use_auth_token` (copy-pasted several times) - headers = {} - if isinstance(use_auth_token, str): - headers["authorization"] = f"Bearer {use_auth_token}" - elif use_auth_token: - token = HfFolder.get_token() - if token is None: - raise EnvironmentError( - "You specified use_auth_token=True, but a huggingface token was not" - " found." - ) - headers["authorization"] = f"Bearer {token}" + headers = build_hf_headers(use_auth_token=use_auth_token) # Retrieve metadata r = _request_wrapper( diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index e0758c0665..418a2680f1 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -16,7 +16,6 @@ import re import subprocess import warnings -from os.path import expanduser from typing import BinaryIO, Dict, Iterable, Iterator, List, Optional, Tuple, Union from urllib.parse import quote @@ -50,14 +49,16 @@ REPO_TYPES_URL_PREFIXES, SPACES_SDK_TYPES, ) +from .utils import HfFolder # noqa: F401 # imported for backward compatibility from .utils import ( + build_hf_headers, filter_repo_objects, hf_raise_for_status, logging, parse_datetime, validate_hf_hub_args, ) -from .utils._deprecation import _deprecate_positional_args +from .utils._deprecation import _deprecate_arguments, _deprecate_positional_args from .utils._typing import Literal, TypedDict from .utils.endpoint_helpers import ( AttributeDictionary, @@ -613,15 +614,14 @@ def whoami(self, token: Optional[str] = None) -> Dict: Hugging Face token. Will default to the locally saved token if not provided. """ - if token is None: - token = HfFolder.get_token() - if token is None: - raise ValueError( - "You need to pass a valid `token` or login by using `huggingface-cli" - " login`" - ) - path = f"{self.endpoint}/api/whoami-v2" - r = requests.get(path, headers={"authorization": f"Bearer {token}"}) + r = requests.get( + f"{self.endpoint}/api/whoami-v2", + headers=build_hf_headers( + # If `token` is provided and not `None`, it will be used by default. + # Otherwise, the token must be retrieved from cache or env variable. + use_auth_token=(token or True), + ), + ) try: hf_raise_for_status(r) except HTTPError as e: @@ -632,7 +632,7 @@ def whoami(self, token: Optional[str] = None) -> Dict: ) from e return r.json() - def _is_valid_token(self, token: str): + def _is_valid_token(self, token: str) -> bool: """ Determines whether `token` is a valid token or not. @@ -649,74 +649,6 @@ def _is_valid_token(self, token: str): except HTTPError: return False - def _validate_or_retrieve_token( - self, - token: Optional[Union[str, bool]] = None, - name: Optional[str] = None, - function_name: Optional[str] = None, - ): - """ - Retrieves and validates stored token or validates passed token. - Args: - token (``str``, `optional`): - Hugging Face token. Will default to the locally saved token if not provided. - name (``str``, `optional`): - Name of the repository. This is deprecated in favor of repo_id and will be removed in v0.8. - function_name (``str``, `optional`): - If _validate_or_retrieve_token is called from a function, name of that function to be passed inside deprecation warning. - Returns: - Validated token and the name of the repository. - Raises: - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) - If the token is not passed and there's no token saved locally. - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) - If organization token or invalid token is passed. - """ - if token is None or token is True: - token = HfFolder.get_token() - if token is None: - raise EnvironmentError( - "You need to provide a `token` or be logged in to Hugging Face with" - " `huggingface-cli login`." - ) - if name is not None: - if self._is_valid_token(name): - # TODO(0.6) REMOVE - warnings.warn( - f"`{function_name}` now takes `token` as an optional positional" - " argument. Be sure to adapt your code!", - FutureWarning, - ) - token, name = name, token - if isinstance(token, str): - if token.startswith("api_org"): - raise ValueError("You must use your personal account token.") - if not self._is_valid_token(token): - raise ValueError("Invalid token passed!") - - return token, name - - def _build_auth_headers( - self, *, token: Optional[str], use_auth_token: Optional[Union[str, bool]] - ) -> Dict[str, str]: - """Helper to build Authorization header from kwargs. To be removed in 0.12.0 when `token` is deprecated.""" - if token is not None: - warnings.warn( - "`token` is deprecated and will be removed in 0.12.0. Use" - " `use_auth_token` instead.", - FutureWarning, - ) - - auth_token = None - if use_auth_token is None and token is None: - # To maintain backwards-compatibility. To be removed in 0.12.0 - auth_token = HfFolder.get_token() - elif use_auth_token: - auth_token, _ = self._validate_or_retrieve_token(use_auth_token) - else: - auth_token = token - return {"authorization": f"Bearer {auth_token}"} if auth_token else {} - @staticmethod def set_access_token(access_token: str): """ @@ -861,10 +793,7 @@ def list_models( ``` """ path = f"{self.endpoint}/api/models" - headers = {} - if use_auth_token: - token, name = self._validate_or_retrieve_token(use_auth_token) - headers["authorization"] = f"Bearer {token}" + headers = build_hf_headers(use_auth_token=use_auth_token) params = {} if filter is not None: if isinstance(filter, ModelFilter): @@ -1059,10 +988,7 @@ def list_datasets( ``` """ path = f"{self.endpoint}/api/datasets" - headers = {} - if use_auth_token: - token, name = self._validate_or_retrieve_token(use_auth_token) - headers["authorization"] = f"Bearer {token}" + headers = build_hf_headers(use_auth_token=use_auth_token) params = {} if filter is not None: if isinstance(filter, DatasetFilter): @@ -1199,10 +1125,7 @@ def list_spaces( `List[SpaceInfo]`: a list of [`huggingface_hub.hf_api.SpaceInfo`] objects """ path = f"{self.endpoint}/api/spaces" - headers = {} - if use_auth_token: - token, name = self._validate_or_retrieve_token(use_auth_token) - headers["authorization"] = f"Bearer {token}" + headers = build_hf_headers(use_auth_token=use_auth_token) params = {} if filter is not None: params.update({"filter": filter}) @@ -1232,6 +1155,7 @@ def list_spaces( return [SpaceInfo(**x) for x in d] @validate_hf_hub_args + @_deprecate_arguments(version="0.12", deprecated_args={"token"}) def model_info( self, repo_id: str, @@ -1286,7 +1210,7 @@ def model_info( """ - headers = self._build_auth_headers(token=token, use_auth_token=use_auth_token) + headers = build_hf_headers(use_auth_token=token or use_auth_token) path = ( f"{self.endpoint}/api/models/{repo_id}" if revision is None @@ -1310,6 +1234,7 @@ def model_info( return ModelInfo(**d) @validate_hf_hub_args + @_deprecate_arguments(version="0.12", deprecated_args={"token"}) def dataset_info( self, repo_id: str, @@ -1360,8 +1285,7 @@ def dataset_info( """ - headers = self._build_auth_headers(token=token, use_auth_token=use_auth_token) - + headers = build_hf_headers(use_auth_token=token or use_auth_token) path = ( f"{self.endpoint}/api/datasets/{repo_id}" if revision is None @@ -1379,6 +1303,7 @@ def dataset_info( return DatasetInfo(**d) @validate_hf_hub_args + @_deprecate_arguments(version="0.12", deprecated_args={"token"}) def space_info( self, repo_id: str, @@ -1429,7 +1354,7 @@ def space_info( """ - headers = self._build_auth_headers(token=token, use_auth_token=use_auth_token) + headers = build_hf_headers(use_auth_token=token or use_auth_token) path = ( f"{self.endpoint}/api/spaces/{repo_id}" if revision is None @@ -1611,10 +1536,6 @@ def create_repo( path = f"{self.endpoint}/api/repos/create" - token, name = self._validate_or_retrieve_token( - token, name, function_name="create_repo" - ) - checked_name = repo_type_and_id_from_hf_id(name) if ( @@ -1671,11 +1592,8 @@ def create_repo( if getattr(self, "_lfsmultipartthresh", None): json["lfsmultipartthresh"] = self._lfsmultipartthresh - r = requests.post( - path, - headers={"authorization": f"Bearer {token}"}, - json=json, - ) + headers = build_hf_headers(use_auth_token=token, is_write_action=True) + r = requests.post(path, headers=headers, json=json) try: hf_raise_for_status(r) @@ -1736,10 +1654,6 @@ def delete_repo( path = f"{self.endpoint}/api/repos/delete" - token, name = self._validate_or_retrieve_token( - token, name, function_name="delete_repo" - ) - checked_name = repo_type_and_id_from_hf_id(name) if ( @@ -1779,11 +1693,8 @@ def delete_repo( if repo_type is not None: json["type"] = repo_type - r = requests.delete( - path, - headers={"authorization": f"Bearer {token}"}, - json=json, - ) + headers = build_hf_headers(use_auth_token=token, is_write_action=True) + r = requests.delete(path, headers=headers, json=json) hf_raise_for_status(r) @validate_hf_hub_args @@ -1837,10 +1748,6 @@ def update_repo_visibility( organization, name = repo_id.split("/") if "/" in repo_id else (None, repo_id) - token, name = self._validate_or_retrieve_token( - token, name, function_name="update_repo_visibility" - ) - if organization is None: namespace = self.whoami(token)["name"] else: @@ -1849,14 +1756,10 @@ def update_repo_visibility( if repo_type is None: repo_type = REPO_TYPE_MODEL # default repo type - path = f"{self.endpoint}/api/{repo_type}s/{namespace}/{name}/settings" - - json = {"private": private} - r = requests.put( - path, - headers={"authorization": f"Bearer {token}"}, - json=json, + url=f"{self.endpoint}/api/{repo_type}s/{namespace}/{name}/settings", + headers=build_hf_headers(use_auth_token=token, is_write_action=True), + json={"private": private}, ) hf_raise_for_status(r) return r.json() @@ -1900,9 +1803,6 @@ def move_repo( """ - - token, name = self._validate_or_retrieve_token(token) - if len(from_id.split("/")) != 2: raise ValueError( f"Invalid repo_id: {from_id}. It should have a namespace" @@ -1918,11 +1818,8 @@ def move_repo( json = {"fromRepo": from_id, "toRepo": to_id, "type": repo_type} path = f"{self.endpoint}/api/repos/move" - r = requests.post( - path, - headers={"authorization": f"Bearer {token}"}, - json=json, - ) + headers = build_hf_headers(use_auth_token=token, is_write_action=True) + r = requests.post(path, headers=headers, json=json) try: hf_raise_for_status(r) except HTTPError as e: @@ -2046,7 +1943,6 @@ def create_commit( repo_type = repo_type if repo_type is not None else REPO_TYPE_MODEL if repo_type not in REPO_TYPES: raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") - token, name = self._validate_or_retrieve_token(token) revision = ( quote(revision, safe="") if revision is not None else DEFAULT_REVISION ) @@ -2111,9 +2007,10 @@ def create_commit( ) commit_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/commit/{revision}" + headers = build_hf_headers(use_auth_token=token, is_write_action=True) commit_resp = requests.post( url=commit_url, - headers={"Authorization": f"Bearer {token}"}, + headers=headers, json=commit_payload, params={"create_pr": "1"} if create_pr else None, ) @@ -2566,14 +2463,11 @@ def get_full_repo_name( FutureWarning, ) - if token is None and use_auth_token: - token, name = self._validate_or_retrieve_token(use_auth_token) - if organization is None: if "/" in model_id: username = model_id.split("/")[0] else: - username = self.whoami(token=token)["name"] + username = self.whoami(token=token or use_auth_token)["name"] return f"{username}/{model_id}" else: return f"{organization}/{model_id}" @@ -2623,16 +2517,14 @@ def get_repo_discussions( raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") if repo_type is None: repo_type = REPO_TYPE_MODEL - repo_id = f"{repo_type}s/{repo_id}" - if token is None: - token = HfFolder.get_token() + + headers = build_hf_headers(use_auth_token=token) def _fetch_discussion_page(page_index: int): - path = f"{self.endpoint}/api/{repo_id}/discussions?p={page_index}" - resp = requests.get( - path, - headers={"Authorization": f"Bearer {token}"} if token else None, + path = ( + f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions?p={page_index}" ) + resp = requests.get(path, headers=headers) hf_raise_for_status(resp) paginated_discussions = resp.json() total = paginated_discussions["count"] @@ -2704,17 +2596,12 @@ def get_discussion_details( raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") if repo_type is None: repo_type = REPO_TYPE_MODEL - repo_id = f"{repo_type}s/{repo_id}" - if token is None: - token = HfFolder.get_token() - path = f"{self.endpoint}/api/{repo_id}/discussions/{discussion_num}" - - resp = requests.get( - path, - params={"diff": "1"}, - headers={"Authorization": f"Bearer {token}"} if token else None, + path = ( + f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions/{discussion_num}" ) + headers = build_hf_headers(use_auth_token=token) + resp = requests.get(path, params={"diff": "1"}, headers=headers) hf_raise_for_status(resp) discussion_details = resp.json() @@ -2805,8 +2692,7 @@ def create_discussion( raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") if repo_type is None: repo_type = REPO_TYPE_MODEL - full_repo_id = f"{repo_type}s/{repo_id}" - token, _ = self._validate_or_retrieve_token(token=token) + if description is not None: description = description.strip() description = ( @@ -2819,14 +2705,15 @@ def create_discussion( ) ) + headers = build_hf_headers(use_auth_token=token, is_write_action=True) resp = requests.post( - f"{self.endpoint}/api/{full_repo_id}/discussions", + f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions", json={ "title": title.strip(), "description": description, "pullRequest": pull_request, }, - headers={"Authorization": f"Bearer {token}"}, + headers=headers, ) hf_raise_for_status(resp) num = resp.json()["num"] @@ -2913,15 +2800,11 @@ def _post_discussion_changes( if repo_type is None: repo_type = REPO_TYPE_MODEL repo_id = f"{repo_type}s/{repo_id}" - token, _ = self._validate_or_retrieve_token(token=token) path = f"{self.endpoint}/api/{repo_id}/discussions/{discussion_num}/{resource}" - resp = requests.post( - path, - headers={"Authorization": f"Bearer {token}"}, - json=body, - ) + headers = build_hf_headers(use_auth_token=token, is_write_action=True) + resp = requests.post(path, headers=headers, json=body) hf_raise_for_status(resp) return resp @@ -3316,54 +3199,6 @@ def hide_discussion_comment( return deserialize_event(resp.json()["updatedComment"]) -class HfFolder: - path_token = expanduser("~/.huggingface/token") - - @classmethod - def save_token(cls, token): - """ - Save token, creating folder as needed. - - Args: - token (`str`): - The token to save to the [`HfFolder`] - """ - os.makedirs(os.path.dirname(cls.path_token), exist_ok=True) - with open(cls.path_token, "w+") as f: - f.write(token) - - @classmethod - def get_token(cls) -> Optional[str]: - """ - Get token or None if not existent. - - Note that a token can be also provided using the - `HUGGING_FACE_HUB_TOKEN` environment variable. - - Returns: - `str` or `None`: The token, `None` if it doesn't exist. - - """ - token: Optional[str] = os.environ.get("HUGGING_FACE_HUB_TOKEN") - if token is None: - try: - with open(cls.path_token, "r") as f: - token = f.read() - except FileNotFoundError: - pass - return token - - @classmethod - def delete_token(cls): - """ - Deletes the token from storage. Does not fail if token does not exist. - """ - try: - os.remove(cls.path_token) - except FileNotFoundError: - pass - - def _prepare_upload_folder_commit( folder_path: str, path_in_repo: str, @@ -3464,5 +3299,3 @@ def _parse_revision_from_pr_url(pr_url: str) -> str: edit_discussion_comment = api.edit_discussion_comment rename_discussion = api.rename_discussion merge_pull_request = api.merge_pull_request - -_validate_or_retrieve_token = api._validate_or_retrieve_token diff --git a/src/huggingface_hub/hub_mixin.py b/src/huggingface_hub/hub_mixin.py index ff0f292c5c..fd24fcd546 100644 --- a/src/huggingface_hub/hub_mixin.py +++ b/src/huggingface_hub/hub_mixin.py @@ -5,13 +5,12 @@ from typing import Dict, List, Optional, Union import requests -from huggingface_hub import hf_api from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME from .file_download import hf_hub_download, is_torch_available -from .hf_api import HfApi, HfFolder +from .hf_api import HfApi from .repository import Repository -from .utils import logging, validate_hf_hub_args +from .utils import HfFolder, logging, validate_hf_hub_args from .utils._deprecation import _deprecate_arguments, _deprecate_positional_args @@ -326,9 +325,7 @@ def push_to_hub( # If the repo id is set, it means we use the new version using HTTP endpoint # (introduced in v0.9). if repo_id is not None: - token, _ = hf_api._validate_or_retrieve_token(token) api = HfApi(endpoint=api_endpoint) - api.create_repo( repo_id=repo_id, repo_type="model", diff --git a/src/huggingface_hub/keras_mixin.py b/src/huggingface_hub/keras_mixin.py index dfcd90758b..9872bf7502 100644 --- a/src/huggingface_hub/keras_mixin.py +++ b/src/huggingface_hub/keras_mixin.py @@ -9,12 +9,7 @@ from urllib.parse import quote import yaml -from huggingface_hub import ( - CommitOperationDelete, - ModelHubMixin, - hf_api, - snapshot_download, -) +from huggingface_hub import CommitOperationDelete, ModelHubMixin, snapshot_download from huggingface_hub.file_download import ( get_tf_version, is_graphviz_available, @@ -23,14 +18,9 @@ ) from .constants import CONFIG_NAME, DEFAULT_REVISION -from .hf_api import ( - HfApi, - HfFolder, - _parse_revision_from_pr_url, - _prepare_upload_folder_commit, -) +from .hf_api import HfApi, _parse_revision_from_pr_url, _prepare_upload_folder_commit from .repository import Repository -from .utils import logging, validate_hf_hub_args +from .utils import HfFolder, logging, validate_hf_hub_args from .utils._deprecation import _deprecate_arguments, _deprecate_positional_args @@ -409,9 +399,7 @@ def push_to_hub_keras( The url of the commit of your model in the given repository. """ if repo_id is not None: - token, _ = hf_api._validate_or_retrieve_token(token) api = HfApi(endpoint=api_endpoint) - api.create_repo( repo_id=repo_id, repo_type="model", diff --git a/src/huggingface_hub/repository.py b/src/huggingface_hub/repository.py index b132915d71..f10b8c6417 100644 --- a/src/huggingface_hub/repository.py +++ b/src/huggingface_hub/repository.py @@ -15,9 +15,9 @@ from huggingface_hub.repocard import metadata_load, metadata_save from requests.exceptions import HTTPError -from .hf_api import HfApi, HfFolder, repo_type_and_id_from_hf_id +from .hf_api import HfApi, repo_type_and_id_from_hf_id from .lfs import LFS_MULTIPART_UPLOAD_COMMAND -from .utils import logging, run_subprocess, tqdm +from .utils import HfFolder, logging, run_subprocess, tqdm from .utils._deprecation import _deprecate_arguments, _deprecate_method diff --git a/src/huggingface_hub/utils/__init__.py b/src/huggingface_hub/utils/__init__.py index cf5176374c..5faa622778 100644 --- a/src/huggingface_hub/utils/__init__.py +++ b/src/huggingface_hub/utils/__init__.py @@ -35,6 +35,8 @@ RevisionNotFoundError, hf_raise_for_status, ) +from ._headers import build_hf_headers +from ._hf_folder import HfFolder from ._http import http_backoff from ._paths import filter_repo_objects from ._subprocess import run_subprocess diff --git a/src/huggingface_hub/utils/_deprecation.py b/src/huggingface_hub/utils/_deprecation.py index ab8a5e135e..3cb14e7243 100644 --- a/src/huggingface_hub/utils/_deprecation.py +++ b/src/huggingface_hub/utils/_deprecation.py @@ -76,8 +76,13 @@ def inner_f(*args, **kwargs): for _, parameter in zip(args, sig.parameters.values()): if parameter.name in deprecated_args: used_deprecated_args.append(parameter.name) - for kwarg_name in kwargs: - if kwarg_name in deprecated_args: + for kwarg_name, kwarg_value in kwargs.items(): + if ( + # If argument is deprecated but still used + kwarg_name in deprecated_args + # And then the value is not the default value + and kwarg_value != sig.parameters[kwarg_name].default + ): used_deprecated_args.append(kwarg_name) # Warn and proceed diff --git a/src/huggingface_hub/utils/_headers.py b/src/huggingface_hub/utils/_headers.py new file mode 100644 index 0000000000..d51035185f --- /dev/null +++ b/src/huggingface_hub/utils/_headers.py @@ -0,0 +1,139 @@ +# coding=utf-8 +# Copyright 2022-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains utilities to handle headers to send in calls to Huggingface Hub.""" +import os +from typing import Dict, Optional, Union + +from ._hf_folder import HfFolder + + +def build_hf_headers( + *, use_auth_token: Optional[Union[bool, str]] = None, is_write_action: bool = False +) -> Dict[str, str]: + """ + Build headers dictionary to send in a HF Hub call. + + By default, authorization token is always provided either from argument (explicit + use) or retrieved from the cache (implicit use). To explicitly avoid sending the + token to the Hub, set `use_auth_token=False` or set the `HF_HUB_DISABLE_IMPLICIT_TOKEN` + environment variable. + + In case of an API call that requires write access, an error is thrown if token is + `None` or token is an organization token (starting with `"api_org***"`). + + Args: + use_auth_token (`str`, `bool`, *optional*): + The token to be sent in authorization header for the Hub call: + - if a string, it is used as the Hugging Face token + - if `True`, the token is read from the machine (cache or env variable) + - if `False`, authorization header is not set + - if `None`, the token is read from the machine only except if + `HF_HUB_DISABLE_IMPLICIT_TOKEN` env variable is set. + is_write_action (`bool`, default to `False`): + Set to True if the API call requires a write access. If `True`, the token + will be validated (cannot be `None`, cannot start by `"api_org***"`). + + Returns: + A `Dict` of headers to pass in your API call. + + Example: + ```py + >>> build_hf_headers(use_auth_token="hf_***") # explicit token + {"authorization": "Bearer hf_***"} + + >>> build_hf_headers(use_auth_token=True) # explicitly use cached token + {"authorization": "Bearer hf_***"} + + >>> build_hf_headers(use_auth_token=False) # explicitly don't use cached token + {} + + >>> build_hf_headers() # implicit use of the cached token + {"authorization": "Bearer hf_***"} + + # HF_HUB_DISABLE_IMPLICIT_TOKEN=True # to set as env variable + >>> build_hf_headers() # token is not sent + {} + + >>> build_hf_headers(use_auth_token="api_org_***", is_write_action=True) + ValueError: You must use your personal account token for write-access methods. + ``` + + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If organization token is passed and "write" access is required. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If "write" access is required but token is not passed and not saved locally. + [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + If `use_auth_token=True` but token is not saved locally. + """ + # Get auth token to send + token_to_send = _get_token_to_send(use_auth_token) + _validate_token_to_send(token_to_send, is_write_action=is_write_action) + + # Combine headers + headers = {} + if token_to_send is not None: + headers["authorization"] = f"Bearer {token_to_send}" + # TODO: add user agent in headers + return headers + + +def _get_token_to_send(use_auth_token: Optional[Union[bool, str]]) -> Optional[str]: + """Select the token to send from either `use_auth_token` or the cache.""" + # Case token is explicitly provided + if isinstance(use_auth_token, str): + return use_auth_token + + # Case token is explicitly forbidden + if use_auth_token is False: + return None + + # Token is not provided: we get it from local cache + cached_token = HfFolder().get_token() + + # Case token is explicitly required + if use_auth_token is True: + if cached_token is None: + raise EnvironmentError( + "Token is required (`use_auth_token=True`), but no token found. You" + " need to provide a token or be logged in to Hugging Face with" + " `huggingface-cli login` or `notebook_login`. See" + " https://huggingface.co/settings/tokens." + ) + return cached_token + + # Case implicit use of the token is forbidden by env variable + if os.environ.get("HF_HUB_DISABLE_IMPLICIT_TOKEN"): + return None + + # Otherwise: we use the cached token as the user has not explicitly forbidden it + return cached_token + + +def _validate_token_to_send(token: Optional[str], is_write_action: bool) -> None: + if is_write_action: + if token is None: + raise ValueError( + "Token is required (write-access action) but no token found. You need" + " to provide a token or be logged in to Hugging Face with" + " `huggingface-cli login` or `notebook_login`. See" + " https://huggingface.co/settings/tokens." + ) + if token.startswith("api_org"): + raise ValueError( + "You must use your personal account token for write-access methods. To" + " generate a write-access token, go to" + " https://huggingface.co/settings/tokens" + ) diff --git a/src/huggingface_hub/utils/_hf_folder.py b/src/huggingface_hub/utils/_hf_folder.py new file mode 100644 index 0000000000..c9182ada0a --- /dev/null +++ b/src/huggingface_hub/utils/_hf_folder.py @@ -0,0 +1,64 @@ +# coding=utf-8 +# Copyright 2022-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contain helper class to retrieve/store token from/to local cache.""" +import os +from pathlib import Path +from typing import Optional + + +class HfFolder: + path_token = Path("~/.huggingface/token").expanduser() + + @classmethod + def save_token(cls, token: str) -> None: + """ + Save token, creating folder as needed. + + Args: + token (`str`): + The token to save to the [`HfFolder`] + """ + cls.path_token.parent.mkdir(exist_ok=True) + with cls.path_token.open("w+") as f: + f.write(token) + + @classmethod + def get_token(cls) -> Optional[str]: + """ + Get token or None if not existent. + + Note that a token can be also provided using the + `HUGGING_FACE_HUB_TOKEN` environment variable. + + Returns: + `str` or `None`: The token, `None` if it doesn't exist. + """ + token: Optional[str] = os.environ.get("HUGGING_FACE_HUB_TOKEN") + if token is None: + try: + return cls.path_token.read_text() + except FileNotFoundError: + pass + return token + + @classmethod + def delete_token(cls) -> None: + """ + Deletes the token from storage. Does not fail if token does not exist. + """ + try: + cls.path_token.unlink() + except FileNotFoundError: + pass diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000..3c871da76c --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,22 @@ +from typing import Generator + +import pytest + +from huggingface_hub import HfFolder + + +@pytest.fixture(autouse=True, scope="session") +def clean_hf_folder_token_for_tests() -> Generator: + """Clean token stored on machine before all tests and reset it back at the end. + + Useful to avoid token deletion when running tests locally. + """ + # Remove registered token + token = HfFolder().get_token() + HfFolder().delete_token() + + yield # Run all tests + + # Set back token once all tests have passed + if token is not None: + HfFolder().save_token(token) diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index 3bd41e2a99..66c73006ea 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -23,6 +23,7 @@ from functools import partial from io import BytesIO from typing import List +from unittest.mock import Mock, patch from urllib.parse import quote import pytest @@ -43,7 +44,6 @@ DatasetInfo, DatasetSearchArguments, HfApi, - HfFolder, MetricInfo, ModelInfo, ModelSearchArguments, @@ -53,7 +53,7 @@ read_from_credential_store, repo_type_and_id_from_hf_id, ) -from huggingface_hub.utils import RepositoryNotFoundError, logging +from huggingface_hub.utils import HfFolder, RepositoryNotFoundError, logging from huggingface_hub.utils.endpoint_helpers import ( DatasetFilter, ModelFilter, @@ -181,7 +181,7 @@ def test_repo_id_no_warning(): class HfApiEndpointsTest(HfApiCommonTestWithLogin): - def test_whoami(self): + def test_whoami_with_passing_token(self): info = self._api.whoami(token=self._token) self.assertEqual(info["name"], USER) self.assertEqual(info["fullname"], FULL_NAME) @@ -189,6 +189,13 @@ def test_whoami(self): valid_org = [org for org in info["orgs"] if org["name"] == "valid_org"][0] self.assertIsInstance(valid_org["apiToken"], str) + @patch("huggingface_hub.utils._headers.HfFolder") + def test_whoami_with_implicit_token_from_login(self, mock_HfFolder: Mock) -> None: + """Test using `whoami` after a `huggingface-cli login`.""" + mock_HfFolder().get_token.return_value = self._token + info = self._api.whoami() + self.assertEqual(info["name"], USER) + @retry_endpoint def test_delete_repo_error_message(self): # test for #751 diff --git a/tests/test_inference_api.py b/tests/test_inference_api.py index d92aacbd92..49019ce955 100644 --- a/tests/test_inference_api.py +++ b/tests/test_inference_api.py @@ -11,11 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import unittest -import datasets - +from huggingface_hub import hf_hub_download from huggingface_hub.inference_api import InferenceApi from .testing_utils import with_production_testing @@ -67,13 +65,12 @@ def test_inference_with_dict_inputs(self): @with_production_testing def test_inference_with_audio(self): api = InferenceApi("facebook/wav2vec2-base-960h") - with self.assertWarns(FutureWarning): - dataset = datasets.load_dataset( - "patrickvonplaten/librispeech_asr_dummy", - "clean", - split="validation", - ) - data = self.read(dataset["file"][0]) + file = hf_hub_download( + repo_id="hf-internal-testing/dummy-flac-single-example", + repo_type="dataset", + filename="example.flac", + ) + data = self.read(file) result = api(data=data) self.assertIsInstance(result, dict) self.assertTrue("text" in result, f"We received {result} instead") @@ -81,13 +78,10 @@ def test_inference_with_audio(self): @with_production_testing def test_inference_with_image(self): api = InferenceApi("google/vit-base-patch16-224") - with self.assertWarns(FutureWarning): - dataset = datasets.load_dataset( - "Narsil/image_dummy", - "image", - split="test", - ) - data = self.read(dataset["file"][0]) + file = hf_hub_download( + repo_id="Narsil/image_dummy", repo_type="dataset", filename="lena.png" + ) + data = self.read(file) result = api(data=data) self.assertIsInstance(result, list) for classification in result: diff --git a/tests/test_snapshot_download.py b/tests/test_snapshot_download.py index 63069482fd..d6f6342d5c 100644 --- a/tests/test_snapshot_download.py +++ b/tests/test_snapshot_download.py @@ -8,8 +8,7 @@ import requests from huggingface_hub import HfApi, Repository, snapshot_download -from huggingface_hub.hf_api import HfFolder -from huggingface_hub.utils import logging +from huggingface_hub.utils import HfFolder, logging from .testing_constants import ENDPOINT_STAGING, TOKEN, USER from .testing_utils import ( diff --git a/tests/test_utils_deprecation.py b/tests/test_utils_deprecation.py index dfd4a28717..ccaf6f78f0 100644 --- a/tests/test_utils_deprecation.py +++ b/tests/test_utils_deprecation.py @@ -52,6 +52,9 @@ def dummy_b_c_deprecated(a, b="b", c="c"): dummy_b_c_deprecated("A") + dummy_b_c_deprecated("A", b="b") + dummy_b_c_deprecated("A", b="b", c="c") + with pytest.warns(FutureWarning): dummy_c_deprecated("A", "B", "C") @@ -79,7 +82,7 @@ def dummy_deprecated_default_message(a: str = "a") -> None: # Default message with pytest.warns(FutureWarning) as record: - dummy_deprecated_default_message(a="a") + dummy_deprecated_default_message(a="A") self.assertEqual(len(record), 1) self.assertEqual( record[0].message.args[0], @@ -100,7 +103,7 @@ def dummy_deprecated_custom_message(a: str = "a") -> None: # Custom message with pytest.warns(FutureWarning) as record: - dummy_deprecated_custom_message(a="a") + dummy_deprecated_custom_message(a="A") self.assertEqual(len(record), 1) self.assertEqual( record[0].message.args[0], diff --git a/tests/test_utils_headers.py b/tests/test_utils_headers.py new file mode 100644 index 0000000000..b8cb145757 --- /dev/null +++ b/tests/test_utils_headers.py @@ -0,0 +1,64 @@ +import unittest +from unittest.mock import Mock, patch + +from huggingface_hub.utils._headers import build_hf_headers + +from .testing_utils import handle_injection + + +FAKE_TOKEN = "123456789" +FAKE_TOKEN_ORG = "api_org_123456789" +FAKE_TOKEN_HEADER = {"authorization": f"Bearer {FAKE_TOKEN}"} + + +@patch("huggingface_hub.utils._headers.HfFolder") +@handle_injection +class TestAuthHeadersUtil(unittest.TestCase): + def test_use_auth_token_str(self) -> None: + self.assertEqual(build_hf_headers(use_auth_token=FAKE_TOKEN), FAKE_TOKEN_HEADER) + + def test_use_auth_token_true_no_cached_token(self, mock_HfFolder: Mock) -> None: + mock_HfFolder().get_token.return_value = None + with self.assertRaises(EnvironmentError): + build_hf_headers(use_auth_token=True) + + def test_use_auth_token_true_has_cached_token(self, mock_HfFolder: Mock) -> None: + mock_HfFolder().get_token.return_value = FAKE_TOKEN + self.assertEqual(build_hf_headers(use_auth_token=True), FAKE_TOKEN_HEADER) + + def test_use_auth_token_false(self, mock_HfFolder: Mock) -> None: + mock_HfFolder().get_token.return_value = FAKE_TOKEN + self.assertEqual(build_hf_headers(use_auth_token=False), {}) + + def test_use_auth_token_none_no_cached_token(self, mock_HfFolder: Mock) -> None: + mock_HfFolder().get_token.return_value = None + self.assertEqual(build_hf_headers(), {}) + + def test_use_auth_token_none_has_cached_token(self, mock_HfFolder: Mock) -> None: + mock_HfFolder().get_token.return_value = FAKE_TOKEN + self.assertEqual(build_hf_headers(), FAKE_TOKEN_HEADER) + + def test_write_action_org_token(self) -> None: + with self.assertRaises(ValueError): + build_hf_headers(use_auth_token=FAKE_TOKEN_ORG, is_write_action=True) + + def test_write_action_none_token(self, mock_HfFolder: Mock) -> None: + mock_HfFolder().get_token.return_value = None + with self.assertRaises(ValueError): + build_hf_headers(is_write_action=True) + + def test_write_action_use_auth_token_false(self) -> None: + with self.assertRaises(ValueError): + build_hf_headers(use_auth_token=False, is_write_action=True) + + @patch.dict("os.environ", {"HF_HUB_DISABLE_IMPLICIT_TOKEN": "1"}) + def test_implicit_use_disabled(self, mock_HfFolder: Mock) -> None: + mock_HfFolder().get_token.return_value = FAKE_TOKEN + self.assertEqual(build_hf_headers(), {}) # token is not sent + + @patch.dict("os.environ", {"HF_HUB_DISABLE_IMPLICIT_TOKEN": "1"}) + def test_implicit_use_disabled_but_explicit_use(self, mock_HfFolder: Mock) -> None: + mock_HfFolder().get_token.return_value = FAKE_TOKEN + + # This is not an implicit use so we still send it + self.assertEqual(build_hf_headers(use_auth_token=True), FAKE_TOKEN_HEADER) diff --git a/tests/testing_utils.py b/tests/testing_utils.py index a1fb9ea15c..db698f8796 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -369,49 +369,63 @@ def test_hello_both(self, mock_foo: Mock, mock_bar: Mock) -> None: NOTE: this decorator is inspired from the fixture system from pytest. """ - - def _test_decorator(fn: Callable) -> Callable: - signature = inspect.signature(fn) - parameters = signature.parameters - - @wraps(fn) - def _inner(*args, **kwargs): - assert kwargs == {} - - # Initialize new dict at least with `self`. - assert len(args) > 0 - assert len(parameters) > 0 - new_kwargs = {"self": args[0]} - - # Check which mocks have been injected - mocks = {} - for value in args[1:]: - assert isinstance(value, Mock) - mock_name = "mock_" + value._extract_mock_name() - mocks[mock_name] = value - - # Check which mocks are expected - for name, parameter in parameters.items(): - if name == "self": - continue - assert parameter.annotation is Mock - assert name in mocks, ( - f"Mock `{name}` not found for test `{fn.__name__}`. Available:" - f" {', '.join(sorted(mocks.keys()))}" - ) - new_kwargs[name] = mocks[name] - - # Run test only with a subset of mocks - return fn(**new_kwargs) - - return _inner - # Iterate over class functions and decorate tests # Taken from https://stackoverflow.com/a/3467879 # and https://stackoverflow.com/a/30764825 for name, fn in inspect.getmembers(cls): if name.startswith("test_"): - setattr(cls, name, _test_decorator(fn)) + setattr(cls, name, handle_injection_in_test(fn)) # Return decorated class return cls + + +def handle_injection_in_test(fn: Callable) -> Callable: + """ + Handle injections at a test level. See `handle_injection` for more details. + + Example: + ```py + def TestHelloWorld(unittest.TestCase): + + @patch("something.foo") + @patch("something_else.foo.bar") # order doesn't matter + @handle_injection_in_test # after @patch calls + def test_hello_foo(self, mock_foo: Mock) -> None: + (...) + ``` + """ + signature = inspect.signature(fn) + parameters = signature.parameters + + @wraps(fn) + def _inner(*args, **kwargs): + assert kwargs == {} + + # Initialize new dict at least with `self`. + assert len(args) > 0 + assert len(parameters) > 0 + new_kwargs = {"self": args[0]} + + # Check which mocks have been injected + mocks = {} + for value in args[1:]: + assert isinstance(value, Mock) + mock_name = "mock_" + value._extract_mock_name() + mocks[mock_name] = value + + # Check which mocks are expected + for name, parameter in parameters.items(): + if name == "self": + continue + assert parameter.annotation is Mock + assert name in mocks, ( + f"Mock `{name}` not found for test `{fn.__name__}`. Available:" + f" {', '.join(sorted(mocks.keys()))}" + ) + new_kwargs[name] = mocks[name] + + # Run test only with a subset of mocks + return fn(**new_kwargs) + + return _inner