Skip to content

Commit

Permalink
type checks
Browse files Browse the repository at this point in the history
  • Loading branch information
bboynton97 committed Jun 4, 2024
1 parent b019d94 commit 7255090
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
13 changes: 9 additions & 4 deletions autogen/logger/file_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
__all__ = ("FileLogger",)


def safe_serialize(obj):
def default(o):
def safe_serialize(obj: Any) -> str:
def default(o: Any) -> str:
if hasattr(o, "to_json"):
return o.to_json()
return str(o.to_json())
else:
return f"<<non-serializable: {type(o).__qualname__}>>"

Expand Down Expand Up @@ -82,6 +82,11 @@ def log_chat_completion(
Log a chat completion.
"""
thread_id = threading.get_ident()
source_name = None
if isinstance(source, str):
source_name = source
else:
source_name = source.name
try:
log_data = json.dumps(
{
Expand All @@ -95,7 +100,7 @@ def log_chat_completion(
"start_time": start_time,
"end_time": get_current_ts(),
"thread_id": thread_id,
"source_name": source.name,
"source_name": source_name,
}
)

Expand Down
14 changes: 10 additions & 4 deletions autogen/logger/sqlite_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
F = TypeVar("F", bound=Callable[..., Any])


def safe_serialize(obj):
def default(o):
def safe_serialize(obj: Any) -> str:
def default(o: Any) -> str:
if hasattr(o, "to_json"):
return o.to_json()
return str(o.to_json())
else:
return f"<<non-serializable: {type(o).__qualname__}>>"

Expand Down Expand Up @@ -234,6 +234,12 @@ def log_chat_completion(
else:
response_messages = json.dumps(to_dict(response), indent=4)

source_name = None
if isinstance(source, str):
source_name = source
else:
source_name = source.name

query = """
INSERT INTO chat_completions (
invocation_id, client_id, wrapper_id, session_id, request, response, is_cached, cost, start_time, end_time, source_name
Expand All @@ -250,7 +256,7 @@ def log_chat_completion(
cost,
start_time,
end_time,
source.name,
source_name,
)

self._run_query(query=query, args=args)
Expand Down

0 comments on commit 7255090

Please sign in to comment.