Skip to content

Commit

Permalink
Merge pull request #4 from Mirascope/track-calls
Browse files Browse the repository at this point in the history
Track calls
  • Loading branch information
Brendan Kao authored Sep 7, 2024
2 parents c9f07ce + 4042c1a commit 0b3df08
Show file tree
Hide file tree
Showing 11 changed files with 110 additions and 13 deletions.
1 change: 1 addition & 0 deletions examples/sqlite_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
client = OpenAI()


@lilypad.trace
def recommend_book(genre: str) -> str | None:
"""Recommends a `genre` book using OpenAI"""
prompt = lilypad.prompt(recommend_book)(genre)
Expand Down
Binary file removed lilypad/app/database.db
Binary file not shown.
41 changes: 40 additions & 1 deletion lilypad/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from fastapi import Depends, FastAPI, Form, Header, HTTPException, Request
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel
from sqlmodel import Session, select

from lilypad.app.db.session import get_session
from lilypad.app.models import ProjectTable, PromptVersionTable
from lilypad.app.models import CallTable, ProjectTable, PromptVersionTable

app = FastAPI()
templates = Jinja2Templates(directory="templates")
Expand Down Expand Up @@ -163,3 +164,41 @@ async def view_version(
"project": prompt_version.project,
},
)


class CallCreate(BaseModel):
"""Call create model."""

project_name: str
input: str
output: str


@app.post("/calls")
async def create_calls(
session: Annotated[Session, Depends(get_session)], call_create: CallCreate
) -> bool:
"""Creates a logged call."""
project = session.exec(
select(ProjectTable).where(ProjectTable.name == call_create.project_name)
).first()

if not project:
raise HTTPException(status_code=404)

prompt_version = sorted(
project.prompt_versions, key=lambda x: x.id or 0, reverse=True
)[:1][0]

if not prompt_version:
raise HTTPException(status_code=404)

call = CallTable(
prompt_version_id=prompt_version.id,
input=call_create.input,
output=call_create.output,
)
session.add(call)
session.commit()
session.flush()
return True
3 changes: 2 additions & 1 deletion lilypad/app/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""This module initializes the models package."""

from .base_sql_model import BaseSQLModel
from .calls import CallTable
from .projects import ProjectTable
from .prompt_versions import PromptVersionTable

__all__ = ["BaseSQLModel", "ProjectTable", "PromptVersionTable"]
__all__ = ["BaseSQLModel", "CallTable", "ProjectTable", "PromptVersionTable"]
30 changes: 30 additions & 0 deletions lilypad/app/models/calls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Project model"""

import datetime
from typing import TYPE_CHECKING

from sqlmodel import Field, Relationship

from lilypad.app.models import BaseSQLModel

from .table_names import CALL_TABLE_NAME, PROMPT_VERSION_TABLE_NAME

if TYPE_CHECKING:
from lilypad.app.models import PromptVersionTable


class CallTable(BaseSQLModel, table=True):
"""Call model"""

__tablename__ = CALL_TABLE_NAME # type: ignore

id: int = Field(default=None, primary_key=True)
prompt_version_id: int = Field(
default=None, foreign_key=f"{PROMPT_VERSION_TABLE_NAME}.id"
)
input: str = Field(nullable=False)
output: str = Field(nullable=False)
created_at: datetime.datetime = Field(
default=datetime.datetime.now(datetime.UTC), nullable=False
)
prompt_version: "PromptVersionTable" = Relationship(back_populates="calls")
2 changes: 1 addition & 1 deletion lilypad/app/models/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class ProjectTable(BaseSQLModel, table=True):

__tablename__ = PROJECT_TABLE_NAME # type: ignore

id: int | None = Field(default=None, primary_key=True)
id: int = Field(default=None, primary_key=True)
name: str = Field(nullable=False, unique=True)
created_at: datetime.datetime = Field(
default=datetime.datetime.now(datetime.UTC), nullable=False
Expand Down
9 changes: 6 additions & 3 deletions lilypad/app/models/prompt_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@
from .table_names import PROJECT_TABLE_NAME, PROMPT_VERSION_TABLE_NAME

if TYPE_CHECKING:
from lilypad.app.models import ProjectTable
from lilypad.app.models import CallTable, ProjectTable


class PromptVersionTable(BaseSQLModel, table=True):
"""Prompt version model"""

__tablename__ = PROMPT_VERSION_TABLE_NAME # type: ignore

id: int | None = Field(default=None, primary_key=True)
project_id: int | None = Field(default=None, foreign_key=f"{PROJECT_TABLE_NAME}.id")
id: int = Field(default=None, primary_key=True)
project_id: int = Field(default=None, foreign_key=f"{PROJECT_TABLE_NAME}.id")
prompt_template: str = Field(nullable=False)
created_at: datetime.datetime = Field(
default=datetime.datetime.now(datetime.UTC), nullable=False
Expand All @@ -29,3 +29,6 @@ class PromptVersionTable(BaseSQLModel, table=True):
)

project: "ProjectTable" = Relationship(back_populates="prompt_versions")
calls: list["CallTable"] = Relationship(
back_populates="prompt_version", cascade_delete=True
)
1 change: 1 addition & 0 deletions lilypad/app/models/table_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@

PROJECT_TABLE_NAME = "projects"
PROMPT_VERSION_TABLE_NAME = "prompt_versions"
CALL_TABLE_NAME = "calls"
26 changes: 24 additions & 2 deletions lilypad/trace.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""A decorator for tracing functions."""

import inspect
import json
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar, overload
from typing import Any, ParamSpec, TypeVar, overload

import requests
from openai import OpenAI

_P = ParamSpec("_P")
Expand Down Expand Up @@ -39,6 +42,25 @@ def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:

@wraps(fn)
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
return fn(*args, **kwargs)
url = "http://localhost:8000/calls"

params_dict: dict[str, Any] = {}
bound_args = inspect.signature(fn).bind(*args, **kwargs)
for param_name, param_value in bound_args.arguments.items():
params_dict[param_name] = param_value
output = fn(*args, **kwargs)
input = json.dumps(params_dict)

data = {
"project_name": fn.__name__,
"input": input,
"output": output,
}

try:
requests.post(url, json=data)
except requests.exceptions.RequestException as e:
print(f"An error occurred: {e}") # noqa: T201
return output

return inner
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ classifiers = [
]
dependencies = [
"mirascope>=1.1.2",
"fastapi[standard]>=0.113",
"sqlmodel>=0.0.22",
"psycopg2-binary>=2.9.9",
"fastapi[standard]>=0.114.0",
]

[project.urls]
Expand Down
8 changes: 4 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 0b3df08

Please sign in to comment.