diff --git a/docs/source/quick-start.mdx b/docs/source/quick-start.mdx
index 0ff56a41b3..b0653742b9 100644
--- a/docs/source/quick-start.mdx
+++ b/docs/source/quick-start.mdx
@@ -54,21 +54,16 @@ full-length hash instead of the shorter 7-character commit hash:
For more details and options, see the API reference for [`hf_hub_download`].
-## Create a repository
-
-To create and share files to the Hub, you need to have a Hugging Face account. [Create
-an account](https://hf.co/join) if you don't already have one, and then sign in to find
-your [User Access Token](https://huggingface.co/docs/hub/security-tokens) in
-your Settings. The User Access Token is used to authenticate your identity to the Hub.
-
-
-
-You can also provide your token to our functions and methods. This way you don't need to
-store your token anywhere.
+## Login
-
+In a lot of cases, you must be logged in with a Hugging Face account to interact with
+the Hub: download private repos, upload files, create PRs,...
+[Create an account](https://hf.co/join) if you don't already have one, and then sign in
+to get your [User Access Token](https://huggingface.co/docs/hub/security-tokens) from
+your [Settings page](https://huggingface.co/settings/tokens). The User Access Token is
+used to authenticate your identity to the Hub.
-1. Log in to your Hugging Face account with the following command:
+Once you have your User Access Token, run the following command in your terminal:
```bash
huggingface-cli login
@@ -82,9 +77,25 @@ Or if you prefer to work from a Jupyter or Colaboratory notebook, then login wit
>>> notebook_login()
```
-2. Enter your User Access Token to authenticate your identity to the Hub.
+
+
+You can also provide your token to the functions and methods. This way you don't need to
+store your token anywhere.
+
+
+
+
+
+Once you are logged in, all requests to the Hub will use your access token by default.
+If you want to disable implicit use of your token, you should set the
+`HF_HUB_DISABLE_IMPLICIT_TOKEN` environment variable.
+
+
+
+## Create a repository
-After you've created an account, create a repository with the [`create_repo`] function:
+Once you've registered and logged in, create a repository with the [`create_repo`]
+function:
```py
>>> from huggingface_hub import HfApi
@@ -102,6 +113,14 @@ If you want your repository to be private, then:
Private repositories will not be visible to anyone except yourself.
+
+
+To create a repository or to push content to the Hub, you must provide a User Access
+Token that has the `write` permission. You can choose the permission when creating the
+token in your [Settings page](https://huggingface.co/settings/tokens).
+
+
+
## Share files to the Hub
Use the [`upload_file`] function to add a file to your newly created repository. You
diff --git a/setup.py b/setup.py
index 1d04fd9830..be30badf89 100644
--- a/setup.py
+++ b/setup.py
@@ -41,7 +41,6 @@ def get_version() -> str:
extras["tensorflow"] = ["tensorflow", "pydot", "graphviz"]
extras["testing"] = extras["cli"] + [
- "datasets",
"isort>=5.5.4",
"jedi",
"Jinja2",
diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py
index 78c51bd1e9..a413530e0e 100644
--- a/src/huggingface_hub/__init__.py
+++ b/src/huggingface_hub/__init__.py
@@ -98,7 +98,6 @@
"CommitOperationDelete",
"DatasetSearchArguments",
"HfApi",
- "HfFolder",
"ModelSearchArguments",
"change_discussion_status",
"comment_discussion",
@@ -169,6 +168,7 @@
"CorruptedCacheException",
"DeleteCacheStrategy",
"HFCacheInfo",
+ "HfFolder",
"logging",
"scan_cache_dir",
],
@@ -311,7 +311,6 @@ def __dir__():
from .hf_api import CommitOperationDelete # noqa: F401
from .hf_api import DatasetSearchArguments # noqa: F401
from .hf_api import HfApi # noqa: F401
- from .hf_api import HfFolder # noqa: F401
from .hf_api import ModelSearchArguments # noqa: F401
from .hf_api import change_discussion_status # noqa: F401
from .hf_api import comment_discussion # noqa: F401
@@ -368,6 +367,7 @@ def __dir__():
from .utils import CorruptedCacheException # noqa: F401
from .utils import DeleteCacheStrategy # noqa: F401
from .utils import HFCacheInfo # noqa: F401
+ from .utils import HfFolder # noqa: F401
from .utils import logging # noqa: F401
from .utils import scan_cache_dir # noqa: F401
from .utils.endpoint_helpers import DatasetFilter # noqa: F401
diff --git a/src/huggingface_hub/_commit_api.py b/src/huggingface_hub/_commit_api.py
index 36d5c82db6..9005e45d65 100644
--- a/src/huggingface_hub/_commit_api.py
+++ b/src/huggingface_hub/_commit_api.py
@@ -14,7 +14,7 @@
from .constants import ENDPOINT
from .lfs import UploadInfo, _validate_batch_actions, lfs_upload, post_lfs_batch_info
-from .utils import hf_raise_for_status, logging, validate_hf_hub_args
+from .utils import build_hf_headers, hf_raise_for_status, logging, validate_hf_hub_args
from .utils._typing import Literal
@@ -355,7 +355,7 @@ def fetch_upload_modes(
If the Hub API returned an HTTP 400 error (bad request)
"""
endpoint = endpoint if endpoint is not None else ENDPOINT
- headers = {"authorization": f"Bearer {token}"} if token is not None else None
+ headers = build_hf_headers(use_auth_token=token)
payload = {
"files": [
{
diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py
index c7e72580bb..ec74fb95f2 100644
--- a/src/huggingface_hub/_snapshot_download.py
+++ b/src/huggingface_hub/_snapshot_download.py
@@ -4,7 +4,7 @@
from .constants import DEFAULT_REVISION, HUGGINGFACE_HUB_CACHE, REPO_TYPES
from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name
-from .hf_api import HfApi, HfFolder
+from .hf_api import HfApi
from .utils import filter_repo_objects, logging, tqdm, validate_hf_hub_args
from .utils._deprecation import _deprecate_arguments
@@ -108,18 +108,6 @@ def snapshot_download(
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
- if isinstance(use_auth_token, str):
- token = use_auth_token
- elif use_auth_token:
- token = HfFolder.get_token()
- if token is None:
- raise EnvironmentError(
- "You specified use_auth_token=True, but a Hugging Face token was not"
- " found."
- )
- else:
- token = None
-
if repo_type is None:
repo_type = "model"
if repo_type not in REPO_TYPES:
@@ -167,7 +155,10 @@ def snapshot_download(
# if we have internet connection we retrieve the correct folder name from the huggingface api
_api = HfApi()
repo_info = _api.repo_info(
- repo_id=repo_id, repo_type=repo_type, revision=revision, use_auth_token=token
+ repo_id=repo_id,
+ repo_type=repo_type,
+ revision=revision,
+ use_auth_token=use_auth_token,
)
filtered_repo_files = list(
filter_repo_objects(
diff --git a/src/huggingface_hub/commands/user.py b/src/huggingface_hub/commands/user.py
index d5477263dd..e06b13672a 100644
--- a/src/huggingface_hub/commands/user.py
+++ b/src/huggingface_hub/commands/user.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import subprocess
from argparse import ArgumentParser
from getpass import getpass
@@ -24,10 +23,10 @@
REPO_TYPES_URL_PREFIXES,
SPACES_SDK_TYPES,
)
-from huggingface_hub.hf_api import HfApi, HfFolder
+from huggingface_hub.hf_api import HfApi
from requests.exceptions import HTTPError
-from ..utils import run_subprocess
+from ..utils import HfFolder, run_subprocess
from ._cli_utils import ANSI
@@ -308,13 +307,16 @@ def login_token_event(t):
# Erase token and clear value to make sure it's not saved in the notebook.
token_widget.value = ""
clear_output()
- _login(HfApi(), token=token)
+ _login(token=token)
token_finish_button.on_click(login_token_event)
-def _login(hf_api, token=None):
- token, name = hf_api._validate_or_retrieve_token(token)
+def _login(hf_api: HfApi, token: str) -> None:
+ if token.startswith("api_org"):
+ raise ValueError("You must use your personal account token.")
+ if not hf_api._is_valid_token(token=token):
+ raise ValueError("Invalid token passed!")
hf_api.set_access_token(token)
HfFolder.save_token(token)
print("Login successful")
diff --git a/src/huggingface_hub/fastai_utils.py b/src/huggingface_hub/fastai_utils.py
index 26d02cdbea..0fd6ff686d 100644
--- a/src/huggingface_hub/fastai_utils.py
+++ b/src/huggingface_hub/fastai_utils.py
@@ -7,7 +7,7 @@
from packaging import version
-from huggingface_hub import hf_api, snapshot_download
+from huggingface_hub import snapshot_download
from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.file_download import (
_PY_VERSION,
@@ -437,7 +437,6 @@ def push_to_hub_fastai(
"""
_check_fastai_fastcore_versions()
- token, _ = hf_api._validate_or_retrieve_token(token)
api = HfApi(endpoint=api_endpoint)
api.create_repo(
repo_id=repo_id,
diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py
index fc949d73fa..374c3c702a 100644
--- a/src/huggingface_hub/file_download.py
+++ b/src/huggingface_hub/file_download.py
@@ -33,10 +33,10 @@
REPO_TYPES,
REPO_TYPES_URL_PREFIXES,
)
-from .hf_api import HfFolder
from .utils import (
EntryNotFoundError,
LocalEntryNotFoundError,
+ build_hf_headers,
hf_raise_for_status,
http_backoff,
logging,
@@ -670,23 +670,12 @@ def cached_download(
os.makedirs(cache_dir, exist_ok=True)
- headers = {
- "user-agent": http_user_agent(
- library_name=library_name,
- library_version=library_version,
- user_agent=user_agent,
- )
- }
- if isinstance(use_auth_token, str):
- headers["authorization"] = f"Bearer {use_auth_token}"
- elif use_auth_token:
- token = HfFolder.get_token()
- if token is None:
- raise EnvironmentError(
- "You specified use_auth_token=True, but a huggingface token was not"
- " found."
- )
- headers["authorization"] = f"Bearer {token}"
+ headers = build_hf_headers(use_auth_token=use_auth_token)
+ headers["user-agent"] = http_user_agent(
+ library_name=library_name,
+ library_version=library_version,
+ user_agent=user_agent,
+ )
url_to_download = url
etag = None
@@ -1115,23 +1104,12 @@ def hf_hub_download(
url = hf_hub_url(repo_id, filename, repo_type=repo_type, revision=revision)
- headers = {
- "user-agent": http_user_agent(
- library_name=library_name,
- library_version=library_version,
- user_agent=user_agent,
- )
- }
- if isinstance(use_auth_token, str):
- headers["authorization"] = f"Bearer {use_auth_token}"
- elif use_auth_token:
- token = HfFolder.get_token()
- if token is None:
- raise EnvironmentError(
- "You specified use_auth_token=True, but a huggingface token was not"
- " found."
- )
- headers["authorization"] = f"Bearer {token}"
+ headers = build_hf_headers(use_auth_token=use_auth_token)
+ headers["user-agent"] = http_user_agent(
+ library_name=library_name,
+ library_version=library_version,
+ user_agent=user_agent,
+ )
url_to_download = url
etag = None
@@ -1433,18 +1411,7 @@ def get_hf_file_metadata(
A [`HfFileMetadata`] object containing metadata such as location, etag and
commit_hash.
"""
- # TODO: helper to get headers from `use_auth_token` (copy-pasted several times)
- headers = {}
- if isinstance(use_auth_token, str):
- headers["authorization"] = f"Bearer {use_auth_token}"
- elif use_auth_token:
- token = HfFolder.get_token()
- if token is None:
- raise EnvironmentError(
- "You specified use_auth_token=True, but a huggingface token was not"
- " found."
- )
- headers["authorization"] = f"Bearer {token}"
+ headers = build_hf_headers(use_auth_token=use_auth_token)
# Retrieve metadata
r = _request_wrapper(
diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py
index e0758c0665..418a2680f1 100644
--- a/src/huggingface_hub/hf_api.py
+++ b/src/huggingface_hub/hf_api.py
@@ -16,7 +16,6 @@
import re
import subprocess
import warnings
-from os.path import expanduser
from typing import BinaryIO, Dict, Iterable, Iterator, List, Optional, Tuple, Union
from urllib.parse import quote
@@ -50,14 +49,16 @@
REPO_TYPES_URL_PREFIXES,
SPACES_SDK_TYPES,
)
+from .utils import HfFolder # noqa: F401 # imported for backward compatibility
from .utils import (
+ build_hf_headers,
filter_repo_objects,
hf_raise_for_status,
logging,
parse_datetime,
validate_hf_hub_args,
)
-from .utils._deprecation import _deprecate_positional_args
+from .utils._deprecation import _deprecate_arguments, _deprecate_positional_args
from .utils._typing import Literal, TypedDict
from .utils.endpoint_helpers import (
AttributeDictionary,
@@ -613,15 +614,14 @@ def whoami(self, token: Optional[str] = None) -> Dict:
Hugging Face token. Will default to the locally saved token if
not provided.
"""
- if token is None:
- token = HfFolder.get_token()
- if token is None:
- raise ValueError(
- "You need to pass a valid `token` or login by using `huggingface-cli"
- " login`"
- )
- path = f"{self.endpoint}/api/whoami-v2"
- r = requests.get(path, headers={"authorization": f"Bearer {token}"})
+ r = requests.get(
+ f"{self.endpoint}/api/whoami-v2",
+ headers=build_hf_headers(
+ # If `token` is provided and not `None`, it will be used by default.
+ # Otherwise, the token must be retrieved from cache or env variable.
+ use_auth_token=(token or True),
+ ),
+ )
try:
hf_raise_for_status(r)
except HTTPError as e:
@@ -632,7 +632,7 @@ def whoami(self, token: Optional[str] = None) -> Dict:
) from e
return r.json()
- def _is_valid_token(self, token: str):
+ def _is_valid_token(self, token: str) -> bool:
"""
Determines whether `token` is a valid token or not.
@@ -649,74 +649,6 @@ def _is_valid_token(self, token: str):
except HTTPError:
return False
- def _validate_or_retrieve_token(
- self,
- token: Optional[Union[str, bool]] = None,
- name: Optional[str] = None,
- function_name: Optional[str] = None,
- ):
- """
- Retrieves and validates stored token or validates passed token.
- Args:
- token (``str``, `optional`):
- Hugging Face token. Will default to the locally saved token if not provided.
- name (``str``, `optional`):
- Name of the repository. This is deprecated in favor of repo_id and will be removed in v0.8.
- function_name (``str``, `optional`):
- If _validate_or_retrieve_token is called from a function, name of that function to be passed inside deprecation warning.
- Returns:
- Validated token and the name of the repository.
- Raises:
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
- If the token is not passed and there's no token saved locally.
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
- If organization token or invalid token is passed.
- """
- if token is None or token is True:
- token = HfFolder.get_token()
- if token is None:
- raise EnvironmentError(
- "You need to provide a `token` or be logged in to Hugging Face with"
- " `huggingface-cli login`."
- )
- if name is not None:
- if self._is_valid_token(name):
- # TODO(0.6) REMOVE
- warnings.warn(
- f"`{function_name}` now takes `token` as an optional positional"
- " argument. Be sure to adapt your code!",
- FutureWarning,
- )
- token, name = name, token
- if isinstance(token, str):
- if token.startswith("api_org"):
- raise ValueError("You must use your personal account token.")
- if not self._is_valid_token(token):
- raise ValueError("Invalid token passed!")
-
- return token, name
-
- def _build_auth_headers(
- self, *, token: Optional[str], use_auth_token: Optional[Union[str, bool]]
- ) -> Dict[str, str]:
- """Helper to build Authorization header from kwargs. To be removed in 0.12.0 when `token` is deprecated."""
- if token is not None:
- warnings.warn(
- "`token` is deprecated and will be removed in 0.12.0. Use"
- " `use_auth_token` instead.",
- FutureWarning,
- )
-
- auth_token = None
- if use_auth_token is None and token is None:
- # To maintain backwards-compatibility. To be removed in 0.12.0
- auth_token = HfFolder.get_token()
- elif use_auth_token:
- auth_token, _ = self._validate_or_retrieve_token(use_auth_token)
- else:
- auth_token = token
- return {"authorization": f"Bearer {auth_token}"} if auth_token else {}
-
@staticmethod
def set_access_token(access_token: str):
"""
@@ -861,10 +793,7 @@ def list_models(
```
"""
path = f"{self.endpoint}/api/models"
- headers = {}
- if use_auth_token:
- token, name = self._validate_or_retrieve_token(use_auth_token)
- headers["authorization"] = f"Bearer {token}"
+ headers = build_hf_headers(use_auth_token=use_auth_token)
params = {}
if filter is not None:
if isinstance(filter, ModelFilter):
@@ -1059,10 +988,7 @@ def list_datasets(
```
"""
path = f"{self.endpoint}/api/datasets"
- headers = {}
- if use_auth_token:
- token, name = self._validate_or_retrieve_token(use_auth_token)
- headers["authorization"] = f"Bearer {token}"
+ headers = build_hf_headers(use_auth_token=use_auth_token)
params = {}
if filter is not None:
if isinstance(filter, DatasetFilter):
@@ -1199,10 +1125,7 @@ def list_spaces(
`List[SpaceInfo]`: a list of [`huggingface_hub.hf_api.SpaceInfo`] objects
"""
path = f"{self.endpoint}/api/spaces"
- headers = {}
- if use_auth_token:
- token, name = self._validate_or_retrieve_token(use_auth_token)
- headers["authorization"] = f"Bearer {token}"
+ headers = build_hf_headers(use_auth_token=use_auth_token)
params = {}
if filter is not None:
params.update({"filter": filter})
@@ -1232,6 +1155,7 @@ def list_spaces(
return [SpaceInfo(**x) for x in d]
@validate_hf_hub_args
+ @_deprecate_arguments(version="0.12", deprecated_args={"token"})
def model_info(
self,
repo_id: str,
@@ -1286,7 +1210,7 @@ def model_info(
"""
- headers = self._build_auth_headers(token=token, use_auth_token=use_auth_token)
+ headers = build_hf_headers(use_auth_token=token or use_auth_token)
path = (
f"{self.endpoint}/api/models/{repo_id}"
if revision is None
@@ -1310,6 +1234,7 @@ def model_info(
return ModelInfo(**d)
@validate_hf_hub_args
+ @_deprecate_arguments(version="0.12", deprecated_args={"token"})
def dataset_info(
self,
repo_id: str,
@@ -1360,8 +1285,7 @@ def dataset_info(
"""
- headers = self._build_auth_headers(token=token, use_auth_token=use_auth_token)
-
+ headers = build_hf_headers(use_auth_token=token or use_auth_token)
path = (
f"{self.endpoint}/api/datasets/{repo_id}"
if revision is None
@@ -1379,6 +1303,7 @@ def dataset_info(
return DatasetInfo(**d)
@validate_hf_hub_args
+ @_deprecate_arguments(version="0.12", deprecated_args={"token"})
def space_info(
self,
repo_id: str,
@@ -1429,7 +1354,7 @@ def space_info(
"""
- headers = self._build_auth_headers(token=token, use_auth_token=use_auth_token)
+ headers = build_hf_headers(use_auth_token=token or use_auth_token)
path = (
f"{self.endpoint}/api/spaces/{repo_id}"
if revision is None
@@ -1611,10 +1536,6 @@ def create_repo(
path = f"{self.endpoint}/api/repos/create"
- token, name = self._validate_or_retrieve_token(
- token, name, function_name="create_repo"
- )
-
checked_name = repo_type_and_id_from_hf_id(name)
if (
@@ -1671,11 +1592,8 @@ def create_repo(
if getattr(self, "_lfsmultipartthresh", None):
json["lfsmultipartthresh"] = self._lfsmultipartthresh
- r = requests.post(
- path,
- headers={"authorization": f"Bearer {token}"},
- json=json,
- )
+ headers = build_hf_headers(use_auth_token=token, is_write_action=True)
+ r = requests.post(path, headers=headers, json=json)
try:
hf_raise_for_status(r)
@@ -1736,10 +1654,6 @@ def delete_repo(
path = f"{self.endpoint}/api/repos/delete"
- token, name = self._validate_or_retrieve_token(
- token, name, function_name="delete_repo"
- )
-
checked_name = repo_type_and_id_from_hf_id(name)
if (
@@ -1779,11 +1693,8 @@ def delete_repo(
if repo_type is not None:
json["type"] = repo_type
- r = requests.delete(
- path,
- headers={"authorization": f"Bearer {token}"},
- json=json,
- )
+ headers = build_hf_headers(use_auth_token=token, is_write_action=True)
+ r = requests.delete(path, headers=headers, json=json)
hf_raise_for_status(r)
@validate_hf_hub_args
@@ -1837,10 +1748,6 @@ def update_repo_visibility(
organization, name = repo_id.split("/") if "/" in repo_id else (None, repo_id)
- token, name = self._validate_or_retrieve_token(
- token, name, function_name="update_repo_visibility"
- )
-
if organization is None:
namespace = self.whoami(token)["name"]
else:
@@ -1849,14 +1756,10 @@ def update_repo_visibility(
if repo_type is None:
repo_type = REPO_TYPE_MODEL # default repo type
- path = f"{self.endpoint}/api/{repo_type}s/{namespace}/{name}/settings"
-
- json = {"private": private}
-
r = requests.put(
- path,
- headers={"authorization": f"Bearer {token}"},
- json=json,
+ url=f"{self.endpoint}/api/{repo_type}s/{namespace}/{name}/settings",
+ headers=build_hf_headers(use_auth_token=token, is_write_action=True),
+ json={"private": private},
)
hf_raise_for_status(r)
return r.json()
@@ -1900,9 +1803,6 @@ def move_repo(
"""
-
- token, name = self._validate_or_retrieve_token(token)
-
if len(from_id.split("/")) != 2:
raise ValueError(
f"Invalid repo_id: {from_id}. It should have a namespace"
@@ -1918,11 +1818,8 @@ def move_repo(
json = {"fromRepo": from_id, "toRepo": to_id, "type": repo_type}
path = f"{self.endpoint}/api/repos/move"
- r = requests.post(
- path,
- headers={"authorization": f"Bearer {token}"},
- json=json,
- )
+ headers = build_hf_headers(use_auth_token=token, is_write_action=True)
+ r = requests.post(path, headers=headers, json=json)
try:
hf_raise_for_status(r)
except HTTPError as e:
@@ -2046,7 +1943,6 @@ def create_commit(
repo_type = repo_type if repo_type is not None else REPO_TYPE_MODEL
if repo_type not in REPO_TYPES:
raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}")
- token, name = self._validate_or_retrieve_token(token)
revision = (
quote(revision, safe="") if revision is not None else DEFAULT_REVISION
)
@@ -2111,9 +2007,10 @@ def create_commit(
)
commit_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/commit/{revision}"
+ headers = build_hf_headers(use_auth_token=token, is_write_action=True)
commit_resp = requests.post(
url=commit_url,
- headers={"Authorization": f"Bearer {token}"},
+ headers=headers,
json=commit_payload,
params={"create_pr": "1"} if create_pr else None,
)
@@ -2566,14 +2463,11 @@ def get_full_repo_name(
FutureWarning,
)
- if token is None and use_auth_token:
- token, name = self._validate_or_retrieve_token(use_auth_token)
-
if organization is None:
if "/" in model_id:
username = model_id.split("/")[0]
else:
- username = self.whoami(token=token)["name"]
+ username = self.whoami(token=token or use_auth_token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
@@ -2623,16 +2517,14 @@ def get_repo_discussions(
raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}")
if repo_type is None:
repo_type = REPO_TYPE_MODEL
- repo_id = f"{repo_type}s/{repo_id}"
- if token is None:
- token = HfFolder.get_token()
+
+ headers = build_hf_headers(use_auth_token=token)
def _fetch_discussion_page(page_index: int):
- path = f"{self.endpoint}/api/{repo_id}/discussions?p={page_index}"
- resp = requests.get(
- path,
- headers={"Authorization": f"Bearer {token}"} if token else None,
+ path = (
+ f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions?p={page_index}"
)
+ resp = requests.get(path, headers=headers)
hf_raise_for_status(resp)
paginated_discussions = resp.json()
total = paginated_discussions["count"]
@@ -2704,17 +2596,12 @@ def get_discussion_details(
raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}")
if repo_type is None:
repo_type = REPO_TYPE_MODEL
- repo_id = f"{repo_type}s/{repo_id}"
- if token is None:
- token = HfFolder.get_token()
- path = f"{self.endpoint}/api/{repo_id}/discussions/{discussion_num}"
-
- resp = requests.get(
- path,
- params={"diff": "1"},
- headers={"Authorization": f"Bearer {token}"} if token else None,
+ path = (
+ f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions/{discussion_num}"
)
+ headers = build_hf_headers(use_auth_token=token)
+ resp = requests.get(path, params={"diff": "1"}, headers=headers)
hf_raise_for_status(resp)
discussion_details = resp.json()
@@ -2805,8 +2692,7 @@ def create_discussion(
raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}")
if repo_type is None:
repo_type = REPO_TYPE_MODEL
- full_repo_id = f"{repo_type}s/{repo_id}"
- token, _ = self._validate_or_retrieve_token(token=token)
+
if description is not None:
description = description.strip()
description = (
@@ -2819,14 +2705,15 @@ def create_discussion(
)
)
+ headers = build_hf_headers(use_auth_token=token, is_write_action=True)
resp = requests.post(
- f"{self.endpoint}/api/{full_repo_id}/discussions",
+ f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions",
json={
"title": title.strip(),
"description": description,
"pullRequest": pull_request,
},
- headers={"Authorization": f"Bearer {token}"},
+ headers=headers,
)
hf_raise_for_status(resp)
num = resp.json()["num"]
@@ -2913,15 +2800,11 @@ def _post_discussion_changes(
if repo_type is None:
repo_type = REPO_TYPE_MODEL
repo_id = f"{repo_type}s/{repo_id}"
- token, _ = self._validate_or_retrieve_token(token=token)
path = f"{self.endpoint}/api/{repo_id}/discussions/{discussion_num}/{resource}"
- resp = requests.post(
- path,
- headers={"Authorization": f"Bearer {token}"},
- json=body,
- )
+ headers = build_hf_headers(use_auth_token=token, is_write_action=True)
+ resp = requests.post(path, headers=headers, json=body)
hf_raise_for_status(resp)
return resp
@@ -3316,54 +3199,6 @@ def hide_discussion_comment(
return deserialize_event(resp.json()["updatedComment"])
-class HfFolder:
- path_token = expanduser("~/.huggingface/token")
-
- @classmethod
- def save_token(cls, token):
- """
- Save token, creating folder as needed.
-
- Args:
- token (`str`):
- The token to save to the [`HfFolder`]
- """
- os.makedirs(os.path.dirname(cls.path_token), exist_ok=True)
- with open(cls.path_token, "w+") as f:
- f.write(token)
-
- @classmethod
- def get_token(cls) -> Optional[str]:
- """
- Get token or None if not existent.
-
- Note that a token can be also provided using the
- `HUGGING_FACE_HUB_TOKEN` environment variable.
-
- Returns:
- `str` or `None`: The token, `None` if it doesn't exist.
-
- """
- token: Optional[str] = os.environ.get("HUGGING_FACE_HUB_TOKEN")
- if token is None:
- try:
- with open(cls.path_token, "r") as f:
- token = f.read()
- except FileNotFoundError:
- pass
- return token
-
- @classmethod
- def delete_token(cls):
- """
- Deletes the token from storage. Does not fail if token does not exist.
- """
- try:
- os.remove(cls.path_token)
- except FileNotFoundError:
- pass
-
-
def _prepare_upload_folder_commit(
folder_path: str,
path_in_repo: str,
@@ -3464,5 +3299,3 @@ def _parse_revision_from_pr_url(pr_url: str) -> str:
edit_discussion_comment = api.edit_discussion_comment
rename_discussion = api.rename_discussion
merge_pull_request = api.merge_pull_request
-
-_validate_or_retrieve_token = api._validate_or_retrieve_token
diff --git a/src/huggingface_hub/hub_mixin.py b/src/huggingface_hub/hub_mixin.py
index ff0f292c5c..fd24fcd546 100644
--- a/src/huggingface_hub/hub_mixin.py
+++ b/src/huggingface_hub/hub_mixin.py
@@ -5,13 +5,12 @@
from typing import Dict, List, Optional, Union
import requests
-from huggingface_hub import hf_api
from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
from .file_download import hf_hub_download, is_torch_available
-from .hf_api import HfApi, HfFolder
+from .hf_api import HfApi
from .repository import Repository
-from .utils import logging, validate_hf_hub_args
+from .utils import HfFolder, logging, validate_hf_hub_args
from .utils._deprecation import _deprecate_arguments, _deprecate_positional_args
@@ -326,9 +325,7 @@ def push_to_hub(
# If the repo id is set, it means we use the new version using HTTP endpoint
# (introduced in v0.9).
if repo_id is not None:
- token, _ = hf_api._validate_or_retrieve_token(token)
api = HfApi(endpoint=api_endpoint)
-
api.create_repo(
repo_id=repo_id,
repo_type="model",
diff --git a/src/huggingface_hub/keras_mixin.py b/src/huggingface_hub/keras_mixin.py
index dfcd90758b..9872bf7502 100644
--- a/src/huggingface_hub/keras_mixin.py
+++ b/src/huggingface_hub/keras_mixin.py
@@ -9,12 +9,7 @@
from urllib.parse import quote
import yaml
-from huggingface_hub import (
- CommitOperationDelete,
- ModelHubMixin,
- hf_api,
- snapshot_download,
-)
+from huggingface_hub import CommitOperationDelete, ModelHubMixin, snapshot_download
from huggingface_hub.file_download import (
get_tf_version,
is_graphviz_available,
@@ -23,14 +18,9 @@
)
from .constants import CONFIG_NAME, DEFAULT_REVISION
-from .hf_api import (
- HfApi,
- HfFolder,
- _parse_revision_from_pr_url,
- _prepare_upload_folder_commit,
-)
+from .hf_api import HfApi, _parse_revision_from_pr_url, _prepare_upload_folder_commit
from .repository import Repository
-from .utils import logging, validate_hf_hub_args
+from .utils import HfFolder, logging, validate_hf_hub_args
from .utils._deprecation import _deprecate_arguments, _deprecate_positional_args
@@ -409,9 +399,7 @@ def push_to_hub_keras(
The url of the commit of your model in the given repository.
"""
if repo_id is not None:
- token, _ = hf_api._validate_or_retrieve_token(token)
api = HfApi(endpoint=api_endpoint)
-
api.create_repo(
repo_id=repo_id,
repo_type="model",
diff --git a/src/huggingface_hub/repository.py b/src/huggingface_hub/repository.py
index b132915d71..f10b8c6417 100644
--- a/src/huggingface_hub/repository.py
+++ b/src/huggingface_hub/repository.py
@@ -15,9 +15,9 @@
from huggingface_hub.repocard import metadata_load, metadata_save
from requests.exceptions import HTTPError
-from .hf_api import HfApi, HfFolder, repo_type_and_id_from_hf_id
+from .hf_api import HfApi, repo_type_and_id_from_hf_id
from .lfs import LFS_MULTIPART_UPLOAD_COMMAND
-from .utils import logging, run_subprocess, tqdm
+from .utils import HfFolder, logging, run_subprocess, tqdm
from .utils._deprecation import _deprecate_arguments, _deprecate_method
diff --git a/src/huggingface_hub/utils/__init__.py b/src/huggingface_hub/utils/__init__.py
index cf5176374c..5faa622778 100644
--- a/src/huggingface_hub/utils/__init__.py
+++ b/src/huggingface_hub/utils/__init__.py
@@ -35,6 +35,8 @@
RevisionNotFoundError,
hf_raise_for_status,
)
+from ._headers import build_hf_headers
+from ._hf_folder import HfFolder
from ._http import http_backoff
from ._paths import filter_repo_objects
from ._subprocess import run_subprocess
diff --git a/src/huggingface_hub/utils/_deprecation.py b/src/huggingface_hub/utils/_deprecation.py
index ab8a5e135e..3cb14e7243 100644
--- a/src/huggingface_hub/utils/_deprecation.py
+++ b/src/huggingface_hub/utils/_deprecation.py
@@ -76,8 +76,13 @@ def inner_f(*args, **kwargs):
for _, parameter in zip(args, sig.parameters.values()):
if parameter.name in deprecated_args:
used_deprecated_args.append(parameter.name)
- for kwarg_name in kwargs:
- if kwarg_name in deprecated_args:
+ for kwarg_name, kwarg_value in kwargs.items():
+ if (
+ # If argument is deprecated but still used
+ kwarg_name in deprecated_args
+ # And then the value is not the default value
+ and kwarg_value != sig.parameters[kwarg_name].default
+ ):
used_deprecated_args.append(kwarg_name)
# Warn and proceed
diff --git a/src/huggingface_hub/utils/_headers.py b/src/huggingface_hub/utils/_headers.py
new file mode 100644
index 0000000000..d51035185f
--- /dev/null
+++ b/src/huggingface_hub/utils/_headers.py
@@ -0,0 +1,139 @@
+# coding=utf-8
+# Copyright 2022-present, the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Contains utilities to handle headers to send in calls to Huggingface Hub."""
+import os
+from typing import Dict, Optional, Union
+
+from ._hf_folder import HfFolder
+
+
+def build_hf_headers(
+ *, use_auth_token: Optional[Union[bool, str]] = None, is_write_action: bool = False
+) -> Dict[str, str]:
+ """
+ Build headers dictionary to send in a HF Hub call.
+
+ By default, authorization token is always provided either from argument (explicit
+ use) or retrieved from the cache (implicit use). To explicitly avoid sending the
+ token to the Hub, set `use_auth_token=False` or set the `HF_HUB_DISABLE_IMPLICIT_TOKEN`
+ environment variable.
+
+ In case of an API call that requires write access, an error is thrown if token is
+ `None` or token is an organization token (starting with `"api_org***"`).
+
+ Args:
+ use_auth_token (`str`, `bool`, *optional*):
+ The token to be sent in authorization header for the Hub call:
+ - if a string, it is used as the Hugging Face token
+ - if `True`, the token is read from the machine (cache or env variable)
+ - if `False`, authorization header is not set
+ - if `None`, the token is read from the machine only except if
+ `HF_HUB_DISABLE_IMPLICIT_TOKEN` env variable is set.
+ is_write_action (`bool`, default to `False`):
+ Set to True if the API call requires a write access. If `True`, the token
+ will be validated (cannot be `None`, cannot start by `"api_org***"`).
+
+ Returns:
+ A `Dict` of headers to pass in your API call.
+
+ Example:
+ ```py
+ >>> build_hf_headers(use_auth_token="hf_***") # explicit token
+ {"authorization": "Bearer hf_***"}
+
+ >>> build_hf_headers(use_auth_token=True) # explicitly use cached token
+ {"authorization": "Bearer hf_***"}
+
+ >>> build_hf_headers(use_auth_token=False) # explicitly don't use cached token
+ {}
+
+ >>> build_hf_headers() # implicit use of the cached token
+ {"authorization": "Bearer hf_***"}
+
+ # HF_HUB_DISABLE_IMPLICIT_TOKEN=True # to set as env variable
+ >>> build_hf_headers() # token is not sent
+ {}
+
+ >>> build_hf_headers(use_auth_token="api_org_***", is_write_action=True)
+ ValueError: You must use your personal account token for write-access methods.
+ ```
+
+ Raises:
+ [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
+ If organization token is passed and "write" access is required.
+ [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
+ If "write" access is required but token is not passed and not saved locally.
+ [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
+ If `use_auth_token=True` but token is not saved locally.
+ """
+ # Get auth token to send
+ token_to_send = _get_token_to_send(use_auth_token)
+ _validate_token_to_send(token_to_send, is_write_action=is_write_action)
+
+ # Combine headers
+ headers = {}
+ if token_to_send is not None:
+ headers["authorization"] = f"Bearer {token_to_send}"
+ # TODO: add user agent in headers
+ return headers
+
+
+def _get_token_to_send(use_auth_token: Optional[Union[bool, str]]) -> Optional[str]:
+ """Select the token to send from either `use_auth_token` or the cache."""
+ # Case token is explicitly provided
+ if isinstance(use_auth_token, str):
+ return use_auth_token
+
+ # Case token is explicitly forbidden
+ if use_auth_token is False:
+ return None
+
+ # Token is not provided: we get it from local cache
+ cached_token = HfFolder().get_token()
+
+ # Case token is explicitly required
+ if use_auth_token is True:
+ if cached_token is None:
+ raise EnvironmentError(
+ "Token is required (`use_auth_token=True`), but no token found. You"
+ " need to provide a token or be logged in to Hugging Face with"
+ " `huggingface-cli login` or `notebook_login`. See"
+ " https://huggingface.co/settings/tokens."
+ )
+ return cached_token
+
+ # Case implicit use of the token is forbidden by env variable
+ if os.environ.get("HF_HUB_DISABLE_IMPLICIT_TOKEN"):
+ return None
+
+ # Otherwise: we use the cached token as the user has not explicitly forbidden it
+ return cached_token
+
+
+def _validate_token_to_send(token: Optional[str], is_write_action: bool) -> None:
+ if is_write_action:
+ if token is None:
+ raise ValueError(
+ "Token is required (write-access action) but no token found. You need"
+ " to provide a token or be logged in to Hugging Face with"
+ " `huggingface-cli login` or `notebook_login`. See"
+ " https://huggingface.co/settings/tokens."
+ )
+ if token.startswith("api_org"):
+ raise ValueError(
+ "You must use your personal account token for write-access methods. To"
+ " generate a write-access token, go to"
+ " https://huggingface.co/settings/tokens"
+ )
diff --git a/src/huggingface_hub/utils/_hf_folder.py b/src/huggingface_hub/utils/_hf_folder.py
new file mode 100644
index 0000000000..c9182ada0a
--- /dev/null
+++ b/src/huggingface_hub/utils/_hf_folder.py
@@ -0,0 +1,64 @@
+# coding=utf-8
+# Copyright 2022-present, the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Contain helper class to retrieve/store token from/to local cache."""
+import os
+from pathlib import Path
+from typing import Optional
+
+
+class HfFolder:
+ path_token = Path("~/.huggingface/token").expanduser()
+
+ @classmethod
+ def save_token(cls, token: str) -> None:
+ """
+ Save token, creating folder as needed.
+
+ Args:
+ token (`str`):
+ The token to save to the [`HfFolder`]
+ """
+ cls.path_token.parent.mkdir(exist_ok=True)
+ with cls.path_token.open("w+") as f:
+ f.write(token)
+
+ @classmethod
+ def get_token(cls) -> Optional[str]:
+ """
+ Get token or None if not existent.
+
+ Note that a token can be also provided using the
+ `HUGGING_FACE_HUB_TOKEN` environment variable.
+
+ Returns:
+ `str` or `None`: The token, `None` if it doesn't exist.
+ """
+ token: Optional[str] = os.environ.get("HUGGING_FACE_HUB_TOKEN")
+ if token is None:
+ try:
+ return cls.path_token.read_text()
+ except FileNotFoundError:
+ pass
+ return token
+
+ @classmethod
+ def delete_token(cls) -> None:
+ """
+ Deletes the token from storage. Does not fail if token does not exist.
+ """
+ try:
+ cls.path_token.unlink()
+ except FileNotFoundError:
+ pass
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000000..3c871da76c
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,22 @@
+from typing import Generator
+
+import pytest
+
+from huggingface_hub import HfFolder
+
+
+@pytest.fixture(autouse=True, scope="session")
+def clean_hf_folder_token_for_tests() -> Generator:
+ """Clean token stored on machine before all tests and reset it back at the end.
+
+ Useful to avoid token deletion when running tests locally.
+ """
+ # Remove registered token
+ token = HfFolder().get_token()
+ HfFolder().delete_token()
+
+ yield # Run all tests
+
+ # Set back token once all tests have passed
+ if token is not None:
+ HfFolder().save_token(token)
diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py
index 3bd41e2a99..66c73006ea 100644
--- a/tests/test_hf_api.py
+++ b/tests/test_hf_api.py
@@ -23,6 +23,7 @@
from functools import partial
from io import BytesIO
from typing import List
+from unittest.mock import Mock, patch
from urllib.parse import quote
import pytest
@@ -43,7 +44,6 @@
DatasetInfo,
DatasetSearchArguments,
HfApi,
- HfFolder,
MetricInfo,
ModelInfo,
ModelSearchArguments,
@@ -53,7 +53,7 @@
read_from_credential_store,
repo_type_and_id_from_hf_id,
)
-from huggingface_hub.utils import RepositoryNotFoundError, logging
+from huggingface_hub.utils import HfFolder, RepositoryNotFoundError, logging
from huggingface_hub.utils.endpoint_helpers import (
DatasetFilter,
ModelFilter,
@@ -181,7 +181,7 @@ def test_repo_id_no_warning():
class HfApiEndpointsTest(HfApiCommonTestWithLogin):
- def test_whoami(self):
+ def test_whoami_with_passing_token(self):
info = self._api.whoami(token=self._token)
self.assertEqual(info["name"], USER)
self.assertEqual(info["fullname"], FULL_NAME)
@@ -189,6 +189,13 @@ def test_whoami(self):
valid_org = [org for org in info["orgs"] if org["name"] == "valid_org"][0]
self.assertIsInstance(valid_org["apiToken"], str)
+ @patch("huggingface_hub.utils._headers.HfFolder")
+ def test_whoami_with_implicit_token_from_login(self, mock_HfFolder: Mock) -> None:
+ """Test using `whoami` after a `huggingface-cli login`."""
+ mock_HfFolder().get_token.return_value = self._token
+ info = self._api.whoami()
+ self.assertEqual(info["name"], USER)
+
@retry_endpoint
def test_delete_repo_error_message(self):
# test for #751
diff --git a/tests/test_inference_api.py b/tests/test_inference_api.py
index d92aacbd92..49019ce955 100644
--- a/tests/test_inference_api.py
+++ b/tests/test_inference_api.py
@@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import unittest
-import datasets
-
+from huggingface_hub import hf_hub_download
from huggingface_hub.inference_api import InferenceApi
from .testing_utils import with_production_testing
@@ -67,13 +65,12 @@ def test_inference_with_dict_inputs(self):
@with_production_testing
def test_inference_with_audio(self):
api = InferenceApi("facebook/wav2vec2-base-960h")
- with self.assertWarns(FutureWarning):
- dataset = datasets.load_dataset(
- "patrickvonplaten/librispeech_asr_dummy",
- "clean",
- split="validation",
- )
- data = self.read(dataset["file"][0])
+ file = hf_hub_download(
+ repo_id="hf-internal-testing/dummy-flac-single-example",
+ repo_type="dataset",
+ filename="example.flac",
+ )
+ data = self.read(file)
result = api(data=data)
self.assertIsInstance(result, dict)
self.assertTrue("text" in result, f"We received {result} instead")
@@ -81,13 +78,10 @@ def test_inference_with_audio(self):
@with_production_testing
def test_inference_with_image(self):
api = InferenceApi("google/vit-base-patch16-224")
- with self.assertWarns(FutureWarning):
- dataset = datasets.load_dataset(
- "Narsil/image_dummy",
- "image",
- split="test",
- )
- data = self.read(dataset["file"][0])
+ file = hf_hub_download(
+ repo_id="Narsil/image_dummy", repo_type="dataset", filename="lena.png"
+ )
+ data = self.read(file)
result = api(data=data)
self.assertIsInstance(result, list)
for classification in result:
diff --git a/tests/test_snapshot_download.py b/tests/test_snapshot_download.py
index 63069482fd..d6f6342d5c 100644
--- a/tests/test_snapshot_download.py
+++ b/tests/test_snapshot_download.py
@@ -8,8 +8,7 @@
import requests
from huggingface_hub import HfApi, Repository, snapshot_download
-from huggingface_hub.hf_api import HfFolder
-from huggingface_hub.utils import logging
+from huggingface_hub.utils import HfFolder, logging
from .testing_constants import ENDPOINT_STAGING, TOKEN, USER
from .testing_utils import (
diff --git a/tests/test_utils_deprecation.py b/tests/test_utils_deprecation.py
index dfd4a28717..ccaf6f78f0 100644
--- a/tests/test_utils_deprecation.py
+++ b/tests/test_utils_deprecation.py
@@ -52,6 +52,9 @@ def dummy_b_c_deprecated(a, b="b", c="c"):
dummy_b_c_deprecated("A")
+ dummy_b_c_deprecated("A", b="b")
+ dummy_b_c_deprecated("A", b="b", c="c")
+
with pytest.warns(FutureWarning):
dummy_c_deprecated("A", "B", "C")
@@ -79,7 +82,7 @@ def dummy_deprecated_default_message(a: str = "a") -> None:
# Default message
with pytest.warns(FutureWarning) as record:
- dummy_deprecated_default_message(a="a")
+ dummy_deprecated_default_message(a="A")
self.assertEqual(len(record), 1)
self.assertEqual(
record[0].message.args[0],
@@ -100,7 +103,7 @@ def dummy_deprecated_custom_message(a: str = "a") -> None:
# Custom message
with pytest.warns(FutureWarning) as record:
- dummy_deprecated_custom_message(a="a")
+ dummy_deprecated_custom_message(a="A")
self.assertEqual(len(record), 1)
self.assertEqual(
record[0].message.args[0],
diff --git a/tests/test_utils_headers.py b/tests/test_utils_headers.py
new file mode 100644
index 0000000000..b8cb145757
--- /dev/null
+++ b/tests/test_utils_headers.py
@@ -0,0 +1,64 @@
+import unittest
+from unittest.mock import Mock, patch
+
+from huggingface_hub.utils._headers import build_hf_headers
+
+from .testing_utils import handle_injection
+
+
+FAKE_TOKEN = "123456789"
+FAKE_TOKEN_ORG = "api_org_123456789"
+FAKE_TOKEN_HEADER = {"authorization": f"Bearer {FAKE_TOKEN}"}
+
+
+@patch("huggingface_hub.utils._headers.HfFolder")
+@handle_injection
+class TestAuthHeadersUtil(unittest.TestCase):
+ def test_use_auth_token_str(self) -> None:
+ self.assertEqual(build_hf_headers(use_auth_token=FAKE_TOKEN), FAKE_TOKEN_HEADER)
+
+ def test_use_auth_token_true_no_cached_token(self, mock_HfFolder: Mock) -> None:
+ mock_HfFolder().get_token.return_value = None
+ with self.assertRaises(EnvironmentError):
+ build_hf_headers(use_auth_token=True)
+
+ def test_use_auth_token_true_has_cached_token(self, mock_HfFolder: Mock) -> None:
+ mock_HfFolder().get_token.return_value = FAKE_TOKEN
+ self.assertEqual(build_hf_headers(use_auth_token=True), FAKE_TOKEN_HEADER)
+
+ def test_use_auth_token_false(self, mock_HfFolder: Mock) -> None:
+ mock_HfFolder().get_token.return_value = FAKE_TOKEN
+ self.assertEqual(build_hf_headers(use_auth_token=False), {})
+
+ def test_use_auth_token_none_no_cached_token(self, mock_HfFolder: Mock) -> None:
+ mock_HfFolder().get_token.return_value = None
+ self.assertEqual(build_hf_headers(), {})
+
+ def test_use_auth_token_none_has_cached_token(self, mock_HfFolder: Mock) -> None:
+ mock_HfFolder().get_token.return_value = FAKE_TOKEN
+ self.assertEqual(build_hf_headers(), FAKE_TOKEN_HEADER)
+
+ def test_write_action_org_token(self) -> None:
+ with self.assertRaises(ValueError):
+ build_hf_headers(use_auth_token=FAKE_TOKEN_ORG, is_write_action=True)
+
+ def test_write_action_none_token(self, mock_HfFolder: Mock) -> None:
+ mock_HfFolder().get_token.return_value = None
+ with self.assertRaises(ValueError):
+ build_hf_headers(is_write_action=True)
+
+ def test_write_action_use_auth_token_false(self) -> None:
+ with self.assertRaises(ValueError):
+ build_hf_headers(use_auth_token=False, is_write_action=True)
+
+ @patch.dict("os.environ", {"HF_HUB_DISABLE_IMPLICIT_TOKEN": "1"})
+ def test_implicit_use_disabled(self, mock_HfFolder: Mock) -> None:
+ mock_HfFolder().get_token.return_value = FAKE_TOKEN
+ self.assertEqual(build_hf_headers(), {}) # token is not sent
+
+ @patch.dict("os.environ", {"HF_HUB_DISABLE_IMPLICIT_TOKEN": "1"})
+ def test_implicit_use_disabled_but_explicit_use(self, mock_HfFolder: Mock) -> None:
+ mock_HfFolder().get_token.return_value = FAKE_TOKEN
+
+ # This is not an implicit use so we still send it
+ self.assertEqual(build_hf_headers(use_auth_token=True), FAKE_TOKEN_HEADER)
diff --git a/tests/testing_utils.py b/tests/testing_utils.py
index a1fb9ea15c..db698f8796 100644
--- a/tests/testing_utils.py
+++ b/tests/testing_utils.py
@@ -369,49 +369,63 @@ def test_hello_both(self, mock_foo: Mock, mock_bar: Mock) -> None:
NOTE: this decorator is inspired from the fixture system from pytest.
"""
-
- def _test_decorator(fn: Callable) -> Callable:
- signature = inspect.signature(fn)
- parameters = signature.parameters
-
- @wraps(fn)
- def _inner(*args, **kwargs):
- assert kwargs == {}
-
- # Initialize new dict at least with `self`.
- assert len(args) > 0
- assert len(parameters) > 0
- new_kwargs = {"self": args[0]}
-
- # Check which mocks have been injected
- mocks = {}
- for value in args[1:]:
- assert isinstance(value, Mock)
- mock_name = "mock_" + value._extract_mock_name()
- mocks[mock_name] = value
-
- # Check which mocks are expected
- for name, parameter in parameters.items():
- if name == "self":
- continue
- assert parameter.annotation is Mock
- assert name in mocks, (
- f"Mock `{name}` not found for test `{fn.__name__}`. Available:"
- f" {', '.join(sorted(mocks.keys()))}"
- )
- new_kwargs[name] = mocks[name]
-
- # Run test only with a subset of mocks
- return fn(**new_kwargs)
-
- return _inner
-
# Iterate over class functions and decorate tests
# Taken from https://stackoverflow.com/a/3467879
# and https://stackoverflow.com/a/30764825
for name, fn in inspect.getmembers(cls):
if name.startswith("test_"):
- setattr(cls, name, _test_decorator(fn))
+ setattr(cls, name, handle_injection_in_test(fn))
# Return decorated class
return cls
+
+
+def handle_injection_in_test(fn: Callable) -> Callable:
+ """
+ Handle injections at a test level. See `handle_injection` for more details.
+
+ Example:
+ ```py
+ def TestHelloWorld(unittest.TestCase):
+
+ @patch("something.foo")
+ @patch("something_else.foo.bar") # order doesn't matter
+ @handle_injection_in_test # after @patch calls
+ def test_hello_foo(self, mock_foo: Mock) -> None:
+ (...)
+ ```
+ """
+ signature = inspect.signature(fn)
+ parameters = signature.parameters
+
+ @wraps(fn)
+ def _inner(*args, **kwargs):
+ assert kwargs == {}
+
+ # Initialize new dict at least with `self`.
+ assert len(args) > 0
+ assert len(parameters) > 0
+ new_kwargs = {"self": args[0]}
+
+ # Check which mocks have been injected
+ mocks = {}
+ for value in args[1:]:
+ assert isinstance(value, Mock)
+ mock_name = "mock_" + value._extract_mock_name()
+ mocks[mock_name] = value
+
+ # Check which mocks are expected
+ for name, parameter in parameters.items():
+ if name == "self":
+ continue
+ assert parameter.annotation is Mock
+ assert name in mocks, (
+ f"Mock `{name}` not found for test `{fn.__name__}`. Available:"
+ f" {', '.join(sorted(mocks.keys()))}"
+ )
+ new_kwargs[name] = mocks[name]
+
+ # Run test only with a subset of mocks
+ return fn(**new_kwargs)
+
+ return _inner