From 3677d56e9e1ed5e0b28c1efcde919ab5f79cae1a Mon Sep 17 00:00:00 2001 From: Vince Loewe Date: Fri, 3 May 2024 17:42:50 +0100 Subject: [PATCH] Lunary: Fix tool calling --- litellm/integrations/lunary.py | 36 ++++++++++++++++++++++++------ litellm/tests/test_lunary.py | 40 ++++++++++++++++++++++++++++++++-- 2 files changed, 67 insertions(+), 9 deletions(-) diff --git a/litellm/integrations/lunary.py b/litellm/integrations/lunary.py index 6ddf2ca59923..6b23f098755d 100644 --- a/litellm/integrations/lunary.py +++ b/litellm/integrations/lunary.py @@ -4,7 +4,6 @@ import traceback import dotenv import importlib -import sys import packaging @@ -18,13 +17,33 @@ def parse_usage(usage): "prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0, } +def parse_tool_calls(tool_calls): + if tool_calls is None: + return None + + def clean_tool_call(tool_call): + + serialized = { + "type": tool_call.type, + "id": tool_call.id, + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + } + } + + return serialized + + return [clean_tool_call(tool_call) for tool_call in tool_calls] + def parse_messages(input): + if input is None: return None def clean_message(message): - # if is strin, return as is + # if is string, return as is if isinstance(message, str): return message @@ -38,9 +57,7 @@ def clean_message(message): # Only add tool_calls and function_call to res if they are set if message.get("tool_calls"): - serialized["tool_calls"] = message.get("tool_calls") - if message.get("function_call"): - serialized["function_call"] = message.get("function_call") + serialized["tool_calls"] = parse_tool_calls(message.get("tool_calls")) return serialized @@ -93,8 +110,13 @@ def log_event( print_verbose(f"Lunary Logging - Logging request for model {model}") litellm_params = kwargs.get("litellm_params", {}) + optional_params = kwargs.get("optional_params", {}) metadata = litellm_params.get("metadata", {}) or {} + if optional_params: + # merge into extra + extra = {**extra, **optional_params} + tags = litellm_params.pop("tags", None) or [] if extra: @@ -104,7 +126,7 @@ def log_event( # keep only serializable types for param, value in extra.items(): - if not isinstance(value, (str, int, bool, float)): + if not isinstance(value, (str, int, bool, float)) and param != "tools": try: extra[param] = str(value) except: @@ -140,7 +162,7 @@ def log_event( metadata=metadata, runtime="litellm", tags=tags, - extra=extra, + params=extra, ) self.lunary_client.track_event( diff --git a/litellm/tests/test_lunary.py b/litellm/tests/test_lunary.py index cbf9364aff8a..c9a8afd57f1b 100644 --- a/litellm/tests/test_lunary.py +++ b/litellm/tests/test_lunary.py @@ -11,7 +11,6 @@ litellm.success_callback = ["lunary"] litellm.set_verbose = True - def test_lunary_logging(): try: response = completion( @@ -59,9 +58,46 @@ def test_lunary_logging_with_metadata(): except Exception as e: print(e) +#test_lunary_logging_with_metadata() + +def test_lunary_with_tools(): + + import litellm + + messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}] + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ] + + response = litellm.completion( + model="gpt-3.5-turbo-1106", + messages=messages, + tools=tools, + tool_choice="auto", # auto is default, but we'll be explicit + ) + + response_message = response.choices[0].message + print("\nLLM Response:\n", response.choices[0].message) -# test_lunary_logging_with_metadata() +#test_lunary_with_tools() def test_lunary_logging_with_streaming_and_metadata(): try: