From a6d5b44bbdf9a78283692296f5ca0e79d3853dee Mon Sep 17 00:00:00 2001 From: Wenzhe Xue Date: Wed, 17 Apr 2024 10:02:01 -0700 Subject: [PATCH] add load.log10(lamini) to support lamini sdk and add example --- examples/logging/lamini_generate.py | 11 +++++++++++ log10/load.py | 30 ++++++++++++++++++++++++----- 2 files changed, 36 insertions(+), 5 deletions(-) create mode 100644 examples/logging/lamini_generate.py diff --git a/examples/logging/lamini_generate.py b/examples/logging/lamini_generate.py new file mode 100644 index 00000000..e49bf632 --- /dev/null +++ b/examples/logging/lamini_generate.py @@ -0,0 +1,11 @@ +import lamini + +from log10.load import log10 + + +log10(lamini) + +llm = lamini.Lamini("meta-llama/Llama-2-7b-chat-hf") +response = llm.generate("What's 2 + 9 * 3?") + +print(response) diff --git a/log10/load.py b/log10/load.py index f091e9e7..0ae23bc5 100644 --- a/log10/load.py +++ b/log10/load.py @@ -436,7 +436,7 @@ def _get_stack_trace(): ] -def _init_log_row(func, **kwargs): +def _init_log_row(func, *args, **kwargs): kwargs_copy = deepcopy(kwargs) log_row = { @@ -497,6 +497,11 @@ def _init_log_row(func, **kwargs): else: kwargs_copy[key] = value kwargs_copy.pop("generation_config") + elif "lamini" in func.__module__: + log_row["kind"] = "chat" + kwargs_copy.update( + {"model": args[1]["model_name"], "messages": [{"role": "user", "content": args[1]["prompt"]}]} + ) elif "mistralai" in func.__module__: log_row["kind"] = "chat" elif "openai" in func.__module__: @@ -516,7 +521,7 @@ def wrapper(*args, **kwargs): result_queue = queue.Queue() try: - log_row = _init_log_row(func, **kwargs) + log_row = _init_log_row(func, *args, **kwargs) with timed_block(sync_log_text + " call duration"): if USE_ASYNC: @@ -593,11 +598,9 @@ def wrapper(*args, **kwargs): elif "vertexai" in func.__module__: response = output reason = response.candidates[0].finish_reason.name - import uuid - ret_response = { "id": str(uuid.uuid4()), - "object": "completion", + "object": "chat.completion", "choices": [ { "index": 0, @@ -628,6 +631,19 @@ def wrapper(*args, **kwargs): if "choices" in response: response = flatten_response(response) + elif "lamini" in func.__module__: + response = { + "id": str(uuid.uuid4()), + "object": "chat.completion", + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": {"role": "assistant", "content": output["output"]}, + } + ], + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + } elif "mistralai" in func.__module__: if "stream" in func.__qualname__: log_row["response"] = response @@ -802,6 +818,10 @@ def log10(module, DEBUG_=False, USE_ASYNC_=True): attr = module.resources.messages.Messages method = getattr(attr, "create") setattr(attr, "create", intercepting_decorator(method)) + elif module.__name__ == "lamini": + attr = module.api.utils.completion.Completion + method = getattr(attr, "generate") + setattr(attr, "generate", intercepting_decorator(method)) elif module.__name__ == "mistralai" and getattr(module, "_log10_patched", False) is False: attr = module.client.MistralClient method = getattr(attr, "chat")