From 7255090a74c4eae3557afd7ab94039af3a2a4054 Mon Sep 17 00:00:00 2001 From: Braelyn Boynton Date: Mon, 3 Jun 2024 23:53:09 -0700 Subject: [PATCH] type checks --- autogen/logger/file_logger.py | 13 +++++++++---- autogen/logger/sqlite_logger.py | 14 ++++++++++---- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/autogen/logger/file_logger.py b/autogen/logger/file_logger.py index 965245bef33..15b2c457e42 100644 --- a/autogen/logger/file_logger.py +++ b/autogen/logger/file_logger.py @@ -26,10 +26,10 @@ __all__ = ("FileLogger",) -def safe_serialize(obj): - def default(o): +def safe_serialize(obj: Any) -> str: + def default(o: Any) -> str: if hasattr(o, "to_json"): - return o.to_json() + return str(o.to_json()) else: return f"<>" @@ -82,6 +82,11 @@ def log_chat_completion( Log a chat completion. """ thread_id = threading.get_ident() + source_name = None + if isinstance(source, str): + source_name = source + else: + source_name = source.name try: log_data = json.dumps( { @@ -95,7 +100,7 @@ def log_chat_completion( "start_time": start_time, "end_time": get_current_ts(), "thread_id": thread_id, - "source_name": source.name, + "source_name": source_name, } ) diff --git a/autogen/logger/sqlite_logger.py b/autogen/logger/sqlite_logger.py index 7704a3638fd..fb66c893b0c 100644 --- a/autogen/logger/sqlite_logger.py +++ b/autogen/logger/sqlite_logger.py @@ -28,10 +28,10 @@ F = TypeVar("F", bound=Callable[..., Any]) -def safe_serialize(obj): - def default(o): +def safe_serialize(obj: Any) -> str: + def default(o: Any) -> str: if hasattr(o, "to_json"): - return o.to_json() + return str(o.to_json()) else: return f"<>" @@ -234,6 +234,12 @@ def log_chat_completion( else: response_messages = json.dumps(to_dict(response), indent=4) + source_name = None + if isinstance(source, str): + source_name = source + else: + source_name = source.name + query = """ INSERT INTO chat_completions ( invocation_id, client_id, wrapper_id, session_id, request, response, is_cached, cost, start_time, end_time, source_name @@ -250,7 +256,7 @@ def log_chat_completion( cost, start_time, end_time, - source.name, + source_name, ) self._run_query(query=query, args=args)