Skip to content

Commit

Permalink
Fix Inference Client VCR tests (#2858)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanouticelina authored Feb 14, 2025
1 parent 7553646 commit cd85541
Show file tree
Hide file tree
Showing 36 changed files with 51,831 additions and 101,305 deletions.
12 changes: 10 additions & 2 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ jobs:
[
"Repository only",
"Everything else",
"Inference only"

]
include:
Expand Down Expand Up @@ -64,7 +65,7 @@ jobs:
case "${{ matrix.test_name }}" in
"Repository only" | "Everything else")
"Repository only" | "Everything else" | "Inference only")
sudo apt update
sudo apt install -y libsndfile1-dev
;;
Expand Down Expand Up @@ -112,8 +113,15 @@ jobs:
eval $PYTEST
;;
"Inference only")
# Run inference tests concurrently
PYTEST="$PYTEST ../tests -k 'test_inference' -n 4"
echo $PYTEST
eval $PYTEST
;;
"Everything else")
PYTEST="$PYTEST ../tests -k 'not TestRepository' -n 4"
PYTEST="$PYTEST ../tests -k 'not TestRepository and not test_inference' -n 4"
echo $PYTEST
eval $PYTEST
;;
Expand Down
12 changes: 6 additions & 6 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

from requests import HTTPError

from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS
from huggingface_hub import constants
from huggingface_hub.errors import BadRequestError, InferenceTimeoutError
from huggingface_hub.inference._common import (
TASKS_EXPECTING_IMAGES,
Expand Down Expand Up @@ -3300,9 +3300,9 @@ def list_deployed_models(

# Resolve which frameworks to check
if frameworks is None:
frameworks = MAIN_INFERENCE_API_FRAMEWORKS
frameworks = constants.MAIN_INFERENCE_API_FRAMEWORKS
elif frameworks == "all":
frameworks = ALL_INFERENCE_API_FRAMEWORKS
frameworks = constants.ALL_INFERENCE_API_FRAMEWORKS
elif isinstance(frameworks, str):
frameworks = [frameworks]
frameworks = list(set(frameworks))
Expand All @@ -3322,7 +3322,7 @@ def _unpack_response(framework: str, items: List[Dict]) -> None:

for framework in frameworks:
response = get_session().get(
f"{INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token)
f"{constants.INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token)
)
hf_raise_for_status(response)
_unpack_response(framework, response.json())
Expand Down Expand Up @@ -3384,7 +3384,7 @@ def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]:
if model.startswith(("http://", "https://")):
url = model.rstrip("/") + "/info"
else:
url = f"{INFERENCE_ENDPOINT}/models/{model}/info"
url = f"{constants.INFERENCE_ENDPOINT}/models/{model}/info"

response = get_session().get(url, headers=build_hf_headers(token=self.token))
hf_raise_for_status(response)
Expand Down Expand Up @@ -3472,7 +3472,7 @@ def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
raise ValueError("Model id not provided.")
if model.startswith("https://"):
raise NotImplementedError("Model status is only available for Inference API endpoints.")
url = f"{INFERENCE_ENDPOINT}/status/{model}"
url = f"{constants.INFERENCE_ENDPOINT}/status/{model}"

response = get_session().get(url, headers=build_hf_headers(token=self.token))
hf_raise_for_status(response)
Expand Down
12 changes: 6 additions & 6 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import warnings
from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, Optional, Set, Union, overload

from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS
from huggingface_hub import constants
from huggingface_hub.errors import InferenceTimeoutError
from huggingface_hub.inference._common import (
TASKS_EXPECTING_IMAGES,
Expand Down Expand Up @@ -3365,9 +3365,9 @@ async def list_deployed_models(

# Resolve which frameworks to check
if frameworks is None:
frameworks = MAIN_INFERENCE_API_FRAMEWORKS
frameworks = constants.MAIN_INFERENCE_API_FRAMEWORKS
elif frameworks == "all":
frameworks = ALL_INFERENCE_API_FRAMEWORKS
frameworks = constants.ALL_INFERENCE_API_FRAMEWORKS
elif isinstance(frameworks, str):
frameworks = [frameworks]
frameworks = list(set(frameworks))
Expand All @@ -3387,7 +3387,7 @@ def _unpack_response(framework: str, items: List[Dict]) -> None:

for framework in frameworks:
response = get_session().get(
f"{INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token)
f"{constants.INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token)
)
hf_raise_for_status(response)
_unpack_response(framework, response.json())
Expand Down Expand Up @@ -3491,7 +3491,7 @@ async def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, A
if model.startswith(("http://", "https://")):
url = model.rstrip("/") + "/info"
else:
url = f"{INFERENCE_ENDPOINT}/models/{model}/info"
url = f"{constants.INFERENCE_ENDPOINT}/models/{model}/info"

async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client:
response = await client.get(url, proxy=self.proxies)
Expand Down Expand Up @@ -3583,7 +3583,7 @@ async def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
raise ValueError("Model id not provided.")
if model.startswith("https://"):
raise NotImplementedError("Model status is only available for Inference API endpoints.")
url = f"{INFERENCE_ENDPOINT}/status/{model}"
url = f"{constants.INFERENCE_ENDPOINT}/status/{model}"

async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client:
response = await client.get(url, proxy=self.proxies)
Expand Down
4 changes: 2 additions & 2 deletions src/huggingface_hub/inference/_providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@
"image-classification": HFInferenceBinaryInputTask("image-classification"),
"image-segmentation": HFInferenceBinaryInputTask("image-segmentation"),
"document-question-answering": HFInferenceTask("document-question-answering"),
"image-to-text": HFInferenceTask("image-to-text"),
"image-to-text": HFInferenceBinaryInputTask("image-to-text"),
"object-detection": HFInferenceBinaryInputTask("object-detection"),
"audio-to-audio": HFInferenceTask("audio-to-audio"),
"audio-to-audio": HFInferenceBinaryInputTask("audio-to-audio"),
"zero-shot-image-classification": HFInferenceBinaryInputTask("zero-shot-image-classification"),
"zero-shot-classification": HFInferenceTask("zero-shot-classification"),
"image-to-image": HFInferenceBinaryInputTask("image-to-image"),
Expand Down
17 changes: 9 additions & 8 deletions src/huggingface_hub/inference/_providers/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# Example:
# "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
"fal-ai": {},
"fireworks-ai": {},
"hf-inference": {},
"replicate": {},
"sambanova": {},
Expand Down Expand Up @@ -65,12 +66,12 @@ def prepare_request(
url = self._prepare_url(api_key, mapped_model)

# prepare payload (to customize in subclasses)
payload = self._prepare_payload(inputs, parameters, mapped_model=mapped_model)
payload = self._prepare_payload_as_dict(inputs, parameters, mapped_model=mapped_model)
if payload is not None:
payload = recursive_merge(payload, extra_payload or {})

# body data (to customize in subclasses)
data = self._prepare_body(inputs, parameters, mapped_model, extra_payload)
data = self._prepare_payload_as_bytes(inputs, parameters, mapped_model, extra_payload)

# check if both payload and data are set and return
if payload is not None and data is not None:
Expand Down Expand Up @@ -159,21 +160,21 @@ def _prepare_route(self, mapped_model: str) -> str:
"""
return ""

def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
"""Return the payload to use for the request, as a dict.
Override this method in subclasses for customized payloads.
Only one of `_prepare_payload` and `_prepare_body` should return a value.
Only one of `_prepare_payload_as_dict` and `_prepare_payload_as_bytes` should return a value.
"""
return None

def _prepare_body(
def _prepare_payload_as_bytes(
self, inputs: Any, parameters: Dict, mapped_model: str, extra_payload: Optional[Dict]
) -> Optional[bytes]:
"""Return the body to use for the request, as bytes.
Override this method in subclasses for customized body data.
Only one of `_prepare_payload` and `_prepare_body` should return a value.
Only one of `_prepare_payload_as_dict` and `_prepare_payload_as_bytes` should return a value.
"""
return None

Expand All @@ -183,9 +184,9 @@ def _fetch_inference_provider_mapping(model: str) -> Dict:
"""
Fetch provider mappings for a model from the Hub.
"""
from huggingface_hub.hf_api import model_info
from huggingface_hub.hf_api import HfApi

info = model_info(model, expand=["inferenceProviderMapping"])
info = HfApi().model_info(model, expand=["inferenceProviderMapping"])
provider_mapping = info.inference_provider_mapping
if provider_mapping is None:
raise ValueError(f"No provider mapping found for model {model}")
Expand Down
8 changes: 4 additions & 4 deletions src/huggingface_hub/inference/_providers/fal_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class FalAIAutomaticSpeechRecognitionTask(FalAITask):
def __init__(self):
super().__init__("automatic-speech-recognition")

def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
if isinstance(inputs, str) and inputs.startswith(("http://", "https://")):
# If input is a URL, pass it directly
audio_url = inputs
Expand All @@ -52,7 +52,7 @@ class FalAITextToImageTask(FalAITask):
def __init__(self):
super().__init__("text-to-image")

def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
parameters = filter_none(parameters)
if "width" in parameters and "height" in parameters:
parameters["image_size"] = {
Expand All @@ -70,7 +70,7 @@ class FalAITextToSpeechTask(FalAITask):
def __init__(self):
super().__init__("text-to-speech")

def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
return {"lyrics": inputs, **filter_none(parameters)}

def get_response(self, response: Union[bytes, Dict]) -> Any:
Expand All @@ -82,7 +82,7 @@ class FalAITextToVideoTask(FalAITask):
def __init__(self):
super().__init__("text-to-video")

def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
return {"prompt": inputs, **filter_none(parameters)}

def get_response(self, response: Union[bytes, Dict]) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion src/huggingface_hub/inference/_providers/fireworks_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ def __init__(self):
def _prepare_route(self, mapped_model: str) -> str:
return "/v1/chat/completions"

def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
return {"messages": inputs, **filter_none(parameters), "model": mapped_model}
9 changes: 6 additions & 3 deletions src/huggingface_hub/inference/_providers/hf_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _prepare_url(self, api_key: str, mapped_model: str) -> str:
else f"{self.base_url}/models/{mapped_model}"
)

def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
if isinstance(inputs, bytes):
raise ValueError(f"Unexpected binary input for task {self.task}.")
if isinstance(inputs, Path):
Expand All @@ -55,7 +55,10 @@ def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) ->


class HFInferenceBinaryInputTask(HFInferenceTask):
def _prepare_body(
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
return None

def _prepare_payload_as_bytes(
self, inputs: Any, parameters: Dict, mapped_model: str, extra_payload: Optional[Dict]
) -> Optional[bytes]:
parameters = filter_none({k: v for k, v in parameters.items() if v is not None})
Expand All @@ -80,7 +83,7 @@ class HFInferenceConversational(HFInferenceTask):
def __init__(self):
super().__init__("text-generation")

def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
payload_model = "tgi" if mapped_model.startswith(("http://", "https://")) else mapped_model
return {**filter_none(parameters), "model": payload_model, "messages": inputs}

Expand Down
15 changes: 8 additions & 7 deletions src/huggingface_hub/inference/_providers/new_provider.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Before adding a new provider to the `huggingface_hub` library, make sure it has

Create a new file under `src/huggingface_hub/inference/_providers/{provider_name}.py` and copy-paste the following snippet.

Implement the methods that require custom handling. Check out the base implementation to check default behavior. If you don't need to override a method, just remove it. At least one of `_prepare_payload` or `_prepare_body` must be overwritten.
Implement the methods that require custom handling. Check out the base implementation to check default behavior. If you don't need to override a method, just remove it. At least one of `_prepare_payload_as_dict` or `_prepare_payload_as_bytes` must be overwritten.

If the provider supports multiple tasks that require different implementations, create dedicated subclasses for each task, following the pattern shown in `fal_ai.py`.

Expand Down Expand Up @@ -42,23 +42,24 @@ class MyNewProviderTaskProviderHelper(TaskProviderHelper):
"""
return super()._prepare_route(mapped_model)

def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
"""Return the payload to use for the request, as a dict.
Override this method in subclasses for customized payloads.
Only one of `_prepare_payload` and `_prepare_body` should return a value.
Only one of `_prepare_payload_as_dict` and `_prepare_payload_as_bytes` should return a value.
"""
return super()._prepare_payload(inputs, parameters, mapped_model)
return super()._prepare_payload_as_dict(inputs, parameters, mapped_model)

def _prepare_body(
def _prepare_payload_as_bytes(
self, inputs: Any, parameters: Dict, mapped_model: str, extra_payload: Optional[Dict]
) -> Optional[bytes]:
"""Return the body to use for the request, as bytes.
Override this method in subclasses for customized body data.
Only one of `_prepare_payload` and `_prepare_body` should return a value.
Only one of `_prepare_payload_as_dict` and `_prepare_payload_as_bytes` should return a value.
"""
return super()._prepare_body(inputs, parameters, mapped_model, extra_payload)
return super()._prepare_payload_as_bytes(inputs, parameters, mapped_model, extra_payload)

```

### 2. Register the provider helper in `__init__.py`
Expand Down
6 changes: 3 additions & 3 deletions src/huggingface_hub/inference/_providers/replicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def _prepare_route(self, mapped_model: str) -> str:
return "/v1/predictions"
return f"/v1/models/{mapped_model}/predictions"

def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
payload: Dict[str, Any] = {"input": {"prompt": inputs, **filter_none(parameters)}}
if ":" in mapped_model:
version = mapped_model.split(":", 1)[1]
Expand All @@ -43,7 +43,7 @@ class ReplicateTextToSpeechTask(ReplicateTask):
def __init__(self):
super().__init__("text-to-speech")

def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
payload: Dict = super()._prepare_payload(inputs, parameters, mapped_model) # type: ignore[assignment]
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
payload: Dict = super()._prepare_payload_as_dict(inputs, parameters, mapped_model) # type: ignore[assignment]
payload["input"]["text"] = payload["input"].pop("prompt") # rename "prompt" to "text" for TTS
return payload
2 changes: 1 addition & 1 deletion src/huggingface_hub/inference/_providers/sambanova.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ def __init__(self):
def _prepare_route(self, mapped_model: str) -> str:
return "/v1/chat/completions"

def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
return {"messages": inputs, **filter_none(parameters), "model": mapped_model}
4 changes: 2 additions & 2 deletions src/huggingface_hub/inference/_providers/together.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ def _prepare_route(self, mapped_model: str) -> str:

class TogetherTextGenerationTask(TogetherTask):
# Handle both "text-generation" and "conversational"
def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
return {"messages": inputs, **filter_none(parameters), "model": mapped_model}


class TogetherTextToImageTask(TogetherTask):
def __init__(self):
super().__init__("text-to-image")

def _prepare_payload(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
parameters = filter_none(parameters)
if "num_inference_steps" in parameters:
parameters["steps"] = parameters.pop("num_inference_steps")
Expand Down

Large diffs are not rendered by default.

Loading

0 comments on commit cd85541

Please sign in to comment.