Skip to content

Commit

Permalink
fix: comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Brendan committed Oct 28, 2024
1 parent b75c38d commit 839c748
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 105 deletions.
38 changes: 37 additions & 1 deletion lilypad/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
import os
import time
import webbrowser
from collections.abc import AsyncIterable, Callable, Coroutine, Iterable
from collections.abc import AsyncIterable, Callable, Coroutine, Generator, Iterable
from contextlib import _GeneratorContextManager, contextmanager
from functools import partial, wraps
from importlib import import_module
from pathlib import Path
from typing import Any, ParamSpec, TypeVar, cast, get_args, get_origin, get_type_hints

from mirascope.core import base as mb
from opentelemetry.trace import get_tracer
from opentelemetry.trace.span import Span
from opentelemetry.util.types import AttributeValue
from pydantic import BaseModel

from lilypad.models import FnParamsPublic, VersionPublic
Expand Down Expand Up @@ -317,6 +320,39 @@ def iterable() -> Iterable[str]:
return decorator


def get_custom_context_manager(
version: VersionPublic,
arg_types: dict[str, str],
arg_values: dict[str, Any],
is_async: bool,
) -> Callable[..., _GeneratorContextManager[Span]]:
@contextmanager
def custom_context_manager(
fn: Callable,
) -> Generator[Span, Any, None]:
tracer = get_tracer("lilypad")
with tracer.start_as_current_span(f"{fn.__name__}") as span:
attributes: dict[str, AttributeValue] = {
"lilypad.project_id": lilypad_client.project_id
if lilypad_client.project_id
else 0,
"lilypad.function_name": fn.__name__,
"lilypad.version": version.version if version.version else "",
"lilypad.version_id": version.id,
"lilypad.arg_types": json.dumps(arg_types),
"lilypad.arg_values": json.dumps(arg_values),
"lilypad.lexical_closure": version.llm_fn.code,
"lilypad.prompt_template": version.fn_params.prompt_template
if version.fn_params
else "",
"lilypad.is_async": is_async,
}
span.set_attributes(attributes)
yield span

return custom_context_manager


def handle_call_response(
result: mb.BaseCallResponse, fn: Callable, span: Span | None
) -> None:
Expand Down
139 changes: 40 additions & 99 deletions lilypad/llm_fn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
"""The lilypad `llm_fn` decorator."""

import inspect
import json
from collections.abc import Callable, Coroutine, Generator
from contextlib import contextmanager
from collections.abc import Callable, Coroutine
from functools import wraps
from typing import (
Any,
Expand All @@ -15,14 +13,12 @@
)

from mirascope.integrations import middleware_factory
from opentelemetry.trace import get_tracer
from opentelemetry.trace.span import Span
from opentelemetry.util.types import AttributeValue

from lilypad._trace import trace
from lilypad.server import client

from ._utils import (
get_custom_context_manager,
get_llm_function_version,
handle_call_response,
handle_call_response_async,
Expand Down Expand Up @@ -85,9 +81,10 @@ def decorator(
async def inner_async(*args: _P.args, **kwargs: _P.kwargs) -> _R:
arg_types, arg_values = inspect_arguments(fn, *args, **kwargs)
version = get_llm_function_version(fn, arg_types, synced)
is_mirascope_call = hasattr(fn, "__mirascope_call__")

if not synced:
trace_decorator = trace(
if not synced and not is_mirascope_call:
decorator = trace(
project_id=lilypad_client.project_id,
version_id=version.id,
arg_types=arg_types,
Expand All @@ -98,39 +95,12 @@ async def inner_async(*args: _P.args, **kwargs: _P.kwargs) -> _R:
else "",
version=version.version,
)
return cast(_R, await trace_decorator(fn)(*args, **kwargs))
return cast(_R, await decorator(fn)(*args, **kwargs))

if not version.fn_params:
raise ValueError(f"Synced function {fn.__name__} has no params.")

fn_params = version.fn_params

@contextmanager
def custom_context_manager(
fn: Callable,
) -> Generator[Span, Any, None]:
tracer = get_tracer("lilypad")
with tracer.start_as_current_span(f"{fn.__name__}") as span:
attributes: dict[str, AttributeValue] = {
"lilypad.project_id": lilypad_client.project_id
if lilypad_client.project_id
else 0,
"lilypad.function_name": fn.__name__,
"lilypad.version": version.version
if version.version
else "",
"lilypad.version_id": version.id,
"lilypad.arg_types": json.dumps(arg_types),
"lilypad.arg_values": json.dumps(arg_values),
"lilypad.lexical_closure": version.llm_fn.code,
"lilypad.prompt_template": fn_params.prompt_template,
"lilypad.is_async": True,
}
span.set_attributes(attributes)
yield span

synced_decorator_async = middleware_factory(
custom_context_manager=custom_context_manager,
decorator = middleware_factory(
custom_context_manager=get_custom_context_manager(
version, arg_types, arg_values, True
),
handle_call_response=handle_call_response,
handle_call_response_async=handle_call_response_async,
handle_stream=handle_stream,
Expand All @@ -140,11 +110,15 @@ def custom_context_manager(
handle_structured_stream=handle_structured_stream,
handle_structured_stream_async=handle_structured_stream_async,
)
if not synced and is_mirascope_call:
return cast(_R, await decorator(fn)(*args, **kwargs))

if not version.fn_params:
raise ValueError(f"Synced function {fn.__name__} has no params.")

return await traced_synced_llm_function_constructor(
version.fn_params, synced_decorator_async
)(fn)(*args, **kwargs)
version.fn_params, decorator
)(fn)(*args, **kwargs) # pyright: ignore [reportReturnType]

return inner_async

Expand All @@ -154,65 +128,26 @@ def custom_context_manager(
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
arg_types, arg_values = inspect_arguments(fn, *args, **kwargs)
version = get_llm_function_version(fn, arg_types, synced)
is_mirascope_call = hasattr(fn, "__mirascope_call__")

@contextmanager
def custom_context_manager(
fn: Callable,
) -> Generator[Span, Any, None]:
tracer = get_tracer("lilypad")
with tracer.start_as_current_span(f"{fn.__name__}") as span:
attributes: dict[str, AttributeValue] = {
"lilypad.project_id": lilypad_client.project_id
if lilypad_client.project_id
else 0,
"lilypad.function_name": fn.__name__,
"lilypad.version": version.version
if version.version
else "",
"lilypad.version_id": version.id,
"lilypad.arg_types": json.dumps(arg_types),
"lilypad.arg_values": json.dumps(arg_values),
"lilypad.lexical_closure": version.llm_fn.code,
"lilypad.prompt_template": version.fn_params.prompt_template
if version.fn_params
else "",
"lilypad.is_async": False,
}
span.set_attributes(attributes)
yield span

if not synced:
if hasattr(fn, "__mirascope_call__"):
decorator = middleware_factory(
custom_context_manager=custom_context_manager,
handle_call_response=handle_call_response,
handle_call_response_async=handle_call_response_async,
handle_stream=handle_stream,
handle_stream_async=handle_stream_async,
handle_response_model=handle_response_model,
handle_response_model_async=handle_response_model_async,
handle_structured_stream=handle_structured_stream,
handle_structured_stream_async=handle_structured_stream_async,
)
else:
decorator = trace(
project_id=lilypad_client.project_id,
version_id=version.id,
arg_types=arg_types,
arg_values=arg_values,
lexical_closure=version.llm_fn.code,
prompt_template=version.fn_params.prompt_template
if version.fn_params
else "",
version=version.version,
)
if not synced and not is_mirascope_call:
decorator = trace(
project_id=lilypad_client.project_id,
version_id=version.id,
arg_types=arg_types,
arg_values=arg_values,
lexical_closure=version.llm_fn.code,
prompt_template=version.fn_params.prompt_template
if version.fn_params
else "",
version=version.version,
)
return cast(_R, decorator(fn)(*args, **kwargs))

if not version.fn_params:
raise ValueError(f"Synced function {fn.__name__} has no params.")

synced_decorator = middleware_factory(
custom_context_manager=custom_context_manager,
decorator = middleware_factory(
custom_context_manager=get_custom_context_manager(
version, arg_types, arg_values, False
),
handle_call_response=handle_call_response,
handle_call_response_async=handle_call_response_async,
handle_stream=handle_stream,
Expand All @@ -222,8 +157,14 @@ def custom_context_manager(
handle_structured_stream=handle_structured_stream,
handle_structured_stream_async=handle_structured_stream_async,
)
if not synced and is_mirascope_call:
return cast(_R, decorator(fn)(*args, **kwargs))

if not version.fn_params:
raise ValueError(f"Synced function {fn.__name__} has no params.")

return traced_synced_llm_function_constructor(
version.fn_params, synced_decorator
version.fn_params, decorator
)(fn)(*args, **kwargs) # pyright: ignore [reportReturnType]

return inner
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ classifiers = [
"Topic :: Software Development :: Libraries",
]
dependencies = [
"mirascope>=1.5.2",
"mirascope>=1.6.1",
"sqlmodel>=0.0.22",
"psycopg2-binary>=2.9.9",
"fastapi[standard]>=0.114.0",
Expand Down
8 changes: 4 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 839c748

Please sign in to comment.