diff --git a/chromadb/test/property/test_filtering.py b/chromadb/test/property/test_filtering.py index 9129c023df7..ed555a35538 100644 --- a/chromadb/test/property/test_filtering.py +++ b/chromadb/test/property/test_filtering.py @@ -13,6 +13,7 @@ Metadatas, Where, WhereDocument, + Omitted, ) import chromadb.test.property.strategies as strategies import hypothesis.strategies as st @@ -345,7 +346,7 @@ def test_empty_filter(api: ServerAPI) -> None: n_results=3, ) assert res["ids"] == [[], []] - assert res["embeddings"] is None + assert isinstance(res["embeddings"], Omitted) assert res["distances"] == [[], []] assert res["metadatas"] == [[], []] diff --git a/chromadb/test/test_api.py b/chromadb/test/test_api.py index d37a5a23eb2..e786bbaa9db 100644 --- a/chromadb/test/test_api.py +++ b/chromadb/test/test_api.py @@ -5,7 +5,7 @@ import chromadb from chromadb.api.fastapi import FastAPI -from chromadb.api.types import QueryResult, EmbeddingFunction, Document +from chromadb.api.types import QueryResult, EmbeddingFunction, Document, Omitted from chromadb.config import Settings import chromadb.server.fastapi import pytest @@ -92,7 +92,7 @@ def test_persist_index_loading(api_fixture, request): if (key in includes) or (key == "ids"): assert len(nn[key]) == 1 else: - assert nn[key] is None + assert isinstance(nn[key], Omitted) @pytest.mark.parametrize("api_fixture", [local_persist_api]) @@ -119,7 +119,7 @@ def __call__(self, input): if (key in includes) or (key == "ids"): assert len(nn[key]) == 1 else: - assert nn[key] is None + assert isinstance(nn[key], Omitted) @pytest.mark.parametrize("api_fixture", [local_persist_api]) @@ -147,7 +147,7 @@ def __call__(self, input): if (key in includes) or (key == "ids"): assert len(nn[key]) == 1 else: - assert nn[key] is None + assert isinstance(nn[key], Omitted) assert nn["ids"] == [["id1"]] assert nn["embeddings"] == [[[1, 2, 3]]] @@ -261,7 +261,7 @@ def test_get_from_db(api): if (key in includes) or (key == "ids"): assert len(records[key]) == 2 else: - assert records[key] is None + assert isinstance(records[key], Omitted) def test_reset_db(api): @@ -291,7 +291,7 @@ def test_get_nearest_neighbors(api): if (key in includes) or (key == "ids"): assert len(nn[key]) == 1 else: - assert nn[key] is None + assert isinstance(nn[key], Omitted) nn = collection.query( query_embeddings=[[1.1, 2.3, 3.2]], @@ -303,7 +303,7 @@ def test_get_nearest_neighbors(api): if (key in includes) or (key == "ids"): assert len(nn[key]) == 1 else: - assert nn[key] is None + assert isinstance(nn[key], Omitted) nn = collection.query( query_embeddings=[[1.1, 2.3, 3.2], [0.1, 2.3, 4.5]], @@ -315,7 +315,7 @@ def test_get_nearest_neighbors(api): if (key in includes) or (key == "ids"): assert len(nn[key]) == 2 else: - assert nn[key] is None + assert isinstance(nn[key], Omitted) def test_delete(api): @@ -438,7 +438,7 @@ def test_increment_index_on(api): if (key in includes) or (key == "ids"): assert len(nn[key]) == 1 else: - assert nn[key] is None + assert isinstance(nn[key], Omitted) def test_add_a_collection(api): @@ -999,7 +999,7 @@ def test_query_include(api): include=["metadatas", "documents", "distances"], n_results=1, ) - assert items["embeddings"] is None + assert isinstance(items["embeddings"], Omitted) assert items["ids"][0][0] == "id1" assert items["metadatas"][0][0]["int_value"] == 1 @@ -1008,7 +1008,7 @@ def test_query_include(api): include=["embeddings", "documents", "distances"], n_results=1, ) - assert items["metadatas"] is None + assert isinstance(items["metadatas"], Omitted) assert items["ids"][0][0] == "id1" items = collection.query( @@ -1016,10 +1016,10 @@ def test_query_include(api): include=[], n_results=2, ) - assert items["documents"] is None - assert items["metadatas"] is None - assert items["embeddings"] is None - assert items["distances"] is None + assert isinstance(items["documents"], Omitted) + assert isinstance(items["metadatas"], Omitted) + assert isinstance(items["embeddings"], Omitted) + assert isinstance(items["distances"], Omitted) assert items["ids"][0][0] == "id1" assert items["ids"][0][1] == "id2" @@ -1030,20 +1030,20 @@ def test_get_include(api): collection.add(**records) items = collection.get(include=["metadatas", "documents"], where={"int_value": 1}) - assert items["embeddings"] is None + assert isinstance(items["embeddings"], Omitted) assert items["ids"][0] == "id1" assert items["metadatas"][0]["int_value"] == 1 assert items["documents"][0] == "this document is first" items = collection.get(include=["embeddings", "documents"]) - assert items["metadatas"] is None + assert isinstance(items["metadatas"], Omitted) assert items["ids"][0] == "id1" assert approx_equal(items["embeddings"][1][0], 1.2) items = collection.get(include=[]) - assert items["documents"] is None - assert items["metadatas"] is None - assert items["embeddings"] is None + assert isinstance(items["documents"], Omitted) + assert isinstance(items["metadatas"], Omitted) + assert isinstance(items["embeddings"], Omitted) assert items["ids"][0] == "id1" with pytest.raises(ValueError, match="include"): @@ -1173,7 +1173,7 @@ def test_persist_index_loading_params(api, request): if (key in includes) or (key == "ids"): assert len(nn[key]) == 1 else: - assert nn[key] is None + assert isinstance(nn[key], Omitted) def test_add_large(api): @@ -1291,7 +1291,7 @@ def test_get_nearest_neighbors_where_n_results_more_than_element(api): if key in includes or key == "ids": assert len(results[key][0]) == 2 else: - assert results[key] is None + assert isinstance(results[key], Omitted) def test_invalid_n_results_param(api):