diff --git a/chromadb/test/property/test_filtering.py b/chromadb/test/property/test_filtering.py index 8cdd021a0c5..2826caf12eb 100644 --- a/chromadb/test/property/test_filtering.py +++ b/chromadb/test/property/test_filtering.py @@ -338,7 +338,7 @@ def test_empty_filter(api: ServerAPI) -> None: assert res["embeddings"] == [[]] assert res["distances"] == [[]] assert res["metadatas"] == [[]] - assert res["included"] == ["embeddings", "distances", "metadatas"] + assert set(res["included"]) == set(["embeddings", "distances", "metadatas"]) res = coll.query( query_embeddings=test_query_embeddings, @@ -349,7 +349,7 @@ def test_empty_filter(api: ServerAPI) -> None: assert res["embeddings"] is None assert res["distances"] == [[], []] assert res["metadatas"] == [[], []] - assert res["included"] == ["metadatas", "documents", "distances"] + assert set(res["included"]) == set(["metadatas", "documents", "distances"]) def test_boolean_metadata(api: ServerAPI) -> None: diff --git a/chromadb/test/test_api.py b/chromadb/test/test_api.py index 5054ceabec4..936dc7e2ed1 100644 --- a/chromadb/test/test_api.py +++ b/chromadb/test/test_api.py @@ -91,6 +91,8 @@ def test_persist_index_loading(api_fixture, request): for key in nn.keys(): if (key in includes) or (key == "ids"): assert len(nn[key]) == 1 + elif key == "included": + assert set(nn[key]) == set(includes) else: assert nn[key] is None @@ -118,6 +120,8 @@ def __call__(self, input): for key in nn.keys(): if (key in includes) or (key == "ids"): assert len(nn[key]) == 1 + elif key == "included": + assert set(nn[key]) == set(includes) else: assert nn[key] is None @@ -146,6 +150,8 @@ def __call__(self, input): for key in nn.keys(): if (key in includes) or (key == "ids"): assert len(nn[key]) == 1 + elif key == "included": + assert set(nn[key]) == set(includes) else: assert nn[key] is None @@ -260,6 +266,8 @@ def test_get_from_db(api): for key in records.keys(): if (key in includes) or (key == "ids"): assert len(records[key]) == 2 + elif key == "included": + assert set(records[key]) == set(includes) else: assert records[key] is None @@ -290,6 +298,8 @@ def test_get_nearest_neighbors(api): for key in nn.keys(): if (key in includes) or (key == "ids"): assert len(nn[key]) == 1 + elif key == "included": + assert set(nn[key]) == set(includes) else: assert nn[key] is None @@ -302,6 +312,8 @@ def test_get_nearest_neighbors(api): for key in nn.keys(): if (key in includes) or (key == "ids"): assert len(nn[key]) == 1 + elif key == "included": + assert set(nn[key]) == set(includes) else: assert nn[key] is None @@ -314,6 +326,8 @@ def test_get_nearest_neighbors(api): for key in nn.keys(): if (key in includes) or (key == "ids"): assert len(nn[key]) == 2 + elif key == "included": + assert set(nn[key]) == set(includes) else: assert nn[key] is None @@ -437,6 +451,8 @@ def test_increment_index_on(api): for key in nn.keys(): if (key in includes) or (key == "ids"): assert len(nn[key]) == 1 + elif key == "included": + assert set(nn[key]) == set(includes) else: assert nn[key] is None @@ -489,6 +505,8 @@ def test_peek(api): for key in peek.keys(): if key in ["embeddings", "documents", "metadatas"] or key == "ids": assert len(peek[key]) == 2 + elif key == "included": + assert set(peek[key]) == set(["embeddings", "metadatas", "documents"]) else: assert peek[key] is None @@ -1003,7 +1021,7 @@ def test_query_include(api): assert items["embeddings"] is None assert items["ids"][0][0] == "id1" assert items["metadatas"][0][0]["int_value"] == 1 - assert items["included"] == include + assert set(items["included"]) == set(include) include = ["embeddings", "documents", "distances"] items = collection.query( @@ -1013,7 +1031,7 @@ def test_query_include(api): ) assert items["metadatas"] is None assert items["ids"][0][0] == "id1" - assert items["included"] == include + assert set(items["included"]) == set(include) items = collection.query( query_embeddings=[[0, 0, 0], [1, 2, 1.2]], @@ -1039,14 +1057,14 @@ def test_get_include(api): assert items["ids"][0] == "id1" assert items["metadatas"][0]["int_value"] == 1 assert items["documents"][0] == "this document is first" - assert items["included"] == include + assert set(items["included"]) == set(include) include = ["embeddings", "documents"] items = collection.get(include=include) assert items["metadatas"] is None assert items["ids"][0] == "id1" assert approx_equal(items["embeddings"][1][0], 1.2) - assert items["included"] == include + assert set(items["included"]) == set(include) items = collection.get(include=[]) assert items["documents"] is None @@ -1181,6 +1199,8 @@ def test_persist_index_loading_params(api, request): for key in nn.keys(): if (key in includes) or (key == "ids"): assert len(nn[key]) == 1 + elif key == "included": + assert set(nn[key]) == set(includes) else: assert nn[key] is None @@ -1299,6 +1319,8 @@ def test_get_nearest_neighbors_where_n_results_more_than_element(api): for key in results.keys(): if key in includes or key == "ids": assert len(results[key][0]) == 2 + elif key == "included": + assert set(results[key]) == set(includes) else: assert results[key] is None