Skip to content

Commit

Permalink
Client class utilities (#2949)
Browse files Browse the repository at this point in the history
* Addition of client utilities, initially for parameter validation

* Corrected test

* update: type checks and few tests

* fix: docs, tests

---------

Co-authored-by: Hk669 <hrushi669@gmail.com>
  • Loading branch information
2 people authored and victordibia committed Jul 30, 2024
1 parent 51971fd commit d44edb6
Show file tree
Hide file tree
Showing 2 changed files with 235 additions and 0 deletions.
99 changes: 99 additions & 0 deletions autogen/oai/client_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""Utilities for client classes"""

import warnings
from typing import Any, Dict, List, Optional, Tuple


def validate_parameter(
params: Dict[str, Any],
param_name: str,
allowed_types: Tuple,
allow_None: bool,
default_value: Any,
numerical_bound: Tuple,
allowed_values: list,
) -> Any:
"""
Validates a given config parameter, checking its type, values, and setting defaults
Parameters:
params (Dict[str, Any]): Dictionary containing parameters to validate.
param_name (str): The name of the parameter to validate.
allowed_types (Tuple): Tuple of acceptable types for the parameter.
allow_None (bool): Whether the parameter can be `None`.
default_value (Any): The default value to use if the parameter is invalid or missing.
numerical_bound (Optional[Tuple[Optional[float], Optional[float]]]):
A tuple specifying the lower and upper bounds for numerical parameters.
Each bound can be `None` if not applicable.
allowed_values (Optional[List[Any]]): A list of acceptable values for the parameter.
Can be `None` if no specific values are required.
Returns:
Any: The validated parameter value or the default value if validation fails.
Raises:
TypeError: If `allowed_values` is provided but is not a list.
Example Usage:
```python
# Validating a numerical parameter within specific bounds
params = {"temperature": 0.5, "safety_model": "Meta-Llama/Llama-Guard-7b"}
temperature = validate_parameter(params, "temperature", (int, float), True, 0.7, (0, 1), None)
# Result: 0.5
# Validating a parameter that can be one of a list of allowed values
model = validate_parameter(
params, "safety_model", str, True, None, None, ["Meta-Llama/Llama-Guard-7b", "Meta-Llama/Llama-Guard-13b"]
)
# If "safety_model" is missing or invalid in params, defaults to "default"
```
"""

if allowed_values is not None and not isinstance(allowed_values, list):
raise TypeError(f"allowed_values should be a list or None, got {type(allowed_values).__name__}")

param_value = params.get(param_name, default_value)
warning = ""

if param_value is None and allow_None:
pass
elif param_value is None:
if not allow_None:
warning = "cannot be None"
elif not isinstance(param_value, allowed_types):
# Check types and list possible types if invalid
if isinstance(allowed_types, tuple):
formatted_types = "(" + ", ".join(f"{t.__name__}" for t in allowed_types) + ")"
else:
formatted_types = f"{allowed_types.__name__}"
warning = f"must be of type {formatted_types}{' or None' if allow_None else ''}"
elif numerical_bound:
# Check the value fits in possible bounds
lower_bound, upper_bound = numerical_bound
if (lower_bound is not None and param_value < lower_bound) or (
upper_bound is not None and param_value > upper_bound
):
warning = "has numerical bounds"
if lower_bound is not None:
warning += f", >= {str(lower_bound)}"
if upper_bound is not None:
if lower_bound is not None:
warning += " and"
warning += f" <= {str(upper_bound)}"
if allow_None:
warning += ", or can be None"

elif allowed_values:
# Check if the value matches any allowed values
if not (allow_None and param_value is None):
if param_value not in allowed_values:
warning = f"must be one of these values [{allowed_values}]{', or can be None' if allow_None else ''}"

# If we failed any checks, warn and set to default value
if warning:
warnings.warn(
f"Config error - {param_name} {warning}, defaulting to {default_value}.",
UserWarning,
)
param_value = default_value

return param_value
136 changes: 136 additions & 0 deletions test/oai/test_client_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#!/usr/bin/env python3 -m pytest

import pytest

import autogen
from autogen.oai.client_utils import validate_parameter


def test_validate_parameter():
# Test valid parameters
params = {
"model": "Qwen/Qwen2-72B-Instruct",
"max_tokens": 1000,
"stream": False,
"temperature": 1,
"top_p": 0.8,
"top_k": 50,
"repetition_penalty": 0.5,
"presence_penalty": 1.5,
"frequency_penalty": 1.5,
"min_p": 0.2,
"safety_model": "Meta-Llama/Llama-Guard-7b",
}

# Should return the original value as they are valid
assert params["model"] == validate_parameter(params, "model", str, False, None, None, None)
assert params["max_tokens"] == validate_parameter(params, "max_tokens", int, True, 512, (0, None), None)
assert params["stream"] == validate_parameter(params, "stream", bool, False, False, None, None)
assert params["temperature"] == validate_parameter(params, "temperature", (int, float), True, None, None, None)
assert params["top_k"] == validate_parameter(params, "top_k", int, True, None, None, None)
assert params["repetition_penalty"] == validate_parameter(
params, "repetition_penalty", float, True, None, None, None
)
assert params["presence_penalty"] == validate_parameter(
params, "presence_penalty", (int, float), True, None, (-2, 2), None
)
assert params["safety_model"] == validate_parameter(params, "safety_model", str, True, None, None, None)

# Test None allowed
params = {
"max_tokens": None,
}

# Should remain None
assert validate_parameter(params, "max_tokens", int, True, 512, (0, None), None) is None

# Test not None allowed
params = {
"max_tokens": None,
}

# Should return default
assert 512 == validate_parameter(params, "max_tokens", int, False, 512, (0, None), None)

# Test invalid parameters
params = {
"stream": "Yes",
"temperature": "0.5",
"top_p": "0.8",
"top_k": "50",
"repetition_penalty": "0.5",
"presence_penalty": "1.5",
"frequency_penalty": "1.5",
"min_p": "0.2",
"safety_model": False,
}

# Should all be set to defaults
assert validate_parameter(params, "stream", bool, False, False, None, None) is not None
assert validate_parameter(params, "temperature", (int, float), True, None, None, None) is None
assert validate_parameter(params, "top_p", (int, float), True, None, None, None) is None
assert validate_parameter(params, "top_k", int, True, None, None, None) is None
assert validate_parameter(params, "repetition_penalty", float, True, None, None, None) is None
assert validate_parameter(params, "presence_penalty", (int, float), True, None, (-2, 2), None) is None
assert validate_parameter(params, "frequency_penalty", (int, float), True, None, (-2, 2), None) is None
assert validate_parameter(params, "min_p", (int, float), True, None, (0, 1), None) is None
assert validate_parameter(params, "safety_model", str, True, None, None, None) is None

# Test parameters outside of bounds
params = {
"max_tokens": -200,
"presence_penalty": -5,
"frequency_penalty": 5,
"min_p": -0.5,
}

# Should all be set to defaults
assert 512 == validate_parameter(params, "max_tokens", int, True, 512, (0, None), None)
assert validate_parameter(params, "presence_penalty", (int, float), True, None, (-2, 2), None) is None
assert validate_parameter(params, "frequency_penalty", (int, float), True, None, (-2, 2), None) is None
assert validate_parameter(params, "min_p", (int, float), True, None, (0, 1), None) is None

# Test valid list options
params = {
"safety_model": "Meta-Llama/Llama-Guard-7b",
}

# Should all be set to defaults
assert "Meta-Llama/Llama-Guard-7b" == validate_parameter(
params, "safety_model", str, True, None, None, ["Meta-Llama/Llama-Guard-7b", "Meta-Llama/Llama-Guard-13b"]
)

# Test invalid list options
params = {
"stream": True,
}

# Should all be set to defaults
assert not validate_parameter(params, "stream", bool, False, False, None, [False])

# test invalid type
params = {
"temperature": None,
}

# should be set to defaults
assert validate_parameter(params, "temperature", (int, float), False, 0.7, (0.0, 1.0), None) == 0.7

# test value out of bounds
params = {
"temperature": 23,
}

# should be set to defaults
assert validate_parameter(params, "temperature", (int, float), False, 1.0, (0.0, 1.0), None) == 1.0

# type error for the parameters
with pytest.raises(TypeError):
validate_parameter({}, "param", str, True, None, None, "not_a_list")

# passing empty params, which will set to defaults
assert validate_parameter({}, "max_tokens", int, True, 512, (0, None), None) == 512


if __name__ == "__main__":
test_validate_parameter()

0 comments on commit d44edb6

Please sign in to comment.