Skip to content

Commit

Permalink
add load.log10(lamini) to support lamini sdk and add example
Browse files Browse the repository at this point in the history
  • Loading branch information
wenzhe-log10 committed Apr 17, 2024
1 parent ac2e177 commit 9847707
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 5 deletions.
11 changes: 11 additions & 0 deletions examples/logging/lamini_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import lamini

from log10.load import log10


log10(lamini)

llm = lamini.Lamini("meta-llama/Llama-2-7b-chat-hf")
response = llm.generate("What's 2 + 9 * 3?")

print(response)
30 changes: 25 additions & 5 deletions log10/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def _get_stack_trace():
]


def _init_log_row(func, **kwargs):
def _init_log_row(func, *args, **kwargs):
kwargs_copy = deepcopy(kwargs)

log_row = {
Expand Down Expand Up @@ -497,6 +497,11 @@ def _init_log_row(func, **kwargs):
else:
kwargs_copy[key] = value
kwargs_copy.pop("generation_config")
elif "lamini" in func.__module__:
log_row["kind"] = "chat"
kwargs_copy.update(
{"model": args[1]["model_name"], "messages": [{"role": "user", "content": args[1]["prompt"]}]}
)
elif "mistralai" in func.__module__:
log_row["kind"] = "chat"
elif "openai" in func.__module__:
Expand All @@ -516,7 +521,7 @@ def wrapper(*args, **kwargs):
result_queue = queue.Queue()

try:
log_row = _init_log_row(func, **kwargs)
log_row = _init_log_row(func, *args, **kwargs)

with timed_block(sync_log_text + " call duration"):
if USE_ASYNC:
Expand Down Expand Up @@ -593,11 +598,9 @@ def wrapper(*args, **kwargs):
elif "vertexai" in func.__module__:
response = output
reason = response.candidates[0].finish_reason.name
import uuid

ret_response = {
"id": str(uuid.uuid4()),
"object": "completion",
"object": "chat.completion",
"choices": [
{
"index": 0,
Expand Down Expand Up @@ -628,6 +631,19 @@ def wrapper(*args, **kwargs):

if "choices" in response:
response = flatten_response(response)
elif "lamini" in func.__module__:
response = {
"id": str(uuid.uuid4()),
"object": "chat.completion",
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {"role": "assistant", "content": output["output"]},
}
],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}
elif "mistralai" in func.__module__:
if "stream" in func.__qualname__:
log_row["response"] = response
Expand Down Expand Up @@ -802,6 +818,10 @@ def log10(module, DEBUG_=False, USE_ASYNC_=True):
attr = module.resources.messages.Messages
method = getattr(attr, "create")
setattr(attr, "create", intercepting_decorator(method))
elif module.__name__ == "lamini":
attr = module.api.utils.completion.Completion
method = getattr(attr, "generate")
setattr(attr, "generate", intercepting_decorator(method))
elif module.__name__ == "mistralai" and getattr(module, "_log10_patched", False) is False:
attr = module.client.MistralClient
method = getattr(attr, "chat")
Expand Down

0 comments on commit 9847707

Please sign in to comment.