-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add general and HF util methods (#6200)
* Add general and hf util methods
- Loading branch information
Showing
4 changed files
with
178 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import inspect | ||
from typing import Any, Dict, List, Optional | ||
|
||
from huggingface_hub import InferenceClient, HfApi | ||
from huggingface_hub.utils import RepositoryNotFoundError | ||
|
||
|
||
def check_generation_params(kwargs: Dict[str, Any], additional_accepted_params: Optional[List[str]] = None): | ||
""" | ||
Check the provided generation parameters for validity. | ||
:param kwargs: A dictionary containing the generation parameters. | ||
:param additional_accepted_params: An optional list of strings representing additional accepted parameters. | ||
:raises ValueError: If any unknown text generation parameters are provided. | ||
""" | ||
if kwargs: | ||
accepted_params = { | ||
param | ||
for param in inspect.signature(InferenceClient.text_generation).parameters.keys() | ||
if param not in ["self", "prompt"] | ||
} | ||
if additional_accepted_params: | ||
accepted_params.update(additional_accepted_params) | ||
unknown_params = set(kwargs.keys()) - accepted_params | ||
if unknown_params: | ||
raise ValueError( | ||
f"Unknown text generation parameters: {unknown_params}. The valid parameters are: {accepted_params}." | ||
) | ||
|
||
|
||
def check_valid_model(model_id: str, token: Optional[str]) -> None: | ||
""" | ||
Check if the provided model ID corresponds to a valid model on HuggingFace Hub. | ||
Also check if the model is a text generation model. | ||
:param model_id: A string representing the HuggingFace model ID. | ||
:param token: An optional string representing the authentication token. | ||
:raises ValueError: If the model is not found or is not a text generation model. | ||
""" | ||
api = HfApi() | ||
try: | ||
model_info = api.model_info(model_id, token=token) | ||
except RepositoryNotFoundError as e: | ||
raise ValueError( | ||
f"Model {model_id} not found on HuggingFace Hub. Please provide a valid HuggingFace model_id." | ||
) from e | ||
|
||
allowed_model = model_info.pipeline_tag in ["text-generation", "text2text-generation"] | ||
if not allowed_model: | ||
raise ValueError(f"Model {model_id} is not a text generation model. Please provide a text generation model.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import inspect | ||
import sys | ||
from typing import Optional, Callable | ||
|
||
from haystack.preview import DeserializationError | ||
from haystack.preview.dataclasses import StreamingChunk | ||
|
||
|
||
def serialize_callback_handler(streaming_callback: Callable[[StreamingChunk], None]) -> str: | ||
""" | ||
Serializes the streaming callback handler. | ||
:param streaming_callback: The streaming callback handler function | ||
:return: The full path of the streaming callback handler function | ||
""" | ||
module = inspect.getmodule(streaming_callback) | ||
|
||
# Get the full package path of the function | ||
if module is not None: | ||
full_path = f"{module.__name__}.{streaming_callback.__name__}" | ||
else: | ||
full_path = streaming_callback.__name__ | ||
return full_path | ||
|
||
|
||
def deserialize_callback_handler(callback_name: str) -> Optional[Callable[[StreamingChunk], None]]: | ||
""" | ||
Deserializes the streaming callback handler. | ||
:param callback_name: The full path of the streaming callback handler function | ||
:return: The streaming callback handler function | ||
:raises DeserializationError: If the streaming callback handler function cannot be found | ||
""" | ||
parts = callback_name.split(".") | ||
module_name = ".".join(parts[:-1]) | ||
function_name = parts[-1] | ||
module = sys.modules.get(module_name, None) | ||
if not module: | ||
raise DeserializationError(f"Could not locate the module of the streaming callback: {module_name}") | ||
streaming_callback = getattr(module, function_name, None) | ||
if not streaming_callback: | ||
raise DeserializationError(f"Could not locate the streaming callback: {function_name}") | ||
return streaming_callback |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import pytest | ||
|
||
from haystack.preview.components.generators.hf_utils import check_generation_params | ||
|
||
|
||
@pytest.mark.unit | ||
def test_empty_dictionary(): | ||
# no exception raised | ||
check_generation_params({}) | ||
|
||
|
||
@pytest.mark.unit | ||
def test_valid_generation_parameters(): | ||
# these are valid parameters | ||
kwargs = {"max_new_tokens": 100, "temperature": 0.8} | ||
additional_accepted_params = None | ||
check_generation_params(kwargs, additional_accepted_params) | ||
|
||
|
||
@pytest.mark.unit | ||
def test_invalid_generation_parameters(): | ||
# these are invalid parameters | ||
kwargs = {"invalid_param": "value"} | ||
additional_accepted_params = None | ||
with pytest.raises(ValueError): | ||
check_generation_params(kwargs, additional_accepted_params) | ||
|
||
|
||
@pytest.mark.unit | ||
def test_additional_accepted_params_empty_list(): | ||
kwargs = {"temperature": 0.8} | ||
additional_accepted_params = [] | ||
check_generation_params(kwargs, additional_accepted_params) | ||
|
||
|
||
@pytest.mark.unit | ||
def test_additional_accepted_params_known_parameter(): | ||
# both are valid parameters | ||
kwargs = {"temperature": 0.8} | ||
additional_accepted_params = ["max_new_tokens"] | ||
check_generation_params(kwargs, additional_accepted_params) | ||
|
||
|
||
@pytest.mark.unit | ||
def test_additional_accepted_params_unknown_parameter(): | ||
kwargs = {"strange_param": "value"} | ||
additional_accepted_params = ["strange_param"] | ||
# Although strange_param is not generation param the check_generation_params | ||
# does not raise exception because strange_param is passed as additional_accepted_params | ||
check_generation_params(kwargs, additional_accepted_params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import pytest | ||
|
||
from haystack.preview.components.generators.openai.gpt import default_streaming_callback | ||
from haystack.preview.components.generators.utils import serialize_callback_handler, deserialize_callback_handler | ||
|
||
|
||
# streaming callback needs to be on module level | ||
def streaming_callback(chunk): | ||
pass | ||
|
||
|
||
@pytest.mark.unit | ||
def test_callback_handler_serialization(): | ||
result = serialize_callback_handler(streaming_callback) | ||
assert result == "test_utils.streaming_callback" | ||
|
||
|
||
@pytest.mark.unit | ||
def test_callback_handler_serialization_non_local(): | ||
result = serialize_callback_handler(default_streaming_callback) | ||
assert result == "haystack.preview.components.generators.openai.gpt.default_streaming_callback" | ||
|
||
|
||
@pytest.mark.unit | ||
def test_callback_handler_deserialization(): | ||
result = serialize_callback_handler(streaming_callback) | ||
fn = deserialize_callback_handler(result) | ||
|
||
assert fn is streaming_callback | ||
|
||
|
||
@pytest.mark.unit | ||
def test_callback_handler_deserialization_non_local(): | ||
result = serialize_callback_handler(default_streaming_callback) | ||
fn = deserialize_callback_handler(result) | ||
|
||
assert fn is default_streaming_callback |