Skip to content

Commit

Permalink
Lunary: Fix tool calling
Browse files Browse the repository at this point in the history
  • Loading branch information
vincelwt committed May 3, 2024
1 parent 91971fa commit 3677d56
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 9 deletions.
36 changes: 29 additions & 7 deletions litellm/integrations/lunary.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import traceback
import dotenv
import importlib
import sys

import packaging

Expand All @@ -18,13 +17,33 @@ def parse_usage(usage):
"prompt": usage["prompt_tokens"] if "prompt_tokens" in usage else 0,
}

def parse_tool_calls(tool_calls):
if tool_calls is None:
return None

def clean_tool_call(tool_call):

serialized = {
"type": tool_call.type,
"id": tool_call.id,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
}
}

return serialized

return [clean_tool_call(tool_call) for tool_call in tool_calls]


def parse_messages(input):

if input is None:
return None

def clean_message(message):
# if is strin, return as is
# if is string, return as is
if isinstance(message, str):
return message

Expand All @@ -38,9 +57,7 @@ def clean_message(message):

# Only add tool_calls and function_call to res if they are set
if message.get("tool_calls"):
serialized["tool_calls"] = message.get("tool_calls")
if message.get("function_call"):
serialized["function_call"] = message.get("function_call")
serialized["tool_calls"] = parse_tool_calls(message.get("tool_calls"))

return serialized

Expand Down Expand Up @@ -93,8 +110,13 @@ def log_event(
print_verbose(f"Lunary Logging - Logging request for model {model}")

litellm_params = kwargs.get("litellm_params", {})
optional_params = kwargs.get("optional_params", {})
metadata = litellm_params.get("metadata", {}) or {}

if optional_params:
# merge into extra
extra = {**extra, **optional_params}

tags = litellm_params.pop("tags", None) or []

if extra:
Expand All @@ -104,7 +126,7 @@ def log_event(

# keep only serializable types
for param, value in extra.items():
if not isinstance(value, (str, int, bool, float)):
if not isinstance(value, (str, int, bool, float)) and param != "tools":
try:
extra[param] = str(value)
except:
Expand Down Expand Up @@ -140,7 +162,7 @@ def log_event(
metadata=metadata,
runtime="litellm",
tags=tags,
extra=extra,
params=extra,
)

self.lunary_client.track_event(
Expand Down
40 changes: 38 additions & 2 deletions litellm/tests/test_lunary.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
litellm.success_callback = ["lunary"]
litellm.set_verbose = True


def test_lunary_logging():
try:
response = completion(
Expand Down Expand Up @@ -59,9 +58,46 @@ def test_lunary_logging_with_metadata():
except Exception as e:
print(e)

#test_lunary_logging_with_metadata()

def test_lunary_with_tools():

import litellm

messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}]
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]

response = litellm.completion(
model="gpt-3.5-turbo-1106",
messages=messages,
tools=tools,
tool_choice="auto", # auto is default, but we'll be explicit
)

response_message = response.choices[0].message
print("\nLLM Response:\n", response.choices[0].message)

# test_lunary_logging_with_metadata()

#test_lunary_with_tools()

def test_lunary_logging_with_streaming_and_metadata():
try:
Expand Down

0 comments on commit 3677d56

Please sign in to comment.