-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #29 from Mirascope/add-server-tests
feat: add unittest for server
- Loading branch information
Showing
20 changed files
with
1,534 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# tests/server/_utils/test_spans.py | ||
|
||
from lilypad.server._utils.spans import ( | ||
convert_anthropic_messages, | ||
convert_gemini_messages, | ||
convert_openai_messages, | ||
) | ||
|
||
|
||
def test_convert_openai_messages(): | ||
"""Test converting OpenAI messages""" | ||
messages = [ | ||
{ | ||
"name": "gen_ai.user.message", | ||
"attributes": {"content": '["Hello, how are you?"]'}, | ||
}, | ||
{ | ||
"name": "gen_ai.choice", | ||
"attributes": { | ||
"index": 0, | ||
"message": '{"role": "assistant", "content": "I am doing well, thank you!"}', | ||
}, | ||
}, | ||
] | ||
|
||
result = convert_openai_messages(messages) | ||
assert len(result) == 2 | ||
assert result[0].role == "user" | ||
assert result[0].content[0].text == "Hello, how are you?" # pyright: ignore [reportAttributeAccessIssue] | ||
assert result[1].role == "assistant" | ||
assert result[1].content[0].text == "I am doing well, thank you!" # pyright: ignore [reportAttributeAccessIssue] | ||
|
||
|
||
def test_convert_gemini_messages(): | ||
"""Test converting Gemini messages""" | ||
messages = [ | ||
{"name": "gen_ai.user.message", "attributes": {"content": '["Test message"]'}}, | ||
{ | ||
"name": "gen_ai.choice", | ||
"attributes": { | ||
"index": 0, | ||
"message": '{"role": "assistant", "content": ["Response"]}', | ||
}, | ||
}, | ||
] | ||
|
||
result = convert_gemini_messages(messages) | ||
assert len(result) == 2 | ||
assert result[0].role == "user" | ||
assert result[0].content[0].text == "Test message" # pyright: ignore [reportAttributeAccessIssue] | ||
assert result[1].role == "assistant" | ||
assert result[1].content[0].text == "Response" # pyright: ignore [reportAttributeAccessIssue] | ||
|
||
|
||
def test_convert_anthropic_messages(): | ||
"""Test converting Anthropic messages""" | ||
messages = [ | ||
{"name": "gen_ai.user.message", "attributes": {"content": '["User input"]'}}, | ||
{ | ||
"name": "gen_ai.choice", | ||
"attributes": { | ||
"index": 0, | ||
"message": '{"role": "assistant", "content": "Assistant response"}', | ||
}, | ||
}, | ||
] | ||
|
||
result = convert_anthropic_messages(messages) | ||
assert len(result) == 2 | ||
assert result[0].role == "user" | ||
assert result[0].content[0].text == "User input" # pyright: ignore [reportAttributeAccessIssue] | ||
assert result[1].role == "assistant" | ||
assert result[1].content[0].text == "Assistant response" # pyright: ignore [reportAttributeAccessIssue] | ||
|
||
|
||
def test_invalid_message_content(): | ||
"""Test handling invalid message content.""" | ||
messages = [ | ||
{"name": "gen_ai.user.message", "attributes": {"content": "Invalid JSON"}}, | ||
{ | ||
"name": "gen_ai.choice", | ||
"attributes": { | ||
"index": 0, | ||
"message": '{"role": "assistant", "content": "Response"}', | ||
}, | ||
}, | ||
] | ||
|
||
result = convert_openai_messages(messages) | ||
assert result[0].content[0].text == "Invalid JSON" # pyright: ignore [reportAttributeAccessIssue] | ||
|
||
result = convert_anthropic_messages(messages) | ||
assert result[0].content[0].text == "Invalid JSON" # pyright: ignore [reportAttributeAccessIssue] | ||
|
||
result = convert_gemini_messages(messages) | ||
assert result[0].content[0].text == "Invalid JSON" # pyright: ignore [reportAttributeAccessIssue] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from lilypad.server._utils.versions import construct_function | ||
|
||
|
||
def test_construct_function(): | ||
"""Test constructing function code""" | ||
arg_types = {"text": "str", "temperature": "float"} | ||
function_name = "test_function" | ||
|
||
code = construct_function(arg_types, function_name) | ||
assert "@lilypad.prompt()" in code | ||
assert "def test_function(text: str, temperature: float)" in code | ||
assert "-> str" in code | ||
|
||
|
||
def test_construct_function_with_configure(): | ||
"""Test constructing function code with configure flag""" | ||
arg_types = {"text": "str"} | ||
function_name = "test_function" | ||
|
||
code = construct_function(arg_types, function_name, configure=True) | ||
assert "lilypad.configure()" in code | ||
assert "@lilypad.prompt()" in code | ||
assert "def test_function(text: str)" in code | ||
|
||
|
||
def test_construct_function_no_args(): | ||
"""Test constructing function with no arguments""" | ||
code = construct_function({}, "test_function") | ||
assert "@lilypad.prompt()" in code | ||
assert "def test_function()" in code | ||
assert "-> str" in code |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""tests.server.api.__init__.py""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
"""Pytest configuration for FastAPI tests.""" | ||
|
||
from collections.abc import AsyncGenerator, Generator | ||
|
||
import pytest | ||
from fastapi.testclient import TestClient | ||
from sqlmodel import Session, SQLModel, create_engine | ||
|
||
from lilypad.server.api.v0.main import api | ||
from lilypad.server.db.session import get_session | ||
|
||
# Create a single test engine for all tests | ||
TEST_DATABASE_URL = "sqlite:///:memory:" | ||
test_engine = create_engine( | ||
TEST_DATABASE_URL, | ||
connect_args={"check_same_thread": False}, | ||
) | ||
|
||
|
||
@pytest.fixture(scope="session", autouse=True) | ||
def create_test_database(): | ||
"""Create test database once for test session.""" | ||
SQLModel.metadata.create_all(test_engine) | ||
yield | ||
SQLModel.metadata.drop_all(test_engine) | ||
|
||
|
||
@pytest.fixture | ||
def session() -> Generator[Session, None, None]: | ||
"""Create a fresh database session for each test. | ||
Yields: | ||
Session: The database session | ||
""" | ||
connection = test_engine.connect() | ||
transaction = connection.begin() | ||
session = Session(bind=connection) | ||
|
||
try: | ||
yield session | ||
finally: | ||
session.close() | ||
transaction.rollback() | ||
connection.close() | ||
|
||
|
||
@pytest.fixture | ||
def get_test_session( | ||
session: Session, | ||
) -> Generator[AsyncGenerator[Session, None], None, None]: | ||
"""Override the get_session dependency for FastAPI. | ||
Args: | ||
session: The test database session | ||
Yields: | ||
AsyncGenerator[Session, None]: Async session generator | ||
""" | ||
|
||
async def override_get_session() -> AsyncGenerator[Session, None]: | ||
try: | ||
yield session | ||
finally: | ||
pass # Session is handled by the session fixture | ||
|
||
return override_get_session # pyright: ignore [reportReturnType] | ||
|
||
|
||
@pytest.fixture | ||
def client( | ||
session: Session, get_test_session: AsyncGenerator[Session, None] | ||
) -> TestClient: # pyright: ignore [reportInvalidTypeForm] | ||
"""Create a test client with database session dependency override. | ||
Args: | ||
session: The database session | ||
get_test_session: Session dependency override | ||
Returns: | ||
TestClient: FastAPI test client | ||
""" | ||
api.dependency_overrides[get_session] = get_test_session # pyright: ignore [reportArgumentType] | ||
|
||
client = TestClient(api) | ||
try: | ||
yield client # pyright: ignore [reportReturnType] | ||
finally: | ||
api.dependency_overrides.clear() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
"""Tests for the functions API.""" | ||
|
||
from collections.abc import Generator | ||
|
||
import pytest | ||
from fastapi.testclient import TestClient | ||
from sqlmodel import Session | ||
|
||
from lilypad.server.models import FunctionTable, ProjectTable | ||
|
||
|
||
@pytest.fixture | ||
def test_project(session: Session) -> Generator[ProjectTable, None, None]: | ||
"""Create a test project. | ||
Args: | ||
session: Database session | ||
Yields: | ||
ProjectTable: Test project | ||
""" | ||
project = ProjectTable(name="test_project") | ||
session.add(project) | ||
session.commit() | ||
session.refresh(project) | ||
yield project | ||
|
||
|
||
@pytest.fixture | ||
def test_function( | ||
session: Session, test_project: ProjectTable | ||
) -> Generator[FunctionTable, None, None]: | ||
"""Create a test function. | ||
Args: | ||
session: Database session | ||
test_project: Parent project | ||
Yields: | ||
FunctionTable: Test function | ||
""" | ||
function = FunctionTable( | ||
project_id=test_project.id, | ||
name="test_function", | ||
hash="test_hash", | ||
code="def test(): pass", | ||
) | ||
session.add(function) | ||
session.commit() | ||
session.refresh(function) | ||
yield function | ||
|
||
|
||
def test_get_empty_function_names(client: TestClient, test_project: ProjectTable): | ||
"""Test getting function names when no functions exist.""" | ||
response = client.get(f"/projects/{test_project.id}/functions/names") | ||
assert response.status_code == 200 | ||
assert response.json() == [] | ||
|
||
|
||
def test_get_function_names( | ||
client: TestClient, test_project: ProjectTable, test_function: FunctionTable | ||
): | ||
"""Test getting function names returns expected names.""" | ||
response = client.get(f"/projects/{test_project.id}/functions/names") | ||
assert response.status_code == 200 | ||
names = response.json() | ||
assert len(names) == 1 | ||
assert names[0] == test_function.name |
Oops, something went wrong.