Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
codetheweb committed Apr 23, 2024
1 parent 2928e91 commit 5e59276
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 6 deletions.
4 changes: 2 additions & 2 deletions chromadb/test/property/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def test_empty_filter(api: ServerAPI) -> None:
assert res["embeddings"] == [[]]
assert res["distances"] == [[]]
assert res["metadatas"] == [[]]
assert res["included"] == ["embeddings", "distances", "metadatas"]
assert set(res["included"]) == set(["embeddings", "distances", "metadatas"])

res = coll.query(
query_embeddings=test_query_embeddings,
Expand All @@ -349,7 +349,7 @@ def test_empty_filter(api: ServerAPI) -> None:
assert res["embeddings"] is None
assert res["distances"] == [[], []]
assert res["metadatas"] == [[], []]
assert res["included"] == ["metadatas", "documents", "distances"]
assert set(res["included"]) == set(["metadatas", "documents", "distances"])


def test_boolean_metadata(api: ServerAPI) -> None:
Expand Down
30 changes: 26 additions & 4 deletions chromadb/test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def test_persist_index_loading(api_fixture, request):
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None

Expand Down Expand Up @@ -118,6 +120,8 @@ def __call__(self, input):
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None

Expand Down Expand Up @@ -146,6 +150,8 @@ def __call__(self, input):
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None

Expand Down Expand Up @@ -260,6 +266,8 @@ def test_get_from_db(api):
for key in records.keys():
if (key in includes) or (key == "ids"):
assert len(records[key]) == 2
elif key == "included":
assert set(records[key]) == set(includes)
else:
assert records[key] is None

Expand Down Expand Up @@ -290,6 +298,8 @@ def test_get_nearest_neighbors(api):
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None

Expand All @@ -302,6 +312,8 @@ def test_get_nearest_neighbors(api):
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None

Expand All @@ -314,6 +326,8 @@ def test_get_nearest_neighbors(api):
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 2
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None

Expand Down Expand Up @@ -437,6 +451,8 @@ def test_increment_index_on(api):
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None

Expand Down Expand Up @@ -489,6 +505,8 @@ def test_peek(api):
for key in peek.keys():
if key in ["embeddings", "documents", "metadatas"] or key == "ids":
assert len(peek[key]) == 2
elif key == "included":
assert set(peek[key]) == set(["embeddings", "metadatas", "documents"])
else:
assert peek[key] is None

Expand Down Expand Up @@ -1003,7 +1021,7 @@ def test_query_include(api):
assert items["embeddings"] is None
assert items["ids"][0][0] == "id1"
assert items["metadatas"][0][0]["int_value"] == 1
assert items["included"] == include
assert set(items["included"]) == set(include)

include = ["embeddings", "documents", "distances"]
items = collection.query(
Expand All @@ -1013,7 +1031,7 @@ def test_query_include(api):
)
assert items["metadatas"] is None
assert items["ids"][0][0] == "id1"
assert items["included"] == include
assert set(items["included"]) == set(include)

items = collection.query(
query_embeddings=[[0, 0, 0], [1, 2, 1.2]],
Expand All @@ -1039,14 +1057,14 @@ def test_get_include(api):
assert items["ids"][0] == "id1"
assert items["metadatas"][0]["int_value"] == 1
assert items["documents"][0] == "this document is first"
assert items["included"] == include
assert set(items["included"]) == set(include)

include = ["embeddings", "documents"]
items = collection.get(include=include)
assert items["metadatas"] is None
assert items["ids"][0] == "id1"
assert approx_equal(items["embeddings"][1][0], 1.2)
assert items["included"] == include
assert set(items["included"]) == set(include)

items = collection.get(include=[])
assert items["documents"] is None
Expand Down Expand Up @@ -1181,6 +1199,8 @@ def test_persist_index_loading_params(api, request):
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None

Expand Down Expand Up @@ -1299,6 +1319,8 @@ def test_get_nearest_neighbors_where_n_results_more_than_element(api):
for key in results.keys():
if key in includes or key == "ids":
assert len(results[key][0]) == 2
elif key == "included":
assert set(results[key]) == set(includes)
else:
assert results[key] is None

Expand Down

0 comments on commit 5e59276

Please sign in to comment.