-
Notifications
You must be signed in to change notification settings - Fork 5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Addition of client utilities, initially for parameter validation * Corrected test * update: type checks and few tests * fix: docs, tests --------- Co-authored-by: Hk669 <hrushi669@gmail.com>
- Loading branch information
1 parent
51971fd
commit d44edb6
Showing
2 changed files
with
235 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
"""Utilities for client classes""" | ||
|
||
import warnings | ||
from typing import Any, Dict, List, Optional, Tuple | ||
|
||
|
||
def validate_parameter( | ||
params: Dict[str, Any], | ||
param_name: str, | ||
allowed_types: Tuple, | ||
allow_None: bool, | ||
default_value: Any, | ||
numerical_bound: Tuple, | ||
allowed_values: list, | ||
) -> Any: | ||
""" | ||
Validates a given config parameter, checking its type, values, and setting defaults | ||
Parameters: | ||
params (Dict[str, Any]): Dictionary containing parameters to validate. | ||
param_name (str): The name of the parameter to validate. | ||
allowed_types (Tuple): Tuple of acceptable types for the parameter. | ||
allow_None (bool): Whether the parameter can be `None`. | ||
default_value (Any): The default value to use if the parameter is invalid or missing. | ||
numerical_bound (Optional[Tuple[Optional[float], Optional[float]]]): | ||
A tuple specifying the lower and upper bounds for numerical parameters. | ||
Each bound can be `None` if not applicable. | ||
allowed_values (Optional[List[Any]]): A list of acceptable values for the parameter. | ||
Can be `None` if no specific values are required. | ||
Returns: | ||
Any: The validated parameter value or the default value if validation fails. | ||
Raises: | ||
TypeError: If `allowed_values` is provided but is not a list. | ||
Example Usage: | ||
```python | ||
# Validating a numerical parameter within specific bounds | ||
params = {"temperature": 0.5, "safety_model": "Meta-Llama/Llama-Guard-7b"} | ||
temperature = validate_parameter(params, "temperature", (int, float), True, 0.7, (0, 1), None) | ||
# Result: 0.5 | ||
# Validating a parameter that can be one of a list of allowed values | ||
model = validate_parameter( | ||
params, "safety_model", str, True, None, None, ["Meta-Llama/Llama-Guard-7b", "Meta-Llama/Llama-Guard-13b"] | ||
) | ||
# If "safety_model" is missing or invalid in params, defaults to "default" | ||
``` | ||
""" | ||
|
||
if allowed_values is not None and not isinstance(allowed_values, list): | ||
raise TypeError(f"allowed_values should be a list or None, got {type(allowed_values).__name__}") | ||
|
||
param_value = params.get(param_name, default_value) | ||
warning = "" | ||
|
||
if param_value is None and allow_None: | ||
pass | ||
elif param_value is None: | ||
if not allow_None: | ||
warning = "cannot be None" | ||
elif not isinstance(param_value, allowed_types): | ||
# Check types and list possible types if invalid | ||
if isinstance(allowed_types, tuple): | ||
formatted_types = "(" + ", ".join(f"{t.__name__}" for t in allowed_types) + ")" | ||
else: | ||
formatted_types = f"{allowed_types.__name__}" | ||
warning = f"must be of type {formatted_types}{' or None' if allow_None else ''}" | ||
elif numerical_bound: | ||
# Check the value fits in possible bounds | ||
lower_bound, upper_bound = numerical_bound | ||
if (lower_bound is not None and param_value < lower_bound) or ( | ||
upper_bound is not None and param_value > upper_bound | ||
): | ||
warning = "has numerical bounds" | ||
if lower_bound is not None: | ||
warning += f", >= {str(lower_bound)}" | ||
if upper_bound is not None: | ||
if lower_bound is not None: | ||
warning += " and" | ||
warning += f" <= {str(upper_bound)}" | ||
if allow_None: | ||
warning += ", or can be None" | ||
|
||
elif allowed_values: | ||
# Check if the value matches any allowed values | ||
if not (allow_None and param_value is None): | ||
if param_value not in allowed_values: | ||
warning = f"must be one of these values [{allowed_values}]{', or can be None' if allow_None else ''}" | ||
|
||
# If we failed any checks, warn and set to default value | ||
if warning: | ||
warnings.warn( | ||
f"Config error - {param_name} {warning}, defaulting to {default_value}.", | ||
UserWarning, | ||
) | ||
param_value = default_value | ||
|
||
return param_value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
#!/usr/bin/env python3 -m pytest | ||
|
||
import pytest | ||
|
||
import autogen | ||
from autogen.oai.client_utils import validate_parameter | ||
|
||
|
||
def test_validate_parameter(): | ||
# Test valid parameters | ||
params = { | ||
"model": "Qwen/Qwen2-72B-Instruct", | ||
"max_tokens": 1000, | ||
"stream": False, | ||
"temperature": 1, | ||
"top_p": 0.8, | ||
"top_k": 50, | ||
"repetition_penalty": 0.5, | ||
"presence_penalty": 1.5, | ||
"frequency_penalty": 1.5, | ||
"min_p": 0.2, | ||
"safety_model": "Meta-Llama/Llama-Guard-7b", | ||
} | ||
|
||
# Should return the original value as they are valid | ||
assert params["model"] == validate_parameter(params, "model", str, False, None, None, None) | ||
assert params["max_tokens"] == validate_parameter(params, "max_tokens", int, True, 512, (0, None), None) | ||
assert params["stream"] == validate_parameter(params, "stream", bool, False, False, None, None) | ||
assert params["temperature"] == validate_parameter(params, "temperature", (int, float), True, None, None, None) | ||
assert params["top_k"] == validate_parameter(params, "top_k", int, True, None, None, None) | ||
assert params["repetition_penalty"] == validate_parameter( | ||
params, "repetition_penalty", float, True, None, None, None | ||
) | ||
assert params["presence_penalty"] == validate_parameter( | ||
params, "presence_penalty", (int, float), True, None, (-2, 2), None | ||
) | ||
assert params["safety_model"] == validate_parameter(params, "safety_model", str, True, None, None, None) | ||
|
||
# Test None allowed | ||
params = { | ||
"max_tokens": None, | ||
} | ||
|
||
# Should remain None | ||
assert validate_parameter(params, "max_tokens", int, True, 512, (0, None), None) is None | ||
|
||
# Test not None allowed | ||
params = { | ||
"max_tokens": None, | ||
} | ||
|
||
# Should return default | ||
assert 512 == validate_parameter(params, "max_tokens", int, False, 512, (0, None), None) | ||
|
||
# Test invalid parameters | ||
params = { | ||
"stream": "Yes", | ||
"temperature": "0.5", | ||
"top_p": "0.8", | ||
"top_k": "50", | ||
"repetition_penalty": "0.5", | ||
"presence_penalty": "1.5", | ||
"frequency_penalty": "1.5", | ||
"min_p": "0.2", | ||
"safety_model": False, | ||
} | ||
|
||
# Should all be set to defaults | ||
assert validate_parameter(params, "stream", bool, False, False, None, None) is not None | ||
assert validate_parameter(params, "temperature", (int, float), True, None, None, None) is None | ||
assert validate_parameter(params, "top_p", (int, float), True, None, None, None) is None | ||
assert validate_parameter(params, "top_k", int, True, None, None, None) is None | ||
assert validate_parameter(params, "repetition_penalty", float, True, None, None, None) is None | ||
assert validate_parameter(params, "presence_penalty", (int, float), True, None, (-2, 2), None) is None | ||
assert validate_parameter(params, "frequency_penalty", (int, float), True, None, (-2, 2), None) is None | ||
assert validate_parameter(params, "min_p", (int, float), True, None, (0, 1), None) is None | ||
assert validate_parameter(params, "safety_model", str, True, None, None, None) is None | ||
|
||
# Test parameters outside of bounds | ||
params = { | ||
"max_tokens": -200, | ||
"presence_penalty": -5, | ||
"frequency_penalty": 5, | ||
"min_p": -0.5, | ||
} | ||
|
||
# Should all be set to defaults | ||
assert 512 == validate_parameter(params, "max_tokens", int, True, 512, (0, None), None) | ||
assert validate_parameter(params, "presence_penalty", (int, float), True, None, (-2, 2), None) is None | ||
assert validate_parameter(params, "frequency_penalty", (int, float), True, None, (-2, 2), None) is None | ||
assert validate_parameter(params, "min_p", (int, float), True, None, (0, 1), None) is None | ||
|
||
# Test valid list options | ||
params = { | ||
"safety_model": "Meta-Llama/Llama-Guard-7b", | ||
} | ||
|
||
# Should all be set to defaults | ||
assert "Meta-Llama/Llama-Guard-7b" == validate_parameter( | ||
params, "safety_model", str, True, None, None, ["Meta-Llama/Llama-Guard-7b", "Meta-Llama/Llama-Guard-13b"] | ||
) | ||
|
||
# Test invalid list options | ||
params = { | ||
"stream": True, | ||
} | ||
|
||
# Should all be set to defaults | ||
assert not validate_parameter(params, "stream", bool, False, False, None, [False]) | ||
|
||
# test invalid type | ||
params = { | ||
"temperature": None, | ||
} | ||
|
||
# should be set to defaults | ||
assert validate_parameter(params, "temperature", (int, float), False, 0.7, (0.0, 1.0), None) == 0.7 | ||
|
||
# test value out of bounds | ||
params = { | ||
"temperature": 23, | ||
} | ||
|
||
# should be set to defaults | ||
assert validate_parameter(params, "temperature", (int, float), False, 1.0, (0.0, 1.0), None) == 1.0 | ||
|
||
# type error for the parameters | ||
with pytest.raises(TypeError): | ||
validate_parameter({}, "param", str, True, None, None, "not_a_list") | ||
|
||
# passing empty params, which will set to defaults | ||
assert validate_parameter({}, "max_tokens", int, True, 512, (0, None), None) == 512 | ||
|
||
|
||
if __name__ == "__main__": | ||
test_validate_parameter() |