Skip to content

Commit

Permalink
Merge pull request #29 from Mirascope/add-server-tests
Browse files Browse the repository at this point in the history
feat: add unittest for server
  • Loading branch information
willbakst authored Nov 26, 2024
2 parents 498f9d4 + 28f2f86 commit e288838
Show file tree
Hide file tree
Showing 20 changed files with 1,534 additions and 10 deletions.
34 changes: 24 additions & 10 deletions lilypad/server/services/versions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""The `VersionService` class for versions."""

from collections.abc import Sequence
from contextlib import suppress

from fastapi import HTTPException, status
from sqlmodel import col, func, select
Expand Down Expand Up @@ -88,19 +87,34 @@ def find_prompt_active_version(
def change_active_version(
self, project_id: int, new_active_version: VersionTable
) -> VersionTable:
"""Change the active version"""
with suppress(HTTPException):
active_version = self.find_prompt_active_version(
project_id, new_active_version.function_name
)
if active_version.id == new_active_version.id:
return active_version
active_version.is_active = False
self.session.add(active_version)
"""Change the active version for a function, deactivating any currently active versions.
Args:
project_id: The project ID
new_active_version: The version to make active
Returns:
The newly activated version
"""
# Deactivate all currently active versions for the same function
stmt = select(VersionTable).where(
VersionTable.project_id == project_id,
VersionTable.function_name == new_active_version.function_name,
VersionTable.is_active,
)
current_active_versions = self.session.exec(stmt).all()

for version in current_active_versions:
version.is_active = False
self.session.add(version)

# Activate the new version
new_active_version.is_active = True
self.session.add(new_active_version)
self.session.flush()

# Refresh to get latest state
self.session.refresh(new_active_version)
return new_active_version

def get_function_version_count(self, project_id: int, function_name: str) -> int:
Expand Down
Empty file added tests/server/_utils/__init__.py
Empty file.
96 changes: 96 additions & 0 deletions tests/server/_utils/test_spans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# tests/server/_utils/test_spans.py

from lilypad.server._utils.spans import (
convert_anthropic_messages,
convert_gemini_messages,
convert_openai_messages,
)


def test_convert_openai_messages():
"""Test converting OpenAI messages"""
messages = [
{
"name": "gen_ai.user.message",
"attributes": {"content": '["Hello, how are you?"]'},
},
{
"name": "gen_ai.choice",
"attributes": {
"index": 0,
"message": '{"role": "assistant", "content": "I am doing well, thank you!"}',
},
},
]

result = convert_openai_messages(messages)
assert len(result) == 2
assert result[0].role == "user"
assert result[0].content[0].text == "Hello, how are you?" # pyright: ignore [reportAttributeAccessIssue]
assert result[1].role == "assistant"
assert result[1].content[0].text == "I am doing well, thank you!" # pyright: ignore [reportAttributeAccessIssue]


def test_convert_gemini_messages():
"""Test converting Gemini messages"""
messages = [
{"name": "gen_ai.user.message", "attributes": {"content": '["Test message"]'}},
{
"name": "gen_ai.choice",
"attributes": {
"index": 0,
"message": '{"role": "assistant", "content": ["Response"]}',
},
},
]

result = convert_gemini_messages(messages)
assert len(result) == 2
assert result[0].role == "user"
assert result[0].content[0].text == "Test message" # pyright: ignore [reportAttributeAccessIssue]
assert result[1].role == "assistant"
assert result[1].content[0].text == "Response" # pyright: ignore [reportAttributeAccessIssue]


def test_convert_anthropic_messages():
"""Test converting Anthropic messages"""
messages = [
{"name": "gen_ai.user.message", "attributes": {"content": '["User input"]'}},
{
"name": "gen_ai.choice",
"attributes": {
"index": 0,
"message": '{"role": "assistant", "content": "Assistant response"}',
},
},
]

result = convert_anthropic_messages(messages)
assert len(result) == 2
assert result[0].role == "user"
assert result[0].content[0].text == "User input" # pyright: ignore [reportAttributeAccessIssue]
assert result[1].role == "assistant"
assert result[1].content[0].text == "Assistant response" # pyright: ignore [reportAttributeAccessIssue]


def test_invalid_message_content():
"""Test handling invalid message content."""
messages = [
{"name": "gen_ai.user.message", "attributes": {"content": "Invalid JSON"}},
{
"name": "gen_ai.choice",
"attributes": {
"index": 0,
"message": '{"role": "assistant", "content": "Response"}',
},
},
]

result = convert_openai_messages(messages)
assert result[0].content[0].text == "Invalid JSON" # pyright: ignore [reportAttributeAccessIssue]

result = convert_anthropic_messages(messages)
assert result[0].content[0].text == "Invalid JSON" # pyright: ignore [reportAttributeAccessIssue]

result = convert_gemini_messages(messages)
assert result[0].content[0].text == "Invalid JSON" # pyright: ignore [reportAttributeAccessIssue]
31 changes: 31 additions & 0 deletions tests/server/_utils/test_versions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from lilypad.server._utils.versions import construct_function


def test_construct_function():
"""Test constructing function code"""
arg_types = {"text": "str", "temperature": "float"}
function_name = "test_function"

code = construct_function(arg_types, function_name)
assert "@lilypad.prompt()" in code
assert "def test_function(text: str, temperature: float)" in code
assert "-> str" in code


def test_construct_function_with_configure():
"""Test constructing function code with configure flag"""
arg_types = {"text": "str"}
function_name = "test_function"

code = construct_function(arg_types, function_name, configure=True)
assert "lilypad.configure()" in code
assert "@lilypad.prompt()" in code
assert "def test_function(text: str)" in code


def test_construct_function_no_args():
"""Test constructing function with no arguments"""
code = construct_function({}, "test_function")
assert "@lilypad.prompt()" in code
assert "def test_function()" in code
assert "-> str" in code
1 change: 1 addition & 0 deletions tests/server/api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""tests.server.api.__init__.py"""
88 changes: 88 additions & 0 deletions tests/server/api/v0/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Pytest configuration for FastAPI tests."""

from collections.abc import AsyncGenerator, Generator

import pytest
from fastapi.testclient import TestClient
from sqlmodel import Session, SQLModel, create_engine

from lilypad.server.api.v0.main import api
from lilypad.server.db.session import get_session

# Create a single test engine for all tests
TEST_DATABASE_URL = "sqlite:///:memory:"
test_engine = create_engine(
TEST_DATABASE_URL,
connect_args={"check_same_thread": False},
)


@pytest.fixture(scope="session", autouse=True)
def create_test_database():
"""Create test database once for test session."""
SQLModel.metadata.create_all(test_engine)
yield
SQLModel.metadata.drop_all(test_engine)


@pytest.fixture
def session() -> Generator[Session, None, None]:
"""Create a fresh database session for each test.
Yields:
Session: The database session
"""
connection = test_engine.connect()
transaction = connection.begin()
session = Session(bind=connection)

try:
yield session
finally:
session.close()
transaction.rollback()
connection.close()


@pytest.fixture
def get_test_session(
session: Session,
) -> Generator[AsyncGenerator[Session, None], None, None]:
"""Override the get_session dependency for FastAPI.
Args:
session: The test database session
Yields:
AsyncGenerator[Session, None]: Async session generator
"""

async def override_get_session() -> AsyncGenerator[Session, None]:
try:
yield session
finally:
pass # Session is handled by the session fixture

return override_get_session # pyright: ignore [reportReturnType]


@pytest.fixture
def client(
session: Session, get_test_session: AsyncGenerator[Session, None]
) -> TestClient: # pyright: ignore [reportInvalidTypeForm]
"""Create a test client with database session dependency override.
Args:
session: The database session
get_test_session: Session dependency override
Returns:
TestClient: FastAPI test client
"""
api.dependency_overrides[get_session] = get_test_session # pyright: ignore [reportArgumentType]

client = TestClient(api)
try:
yield client # pyright: ignore [reportReturnType]
finally:
api.dependency_overrides.clear()
69 changes: 69 additions & 0 deletions tests/server/api/v0/test_functions_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""Tests for the functions API."""

from collections.abc import Generator

import pytest
from fastapi.testclient import TestClient
from sqlmodel import Session

from lilypad.server.models import FunctionTable, ProjectTable


@pytest.fixture
def test_project(session: Session) -> Generator[ProjectTable, None, None]:
"""Create a test project.
Args:
session: Database session
Yields:
ProjectTable: Test project
"""
project = ProjectTable(name="test_project")
session.add(project)
session.commit()
session.refresh(project)
yield project


@pytest.fixture
def test_function(
session: Session, test_project: ProjectTable
) -> Generator[FunctionTable, None, None]:
"""Create a test function.
Args:
session: Database session
test_project: Parent project
Yields:
FunctionTable: Test function
"""
function = FunctionTable(
project_id=test_project.id,
name="test_function",
hash="test_hash",
code="def test(): pass",
)
session.add(function)
session.commit()
session.refresh(function)
yield function


def test_get_empty_function_names(client: TestClient, test_project: ProjectTable):
"""Test getting function names when no functions exist."""
response = client.get(f"/projects/{test_project.id}/functions/names")
assert response.status_code == 200
assert response.json() == []


def test_get_function_names(
client: TestClient, test_project: ProjectTable, test_function: FunctionTable
):
"""Test getting function names returns expected names."""
response = client.get(f"/projects/{test_project.id}/functions/names")
assert response.status_code == 200
names = response.json()
assert len(names) == 1
assert names[0] == test_function.name
Loading

0 comments on commit e288838

Please sign in to comment.