Skip to content

Commit

Permalink
Add included to .get() & .query() response
Browse files Browse the repository at this point in the history
Clarifies behavior when fields like `embeddings` are returned as `None`.
  • Loading branch information
codetheweb committed Apr 23, 2024
1 parent e5ec1b3 commit 5202d38
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 4 deletions.
2 changes: 2 additions & 0 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ def _get(
documents=body.get("documents", None),
data=None,
uris=body.get("uris", None),
included=body["included"],
)

@trace_method("FastAPI._delete", OpenTelemetryGranularity.OPERATION)
Expand Down Expand Up @@ -581,6 +582,7 @@ def _query(
documents=body.get("documents", None),
uris=body.get("uris", None),
data=None,
included=body["included"],
)

@trace_method("FastAPI.reset", OpenTelemetryGranularity.ALL)
Expand Down
3 changes: 3 additions & 0 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ def _get(
documents=[] if "documents" in include else None,
uris=[] if "uris" in include else None,
data=[] if "data" in include else None,
included=include,
)

vectors: Sequence[t.VectorEmbeddingRecord] = []
Expand Down Expand Up @@ -574,6 +575,7 @@ def _get(
documents=documents if "documents" in include else None, # type: ignore
uris=uris if "uris" in include else None, # type: ignore
data=None,
included=include,
)

@trace_method("SegmentAPI._delete", OpenTelemetryGranularity.OPERATION)
Expand Down Expand Up @@ -766,6 +768,7 @@ def _query(
documents=documents if documents else None,
uris=uris if uris else None,
data=None,
included=include,
)

@trace_method("SegmentAPI._peek", OpenTelemetryGranularity.OPERATION)
Expand Down
2 changes: 2 additions & 0 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ class GetResult(TypedDict):
uris: Optional[URIs]
data: Optional[Loadable]
metadatas: Optional[List[Metadata]]
included: Include


class QueryResult(TypedDict):
Expand All @@ -167,6 +168,7 @@ class QueryResult(TypedDict):
data: Optional[List[Loadable]]
metadatas: Optional[List[List[Metadata]]]
distances: Optional[List[List[float]]]
included: Include


class IndexMetadata(TypedDict):
Expand Down
2 changes: 2 additions & 0 deletions chromadb/test/property/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def test_empty_filter(api: ServerAPI) -> None:
assert res["embeddings"] == [[]]
assert res["distances"] == [[]]
assert res["metadatas"] == [[]]
assert res["included"] == ["embeddings", "distances", "metadatas"]

res = coll.query(
query_embeddings=test_query_embeddings,
Expand All @@ -348,6 +349,7 @@ def test_empty_filter(api: ServerAPI) -> None:
assert res["embeddings"] is None
assert res["distances"] == [[], []]
assert res["metadatas"] == [[], []]
assert res["included"] == ["metadatas", "documents", "distances"]


def test_boolean_metadata(api: ServerAPI) -> None:
Expand Down
17 changes: 13 additions & 4 deletions chromadb/test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,22 +994,26 @@ def test_query_include(api):
collection = api.create_collection("test_query_include")
collection.add(**records)

include = ["metadatas", "documents", "distances"]
items = collection.query(
query_embeddings=[0, 0, 0],
include=["metadatas", "documents", "distances"],
include=include,
n_results=1,
)
assert items["embeddings"] is None
assert items["ids"][0][0] == "id1"
assert items["metadatas"][0][0]["int_value"] == 1
assert items["included"] == include

include = ["embeddings", "documents", "distances"]
items = collection.query(
query_embeddings=[0, 0, 0],
include=["embeddings", "documents", "distances"],
include=include,
n_results=1,
)
assert items["metadatas"] is None
assert items["ids"][0][0] == "id1"
assert items["included"] == include

items = collection.query(
query_embeddings=[[0, 0, 0], [1, 2, 1.2]],
Expand All @@ -1029,22 +1033,27 @@ def test_get_include(api):
collection = api.create_collection("test_get_include")
collection.add(**records)

items = collection.get(include=["metadatas", "documents"], where={"int_value": 1})
include = ["metadatas", "documents"]
items = collection.get(include=include, where={"int_value": 1})
assert items["embeddings"] is None
assert items["ids"][0] == "id1"
assert items["metadatas"][0]["int_value"] == 1
assert items["documents"][0] == "this document is first"
assert items["included"] == include

items = collection.get(include=["embeddings", "documents"])
include = ["embeddings", "documents"]
items = collection.get(include=include)
assert items["metadatas"] is None
assert items["ids"][0] == "id1"
assert approx_equal(items["embeddings"][1][0], 1.2)
assert items["included"] == include

items = collection.get(include=[])
assert items["documents"] is None
assert items["metadatas"] is None
assert items["embeddings"] is None
assert items["ids"][0] == "id1"
assert items["included"] == []

with pytest.raises(ValueError, match="include"):
items = collection.get(include=["metadatas", "undefined"])
Expand Down

0 comments on commit 5202d38

Please sign in to comment.