Skip to content

Commit

Permalink
update test_collection
Browse files Browse the repository at this point in the history
  • Loading branch information
bjchambers committed Jan 30, 2024
1 parent 7e2a6e8 commit 030a6c3
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 19 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ pytest-docker-fixtures = {extras = ["pg"], version = "^1.3.18"}
asgi-lifespan = "^2.1.0"
openapi-python-client = "^0.17.2"
poethepoet = "^0.24.4"
dewy-client = { path = "./dewy-client" }

[build-system]
requires = ["poetry-core"]
Expand Down
7 changes: 5 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from asgi_lifespan import LifespanManager
from httpx import AsyncClient
from dewy_client import Client

pytest_plugins = ["pytest_docker_fixtures"]

Expand Down Expand Up @@ -37,8 +38,10 @@ async def app(pg, event_loop):


@pytest.fixture(scope="session")
async def client(app) -> AsyncClient:
async with AsyncClient(app=app, base_url="http://test") as client:
async def client(app) -> Client:
async with AsyncClient(app=app, base_url="http://test") as httpx_client:
client = Client(base_url="http://test")
client.set_async_httpx_client(httpx_client)
yield client


Expand Down
31 changes: 14 additions & 17 deletions tests/test_collection.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,28 @@
import random
import string

from dewy_client.api.default import add_collection, get_collection, list_collections
from dewy_client.models import CollectionCreate

async def test_create_collection(client):
name = "".join(random.choices(string.ascii_lowercase, k=5))
create_response = await client.put("/api/collections/", json={"name": name})
assert create_response.status_code == 200
collection = await add_collection.asyncio(client=client, body=CollectionCreate(
name = name
))

json = create_response.json()
assert json["name"] == name
assert json["text_embedding_model"] == "openai:text-embedding-ada-002"
assert json["text_distance_metric"] == "cosine"
assert collection.name == name
assert collection.text_embedding_model == "openai:text-embedding-ada-002"
assert collection.text_distance_metric == "cosine"

collection_id = json["id"]
collection_id = collection.id

list_response = await client.get("/api/collections/")
assert list_response.status_code == 200
list_response = await list_collections.asyncio(client=client)

# "find" the collection with the new collection ID, since
# other tests may have created other collections
json = list_response.json()
collection_row = next(x for x in list_response.json() if x["id"] == collection_id)
collection_row = next(x for x in list_response if x.id == collection_id)
assert collection_row is not None
assert collection_row["name"] == name
assert collection_row.name == name

get_response = await client.get(f"/api/collections/{collection_id}")
assert get_response.status_code == 200
get_response = await get_collection.asyncio(collection_id, client=client)

json = get_response.json()
assert collection_row["name"] == name
assert get_response.name == name

0 comments on commit 030a6c3

Please sign in to comment.