diff --git a/autogen/oai/client_utils.py b/autogen/oai/client_utils.py new file mode 100644 index 00000000000..4e3870878b4 --- /dev/null +++ b/autogen/oai/client_utils.py @@ -0,0 +1,99 @@ +"""Utilities for client classes""" + +import warnings +from typing import Any, Dict, List, Optional, Tuple + + +def validate_parameter( + params: Dict[str, Any], + param_name: str, + allowed_types: Tuple, + allow_None: bool, + default_value: Any, + numerical_bound: Tuple, + allowed_values: list, +) -> Any: + """ + Validates a given config parameter, checking its type, values, and setting defaults + Parameters: + params (Dict[str, Any]): Dictionary containing parameters to validate. + param_name (str): The name of the parameter to validate. + allowed_types (Tuple): Tuple of acceptable types for the parameter. + allow_None (bool): Whether the parameter can be `None`. + default_value (Any): The default value to use if the parameter is invalid or missing. + numerical_bound (Optional[Tuple[Optional[float], Optional[float]]]): + A tuple specifying the lower and upper bounds for numerical parameters. + Each bound can be `None` if not applicable. + allowed_values (Optional[List[Any]]): A list of acceptable values for the parameter. + Can be `None` if no specific values are required. + + Returns: + Any: The validated parameter value or the default value if validation fails. + + Raises: + TypeError: If `allowed_values` is provided but is not a list. + + Example Usage: + ```python + # Validating a numerical parameter within specific bounds + params = {"temperature": 0.5, "safety_model": "Meta-Llama/Llama-Guard-7b"} + temperature = validate_parameter(params, "temperature", (int, float), True, 0.7, (0, 1), None) + # Result: 0.5 + + # Validating a parameter that can be one of a list of allowed values + model = validate_parameter( + params, "safety_model", str, True, None, None, ["Meta-Llama/Llama-Guard-7b", "Meta-Llama/Llama-Guard-13b"] + ) + # If "safety_model" is missing or invalid in params, defaults to "default" + ``` + """ + + if allowed_values is not None and not isinstance(allowed_values, list): + raise TypeError(f"allowed_values should be a list or None, got {type(allowed_values).__name__}") + + param_value = params.get(param_name, default_value) + warning = "" + + if param_value is None and allow_None: + pass + elif param_value is None: + if not allow_None: + warning = "cannot be None" + elif not isinstance(param_value, allowed_types): + # Check types and list possible types if invalid + if isinstance(allowed_types, tuple): + formatted_types = "(" + ", ".join(f"{t.__name__}" for t in allowed_types) + ")" + else: + formatted_types = f"{allowed_types.__name__}" + warning = f"must be of type {formatted_types}{' or None' if allow_None else ''}" + elif numerical_bound: + # Check the value fits in possible bounds + lower_bound, upper_bound = numerical_bound + if (lower_bound is not None and param_value < lower_bound) or ( + upper_bound is not None and param_value > upper_bound + ): + warning = "has numerical bounds" + if lower_bound is not None: + warning += f", >= {str(lower_bound)}" + if upper_bound is not None: + if lower_bound is not None: + warning += " and" + warning += f" <= {str(upper_bound)}" + if allow_None: + warning += ", or can be None" + + elif allowed_values: + # Check if the value matches any allowed values + if not (allow_None and param_value is None): + if param_value not in allowed_values: + warning = f"must be one of these values [{allowed_values}]{', or can be None' if allow_None else ''}" + + # If we failed any checks, warn and set to default value + if warning: + warnings.warn( + f"Config error - {param_name} {warning}, defaulting to {default_value}.", + UserWarning, + ) + param_value = default_value + + return param_value diff --git a/test/oai/test_client_utils.py b/test/oai/test_client_utils.py new file mode 100644 index 00000000000..9a060a0f2dc --- /dev/null +++ b/test/oai/test_client_utils.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python3 -m pytest + +import pytest + +import autogen +from autogen.oai.client_utils import validate_parameter + + +def test_validate_parameter(): + # Test valid parameters + params = { + "model": "Qwen/Qwen2-72B-Instruct", + "max_tokens": 1000, + "stream": False, + "temperature": 1, + "top_p": 0.8, + "top_k": 50, + "repetition_penalty": 0.5, + "presence_penalty": 1.5, + "frequency_penalty": 1.5, + "min_p": 0.2, + "safety_model": "Meta-Llama/Llama-Guard-7b", + } + + # Should return the original value as they are valid + assert params["model"] == validate_parameter(params, "model", str, False, None, None, None) + assert params["max_tokens"] == validate_parameter(params, "max_tokens", int, True, 512, (0, None), None) + assert params["stream"] == validate_parameter(params, "stream", bool, False, False, None, None) + assert params["temperature"] == validate_parameter(params, "temperature", (int, float), True, None, None, None) + assert params["top_k"] == validate_parameter(params, "top_k", int, True, None, None, None) + assert params["repetition_penalty"] == validate_parameter( + params, "repetition_penalty", float, True, None, None, None + ) + assert params["presence_penalty"] == validate_parameter( + params, "presence_penalty", (int, float), True, None, (-2, 2), None + ) + assert params["safety_model"] == validate_parameter(params, "safety_model", str, True, None, None, None) + + # Test None allowed + params = { + "max_tokens": None, + } + + # Should remain None + assert validate_parameter(params, "max_tokens", int, True, 512, (0, None), None) is None + + # Test not None allowed + params = { + "max_tokens": None, + } + + # Should return default + assert 512 == validate_parameter(params, "max_tokens", int, False, 512, (0, None), None) + + # Test invalid parameters + params = { + "stream": "Yes", + "temperature": "0.5", + "top_p": "0.8", + "top_k": "50", + "repetition_penalty": "0.5", + "presence_penalty": "1.5", + "frequency_penalty": "1.5", + "min_p": "0.2", + "safety_model": False, + } + + # Should all be set to defaults + assert validate_parameter(params, "stream", bool, False, False, None, None) is not None + assert validate_parameter(params, "temperature", (int, float), True, None, None, None) is None + assert validate_parameter(params, "top_p", (int, float), True, None, None, None) is None + assert validate_parameter(params, "top_k", int, True, None, None, None) is None + assert validate_parameter(params, "repetition_penalty", float, True, None, None, None) is None + assert validate_parameter(params, "presence_penalty", (int, float), True, None, (-2, 2), None) is None + assert validate_parameter(params, "frequency_penalty", (int, float), True, None, (-2, 2), None) is None + assert validate_parameter(params, "min_p", (int, float), True, None, (0, 1), None) is None + assert validate_parameter(params, "safety_model", str, True, None, None, None) is None + + # Test parameters outside of bounds + params = { + "max_tokens": -200, + "presence_penalty": -5, + "frequency_penalty": 5, + "min_p": -0.5, + } + + # Should all be set to defaults + assert 512 == validate_parameter(params, "max_tokens", int, True, 512, (0, None), None) + assert validate_parameter(params, "presence_penalty", (int, float), True, None, (-2, 2), None) is None + assert validate_parameter(params, "frequency_penalty", (int, float), True, None, (-2, 2), None) is None + assert validate_parameter(params, "min_p", (int, float), True, None, (0, 1), None) is None + + # Test valid list options + params = { + "safety_model": "Meta-Llama/Llama-Guard-7b", + } + + # Should all be set to defaults + assert "Meta-Llama/Llama-Guard-7b" == validate_parameter( + params, "safety_model", str, True, None, None, ["Meta-Llama/Llama-Guard-7b", "Meta-Llama/Llama-Guard-13b"] + ) + + # Test invalid list options + params = { + "stream": True, + } + + # Should all be set to defaults + assert not validate_parameter(params, "stream", bool, False, False, None, [False]) + + # test invalid type + params = { + "temperature": None, + } + + # should be set to defaults + assert validate_parameter(params, "temperature", (int, float), False, 0.7, (0.0, 1.0), None) == 0.7 + + # test value out of bounds + params = { + "temperature": 23, + } + + # should be set to defaults + assert validate_parameter(params, "temperature", (int, float), False, 1.0, (0.0, 1.0), None) == 1.0 + + # type error for the parameters + with pytest.raises(TypeError): + validate_parameter({}, "param", str, True, None, None, "not_a_list") + + # passing empty params, which will set to defaults + assert validate_parameter({}, "max_tokens", int, True, 512, (0, None), None) == 512 + + +if __name__ == "__main__": + test_validate_parameter()