Skip to content

Commit

Permalink
feat: Add general and HF util methods (#6200)
Browse files Browse the repository at this point in the history
* Add general and hf util methods
  • Loading branch information
vblagoje authored Oct 31, 2023
1 parent 431902b commit c51aa1e
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 0 deletions.
50 changes: 50 additions & 0 deletions haystack/preview/components/generators/hf_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import inspect
from typing import Any, Dict, List, Optional

from huggingface_hub import InferenceClient, HfApi
from huggingface_hub.utils import RepositoryNotFoundError


def check_generation_params(kwargs: Dict[str, Any], additional_accepted_params: Optional[List[str]] = None):
"""
Check the provided generation parameters for validity.
:param kwargs: A dictionary containing the generation parameters.
:param additional_accepted_params: An optional list of strings representing additional accepted parameters.
:raises ValueError: If any unknown text generation parameters are provided.
"""
if kwargs:
accepted_params = {
param
for param in inspect.signature(InferenceClient.text_generation).parameters.keys()
if param not in ["self", "prompt"]
}
if additional_accepted_params:
accepted_params.update(additional_accepted_params)
unknown_params = set(kwargs.keys()) - accepted_params
if unknown_params:
raise ValueError(
f"Unknown text generation parameters: {unknown_params}. The valid parameters are: {accepted_params}."
)


def check_valid_model(model_id: str, token: Optional[str]) -> None:
"""
Check if the provided model ID corresponds to a valid model on HuggingFace Hub.
Also check if the model is a text generation model.
:param model_id: A string representing the HuggingFace model ID.
:param token: An optional string representing the authentication token.
:raises ValueError: If the model is not found or is not a text generation model.
"""
api = HfApi()
try:
model_info = api.model_info(model_id, token=token)
except RepositoryNotFoundError as e:
raise ValueError(
f"Model {model_id} not found on HuggingFace Hub. Please provide a valid HuggingFace model_id."
) from e

allowed_model = model_info.pipeline_tag in ["text-generation", "text2text-generation"]
if not allowed_model:
raise ValueError(f"Model {model_id} is not a text generation model. Please provide a text generation model.")
41 changes: 41 additions & 0 deletions haystack/preview/components/generators/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import inspect
import sys
from typing import Optional, Callable

from haystack.preview import DeserializationError
from haystack.preview.dataclasses import StreamingChunk


def serialize_callback_handler(streaming_callback: Callable[[StreamingChunk], None]) -> str:
"""
Serializes the streaming callback handler.
:param streaming_callback: The streaming callback handler function
:return: The full path of the streaming callback handler function
"""
module = inspect.getmodule(streaming_callback)

# Get the full package path of the function
if module is not None:
full_path = f"{module.__name__}.{streaming_callback.__name__}"
else:
full_path = streaming_callback.__name__
return full_path


def deserialize_callback_handler(callback_name: str) -> Optional[Callable[[StreamingChunk], None]]:
"""
Deserializes the streaming callback handler.
:param callback_name: The full path of the streaming callback handler function
:return: The streaming callback handler function
:raises DeserializationError: If the streaming callback handler function cannot be found
"""
parts = callback_name.split(".")
module_name = ".".join(parts[:-1])
function_name = parts[-1]
module = sys.modules.get(module_name, None)
if not module:
raise DeserializationError(f"Could not locate the module of the streaming callback: {module_name}")
streaming_callback = getattr(module, function_name, None)
if not streaming_callback:
raise DeserializationError(f"Could not locate the streaming callback: {function_name}")
return streaming_callback
50 changes: 50 additions & 0 deletions test/preview/components/generators/test_hf_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest

from haystack.preview.components.generators.hf_utils import check_generation_params


@pytest.mark.unit
def test_empty_dictionary():
# no exception raised
check_generation_params({})


@pytest.mark.unit
def test_valid_generation_parameters():
# these are valid parameters
kwargs = {"max_new_tokens": 100, "temperature": 0.8}
additional_accepted_params = None
check_generation_params(kwargs, additional_accepted_params)


@pytest.mark.unit
def test_invalid_generation_parameters():
# these are invalid parameters
kwargs = {"invalid_param": "value"}
additional_accepted_params = None
with pytest.raises(ValueError):
check_generation_params(kwargs, additional_accepted_params)


@pytest.mark.unit
def test_additional_accepted_params_empty_list():
kwargs = {"temperature": 0.8}
additional_accepted_params = []
check_generation_params(kwargs, additional_accepted_params)


@pytest.mark.unit
def test_additional_accepted_params_known_parameter():
# both are valid parameters
kwargs = {"temperature": 0.8}
additional_accepted_params = ["max_new_tokens"]
check_generation_params(kwargs, additional_accepted_params)


@pytest.mark.unit
def test_additional_accepted_params_unknown_parameter():
kwargs = {"strange_param": "value"}
additional_accepted_params = ["strange_param"]
# Although strange_param is not generation param the check_generation_params
# does not raise exception because strange_param is passed as additional_accepted_params
check_generation_params(kwargs, additional_accepted_params)
37 changes: 37 additions & 0 deletions test/preview/components/generators/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest

from haystack.preview.components.generators.openai.gpt import default_streaming_callback
from haystack.preview.components.generators.utils import serialize_callback_handler, deserialize_callback_handler


# streaming callback needs to be on module level
def streaming_callback(chunk):
pass


@pytest.mark.unit
def test_callback_handler_serialization():
result = serialize_callback_handler(streaming_callback)
assert result == "test_utils.streaming_callback"


@pytest.mark.unit
def test_callback_handler_serialization_non_local():
result = serialize_callback_handler(default_streaming_callback)
assert result == "haystack.preview.components.generators.openai.gpt.default_streaming_callback"


@pytest.mark.unit
def test_callback_handler_deserialization():
result = serialize_callback_handler(streaming_callback)
fn = deserialize_callback_handler(result)

assert fn is streaming_callback


@pytest.mark.unit
def test_callback_handler_deserialization_non_local():
result = serialize_callback_handler(default_streaming_callback)
fn = deserialize_callback_handler(result)

assert fn is default_streaming_callback

0 comments on commit c51aa1e

Please sign in to comment.