Skip to content

Commit

Permalink
feat: added tests for the lamacpp client
Browse files Browse the repository at this point in the history
  • Loading branch information
umbertogriffo committed May 4, 2024
1 parent dd42c4e commit 22d5082
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 1 deletion.
20 changes: 19 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ nest_asyncio = "~=1.5.8"
pytest = "~=7.2.1"
pytest-cov = "~=4.0.0"
pytest-mock = "~=3.10.0"
pytest-asyncio = "~=0.23.6"
pre-commit = "~=3.6.0"
ruff = "~=0.1.9"
httpx = "~=0.23.3"
Expand Down
44 changes: 44 additions & 0 deletions tests/test_lamacpp_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -41,3 +42,46 @@ def test_generate_answer(lamacpp_client):
prompt = "What is the capital city of Italy?"
generated_answer = lamacpp_client.generate_answer(prompt, max_new_tokens=10)
assert "rome" in generated_answer.lower()


def test_generate_stream_answer(lamacpp_client):
prompt = "What is the capital city of Italy?"
generated_answer = lamacpp_client.stream_answer(prompt, max_new_tokens=10)
assert "rome" in generated_answer.lower()


def test_start_answer_iterator_streamer(lamacpp_client):
prompt = "What is the capital city of Italy?"
stream = lamacpp_client.start_answer_iterator_streamer(prompt, max_new_tokens=10)
generated_answer = ""
for output in stream:
generated_answer += output["choices"][0]["text"]
assert "rome" in generated_answer.lower()


@pytest.mark.asyncio
async def test_async_generate_answer(lamacpp_client):
prompt = "What is the capital city of Italy?"
task = lamacpp_client.async_generate_answer(prompt, max_new_tokens=10)
generated_answer = await asyncio.gather(task)
assert "rome" in generated_answer[0].lower()


@pytest.mark.asyncio
async def test_async_start_answer_iterator_streamer(lamacpp_client):
prompt = "What is the capital city of Italy?"
task = lamacpp_client.async_start_answer_iterator_streamer(prompt, max_new_tokens=10)
stream = await asyncio.gather(task)
generated_answer = ""
for output in stream[0]:
generated_answer += output["choices"][0]["text"]
assert "rome" in generated_answer.lower()


def test_parse_token(lamacpp_client):
prompt = "What is the capital city of Italy?"
stream = lamacpp_client.start_answer_iterator_streamer(prompt, max_new_tokens=10)
generated_answer = ""
for output in stream:
generated_answer += lamacpp_client.parse_token(output)
assert "rome" in generated_answer.lower()

0 comments on commit 22d5082

Please sign in to comment.