Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
codetheweb committed Apr 19, 2024
1 parent 78b0e0c commit 315a680
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 23 deletions.
3 changes: 2 additions & 1 deletion chromadb/test/property/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Metadatas,
Where,
WhereDocument,
Omitted,
)
import chromadb.test.property.strategies as strategies
import hypothesis.strategies as st
Expand Down Expand Up @@ -345,7 +346,7 @@ def test_empty_filter(api: ServerAPI) -> None:
n_results=3,
)
assert res["ids"] == [[], []]
assert res["embeddings"] is None
assert isinstance(res["embeddings"], Omitted)
assert res["distances"] == [[], []]
assert res["metadatas"] == [[], []]

Expand Down
44 changes: 22 additions & 22 deletions chromadb/test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import chromadb
from chromadb.api.fastapi import FastAPI
from chromadb.api.types import QueryResult, EmbeddingFunction, Document
from chromadb.api.types import QueryResult, EmbeddingFunction, Document, Omitted
from chromadb.config import Settings
import chromadb.server.fastapi
import pytest
Expand Down Expand Up @@ -92,7 +92,7 @@ def test_persist_index_loading(api_fixture, request):
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
else:
assert nn[key] is None
assert isinstance(nn[key], Omitted)


@pytest.mark.parametrize("api_fixture", [local_persist_api])
Expand All @@ -119,7 +119,7 @@ def __call__(self, input):
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
else:
assert nn[key] is None
assert isinstance(nn[key], Omitted)


@pytest.mark.parametrize("api_fixture", [local_persist_api])
Expand Down Expand Up @@ -147,7 +147,7 @@ def __call__(self, input):
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
else:
assert nn[key] is None
assert isinstance(nn[key], Omitted)

assert nn["ids"] == [["id1"]]
assert nn["embeddings"] == [[[1, 2, 3]]]
Expand Down Expand Up @@ -261,7 +261,7 @@ def test_get_from_db(api):
if (key in includes) or (key == "ids"):
assert len(records[key]) == 2
else:
assert records[key] is None
assert isinstance(records[key], Omitted)


def test_reset_db(api):
Expand Down Expand Up @@ -291,7 +291,7 @@ def test_get_nearest_neighbors(api):
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
else:
assert nn[key] is None
assert isinstance(nn[key], Omitted)

nn = collection.query(
query_embeddings=[[1.1, 2.3, 3.2]],
Expand All @@ -303,7 +303,7 @@ def test_get_nearest_neighbors(api):
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
else:
assert nn[key] is None
assert isinstance(nn[key], Omitted)

nn = collection.query(
query_embeddings=[[1.1, 2.3, 3.2], [0.1, 2.3, 4.5]],
Expand All @@ -315,7 +315,7 @@ def test_get_nearest_neighbors(api):
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 2
else:
assert nn[key] is None
assert isinstance(nn[key], Omitted)


def test_delete(api):
Expand Down Expand Up @@ -438,7 +438,7 @@ def test_increment_index_on(api):
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
else:
assert nn[key] is None
assert isinstance(nn[key], Omitted)


def test_add_a_collection(api):
Expand Down Expand Up @@ -999,7 +999,7 @@ def test_query_include(api):
include=["metadatas", "documents", "distances"],
n_results=1,
)
assert items["embeddings"] is None
assert isinstance(items["embeddings"], Omitted)
assert items["ids"][0][0] == "id1"
assert items["metadatas"][0][0]["int_value"] == 1

Expand All @@ -1008,18 +1008,18 @@ def test_query_include(api):
include=["embeddings", "documents", "distances"],
n_results=1,
)
assert items["metadatas"] is None
assert isinstance(items["metadatas"], Omitted)
assert items["ids"][0][0] == "id1"

items = collection.query(
query_embeddings=[[0, 0, 0], [1, 2, 1.2]],
include=[],
n_results=2,
)
assert items["documents"] is None
assert items["metadatas"] is None
assert items["embeddings"] is None
assert items["distances"] is None
assert isinstance(items["documents"], Omitted)
assert isinstance(items["metadatas"], Omitted)
assert isinstance(items["embeddings"], Omitted)
assert isinstance(items["distances"], Omitted)
assert items["ids"][0][0] == "id1"
assert items["ids"][0][1] == "id2"

Expand All @@ -1030,20 +1030,20 @@ def test_get_include(api):
collection.add(**records)

items = collection.get(include=["metadatas", "documents"], where={"int_value": 1})
assert items["embeddings"] is None
assert isinstance(items["embeddings"], Omitted)
assert items["ids"][0] == "id1"
assert items["metadatas"][0]["int_value"] == 1
assert items["documents"][0] == "this document is first"

items = collection.get(include=["embeddings", "documents"])
assert items["metadatas"] is None
assert isinstance(items["metadatas"], Omitted)
assert items["ids"][0] == "id1"
assert approx_equal(items["embeddings"][1][0], 1.2)

items = collection.get(include=[])
assert items["documents"] is None
assert items["metadatas"] is None
assert items["embeddings"] is None
assert isinstance(items["documents"], Omitted)
assert isinstance(items["metadatas"], Omitted)
assert isinstance(items["embeddings"], Omitted)
assert items["ids"][0] == "id1"

with pytest.raises(ValueError, match="include"):
Expand Down Expand Up @@ -1173,7 +1173,7 @@ def test_persist_index_loading_params(api, request):
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
else:
assert nn[key] is None
assert isinstance(nn[key], Omitted)


def test_add_large(api):
Expand Down Expand Up @@ -1291,7 +1291,7 @@ def test_get_nearest_neighbors_where_n_results_more_than_element(api):
if key in includes or key == "ids":
assert len(results[key][0]) == 2
else:
assert results[key] is None
assert isinstance(results[key], Omitted)


def test_invalid_n_results_param(api):
Expand Down

0 comments on commit 315a680

Please sign in to comment.