Skip to content

Commit

Permalink
Log applied guardrails on LLM API call (#8452)
Browse files Browse the repository at this point in the history
* fix(litellm_logging.py): support saving applied guardrails in logging object

allows list of applied guardrails to be logged for proxy admin's knowledge

* feat(spend_tracking_utils.py): log applied guardrails to spend logs

makes it easy for admin to know what guardrails were applied on a request

* ci(config.yml): uninstall posthog from ci/cd

* test: fix tests

* test: update test
  • Loading branch information
krrishdholakia authored Feb 11, 2025
1 parent 8e32713 commit ce3ead6
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 10 deletions.
7 changes: 7 additions & 0 deletions litellm/litellm_core_utils/litellm_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def __init__(
dynamic_async_failure_callbacks: Optional[
List[Union[str, Callable, CustomLogger]]
] = None,
applied_guardrails: Optional[List[str]] = None,
kwargs: Optional[Dict] = None,
):
_input: Optional[str] = messages # save original value of messages
Expand Down Expand Up @@ -271,6 +272,7 @@ def __init__(
"litellm_call_id": litellm_call_id,
"input": _input,
"litellm_params": litellm_params,
"applied_guardrails": applied_guardrails,
}

def process_dynamic_callbacks(self):
Expand Down Expand Up @@ -2852,6 +2854,7 @@ def get_standard_logging_metadata(
metadata: Optional[Dict[str, Any]],
litellm_params: Optional[dict] = None,
prompt_integration: Optional[str] = None,
applied_guardrails: Optional[List[str]] = None,
) -> StandardLoggingMetadata:
"""
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
Expand All @@ -2866,6 +2869,7 @@ def get_standard_logging_metadata(
- If the input metadata is None or not a dictionary, an empty StandardLoggingMetadata object is returned.
- If 'user_api_key' is present in metadata and is a valid SHA256 hash, it's stored as 'user_api_key_hash'.
"""

prompt_management_metadata: Optional[
StandardLoggingPromptManagementMetadata
] = None
Expand Down Expand Up @@ -2895,6 +2899,7 @@ def get_standard_logging_metadata(
requester_metadata=None,
user_api_key_end_user_id=None,
prompt_management_metadata=prompt_management_metadata,
applied_guardrails=applied_guardrails,
)
if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys
Expand Down Expand Up @@ -3193,6 +3198,7 @@ def get_standard_logging_object_payload(
metadata=metadata,
litellm_params=litellm_params,
prompt_integration=kwargs.get("prompt_integration", None),
applied_guardrails=kwargs.get("applied_guardrails", None),
)

_request_body = proxy_server_request.get("body", {})
Expand Down Expand Up @@ -3328,6 +3334,7 @@ def get_standard_logging_metadata(
requester_metadata=None,
user_api_key_end_user_id=None,
prompt_management_metadata=None,
applied_guardrails=None,
)
if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys
Expand Down
1 change: 1 addition & 0 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1794,6 +1794,7 @@ class SpendLogsMetadata(TypedDict):
dict
] # special param to log k,v pairs to spendlogs for a call
requester_ip_address: Optional[str]
applied_guardrails: Optional[List[str]]


class SpendLogsPayload(TypedDict):
Expand Down
20 changes: 16 additions & 4 deletions litellm/proxy/spend_tracking/spend_tracking_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import datetime
from datetime import datetime as dt
from datetime import timezone
from typing import Optional, cast
from typing import List, Optional, cast

from pydantic import BaseModel

Expand Down Expand Up @@ -32,7 +32,9 @@ def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool:
return False


def _get_spend_logs_metadata(metadata: Optional[dict]) -> SpendLogsMetadata:
def _get_spend_logs_metadata(
metadata: Optional[dict], applied_guardrails: Optional[List[str]] = None
) -> SpendLogsMetadata:
if metadata is None:
return SpendLogsMetadata(
user_api_key=None,
Expand All @@ -44,8 +46,9 @@ def _get_spend_logs_metadata(metadata: Optional[dict]) -> SpendLogsMetadata:
spend_logs_metadata=None,
requester_ip_address=None,
additional_usage_values=None,
applied_guardrails=None,
)
verbose_proxy_logger.debug(
verbose_proxy_logger.info(
"getting payload for SpendLogs, available keys in metadata: "
+ str(list(metadata.keys()))
)
Expand All @@ -58,6 +61,8 @@ def _get_spend_logs_metadata(metadata: Optional[dict]) -> SpendLogsMetadata:
if key in metadata
}
)
clean_metadata["applied_guardrails"] = applied_guardrails

return clean_metadata


Expand Down Expand Up @@ -130,7 +135,14 @@ def get_logging_payload( # noqa: PLR0915
_model_group = metadata.get("model_group", "")

# clean up litellm metadata
clean_metadata = _get_spend_logs_metadata(metadata)
clean_metadata = _get_spend_logs_metadata(
metadata,
applied_guardrails=(
standard_logging_payload["metadata"].get("applied_guardrails", None)
if standard_logging_payload is not None
else None
),
)

special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"]
additional_usage_values = {}
Expand Down
1 change: 1 addition & 0 deletions litellm/types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1525,6 +1525,7 @@ class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata):
requester_ip_address: Optional[str]
requester_metadata: Optional[dict]
prompt_management_metadata: Optional[StandardLoggingPromptManagementMetadata]
applied_guardrails: Optional[List[str]]


class StandardLoggingAdditionalHeaders(TypedDict, total=False):
Expand Down
34 changes: 34 additions & 0 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from litellm.caching._internal_lru_cache import lru_cache_wrapper
from litellm.caching.caching import DualCache
from litellm.caching.caching_handler import CachingHandlerResponse, LLMCachingHandler
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import (
map_finish_reason,
Expand Down Expand Up @@ -418,6 +419,35 @@ def _custom_logger_class_exists_in_failure_callbacks(
)


def get_request_guardrails(kwargs: Dict[str, Any]) -> List[str]:
"""
Get the request guardrails from the kwargs
"""
metadata = kwargs.get("metadata") or {}
requester_metadata = metadata.get("requester_metadata") or {}
applied_guardrails = requester_metadata.get("guardrails") or []
return applied_guardrails


def get_applied_guardrails(kwargs: Dict[str, Any]) -> List[str]:
"""
- Add 'default_on' guardrails to the list
- Add request guardrails to the list
"""

request_guardrails = get_request_guardrails(kwargs)
applied_guardrails = []
for callback in litellm.callbacks:
if callback is not None and isinstance(callback, CustomGuardrail):
if callback.guardrail_name is not None:
if callback.default_on is True:
applied_guardrails.append(callback.guardrail_name)
elif callback.guardrail_name in request_guardrails:
applied_guardrails.append(callback.guardrail_name)

return applied_guardrails


def function_setup( # noqa: PLR0915
original_function: str, rules_obj, start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
Expand All @@ -436,6 +466,9 @@ def function_setup( # noqa: PLR0915
## CUSTOM LLM SETUP ##
custom_llm_setup()

## GET APPLIED GUARDRAILS
applied_guardrails = get_applied_guardrails(kwargs)

## LOGGING SETUP
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None

Expand Down Expand Up @@ -677,6 +710,7 @@ def function_setup( # noqa: PLR0915
dynamic_async_success_callbacks=dynamic_async_success_callbacks,
dynamic_async_failure_callbacks=dynamic_async_failure_callbacks,
kwargs=kwargs,
applied_guardrails=applied_guardrails,
)

## check if metadata is passed in
Expand Down
96 changes: 91 additions & 5 deletions tests/litellm_utils_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,17 +864,24 @@ def test_convert_model_response_object():
== '{"type":"error","error":{"type":"invalid_request_error","message":"Output blocked by content filtering policy"}}'
)


@pytest.mark.parametrize(
"content, expected_reasoning, expected_content",
"content, expected_reasoning, expected_content",
[
(None, None, None),
("<think>I am thinking here</think>The sky is a canvas of blue", "I am thinking here", "The sky is a canvas of blue"),
(
"<think>I am thinking here</think>The sky is a canvas of blue",
"I am thinking here",
"The sky is a canvas of blue",
),
("I am a regular response", None, "I am a regular response"),
]
],
)
def test_parse_content_for_reasoning(content, expected_reasoning, expected_content):
assert(litellm.utils._parse_content_for_reasoning(content) == (expected_reasoning, expected_content))
assert litellm.utils._parse_content_for_reasoning(content) == (
expected_reasoning,
expected_content,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -1874,3 +1881,82 @@ def test_validate_user_messages_invalid_content_type():

assert "Invalid message" in str(e)
print(e)


from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.utils import get_applied_guardrails
from unittest.mock import Mock


@pytest.mark.parametrize(
"test_case",
[
{
"name": "default_on_guardrail",
"callbacks": [
CustomGuardrail(guardrail_name="test_guardrail", default_on=True)
],
"kwargs": {"metadata": {"requester_metadata": {"guardrails": []}}},
"expected": ["test_guardrail"],
},
{
"name": "request_specific_guardrail",
"callbacks": [
CustomGuardrail(guardrail_name="test_guardrail", default_on=False)
],
"kwargs": {
"metadata": {"requester_metadata": {"guardrails": ["test_guardrail"]}}
},
"expected": ["test_guardrail"],
},
{
"name": "multiple_guardrails",
"callbacks": [
CustomGuardrail(guardrail_name="default_guardrail", default_on=True),
CustomGuardrail(guardrail_name="request_guardrail", default_on=False),
],
"kwargs": {
"metadata": {
"requester_metadata": {"guardrails": ["request_guardrail"]}
}
},
"expected": ["default_guardrail", "request_guardrail"],
},
{
"name": "empty_metadata",
"callbacks": [
CustomGuardrail(guardrail_name="test_guardrail", default_on=False)
],
"kwargs": {},
"expected": [],
},
{
"name": "none_callback",
"callbacks": [
None,
CustomGuardrail(guardrail_name="test_guardrail", default_on=True),
],
"kwargs": {},
"expected": ["test_guardrail"],
},
{
"name": "non_guardrail_callback",
"callbacks": [
Mock(),
CustomGuardrail(guardrail_name="test_guardrail", default_on=True),
],
"kwargs": {},
"expected": ["test_guardrail"],
},
],
)
def test_get_applied_guardrails(test_case):

# Setup
litellm.callbacks = test_case["callbacks"]

# Execute
result = get_applied_guardrails(test_case["kwargs"])

# Assert
assert sorted(result) == sorted(test_case["expected"])
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"model": "gpt-4o",
"user": "",
"team_id": "",
"metadata": "{\"additional_usage_values\": {\"completion_tokens_details\": null, \"prompt_tokens_details\": null}}",
"metadata": "{\"applied_guardrails\": [], \"additional_usage_values\": {\"completion_tokens_details\": null, \"prompt_tokens_details\": null}}",
"cache_key": "Cache OFF",
"spend": 0.00022500000000000002,
"total_tokens": 30,
Expand Down
1 change: 1 addition & 0 deletions tests/logging_callback_tests/test_otel_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def validate_redacted_message_span_attributes(span):
"metadata.user_api_key_user_id",
"metadata.user_api_key_org_id",
"metadata.user_api_key_end_user_id",
"metadata.applied_guardrails",
]

_all_attributes = set(
Expand Down

0 comments on commit ce3ead6

Please sign in to comment.