Skip to content

Commit

Permalink
[ENH] add included to .get() & .query() response (#2044)
Browse files Browse the repository at this point in the history
  • Loading branch information
codetheweb authored Apr 30, 2024
1 parent 773556f commit c46ea2a
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 10 deletions.
2 changes: 2 additions & 0 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ def _get(
documents=body.get("documents", None),
data=None,
uris=body.get("uris", None),
included=body["included"],
)

@trace_method("FastAPI._delete", OpenTelemetryGranularity.OPERATION)
Expand Down Expand Up @@ -581,6 +582,7 @@ def _query(
documents=body.get("documents", None),
uris=body.get("uris", None),
data=None,
included=body["included"],
)

@trace_method("FastAPI.reset", OpenTelemetryGranularity.ALL)
Expand Down
3 changes: 3 additions & 0 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ def _get(
documents=[] if "documents" in include else None,
uris=[] if "uris" in include else None,
data=[] if "data" in include else None,
included=include,
)

vectors: Sequence[t.VectorEmbeddingRecord] = []
Expand Down Expand Up @@ -574,6 +575,7 @@ def _get(
documents=documents if "documents" in include else None, # type: ignore
uris=uris if "uris" in include else None, # type: ignore
data=None,
included=include,
)

@trace_method("SegmentAPI._delete", OpenTelemetryGranularity.OPERATION)
Expand Down Expand Up @@ -766,6 +768,7 @@ def _query(
documents=documents if documents else None,
uris=uris if uris else None,
data=None,
included=include,
)

@trace_method("SegmentAPI._peek", OpenTelemetryGranularity.OPERATION)
Expand Down
2 changes: 2 additions & 0 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ class GetResult(TypedDict):
uris: Optional[URIs]
data: Optional[Loadable]
metadatas: Optional[List[Metadata]]
included: Include


class QueryResult(TypedDict):
Expand All @@ -167,6 +168,7 @@ class QueryResult(TypedDict):
data: Optional[List[Loadable]]
metadatas: Optional[List[List[Metadata]]]
distances: Optional[List[List[float]]]
included: Include


class IndexMetadata(TypedDict):
Expand Down
2 changes: 2 additions & 0 deletions chromadb/test/property/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def test_empty_filter(api: ServerAPI) -> None:
assert res["embeddings"] == [[]]
assert res["distances"] == [[]]
assert res["metadatas"] == [[]]
assert set(res["included"]) == set(["embeddings", "distances", "metadatas"])

res = coll.query(
query_embeddings=test_query_embeddings,
Expand All @@ -348,6 +349,7 @@ def test_empty_filter(api: ServerAPI) -> None:
assert res["embeddings"] is None
assert res["distances"] == [[], []]
assert res["metadatas"] == [[], []]
assert set(res["included"]) == set(["metadatas", "documents", "distances"])


def test_boolean_metadata(api: ServerAPI) -> None:
Expand Down
39 changes: 35 additions & 4 deletions chromadb/test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def test_persist_index_loading(api_fixture, request):
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None

Expand Down Expand Up @@ -118,6 +120,8 @@ def __call__(self, input):
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None

Expand Down Expand Up @@ -146,6 +150,8 @@ def __call__(self, input):
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None

Expand Down Expand Up @@ -260,6 +266,8 @@ def test_get_from_db(api):
for key in records.keys():
if (key in includes) or (key == "ids"):
assert len(records[key]) == 2
elif key == "included":
assert set(records[key]) == set(includes)
else:
assert records[key] is None

Expand Down Expand Up @@ -290,6 +298,8 @@ def test_get_nearest_neighbors(api):
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None

Expand All @@ -302,6 +312,8 @@ def test_get_nearest_neighbors(api):
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None

Expand All @@ -314,6 +326,8 @@ def test_get_nearest_neighbors(api):
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 2
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None

Expand Down Expand Up @@ -437,6 +451,8 @@ def test_increment_index_on(api):
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None

Expand Down Expand Up @@ -489,6 +505,8 @@ def test_peek(api):
for key in peek.keys():
if key in ["embeddings", "documents", "metadatas"] or key == "ids":
assert len(peek[key]) == 2
elif key == "included":
assert set(peek[key]) == set(["embeddings", "metadatas", "documents"])
else:
assert peek[key] is None

Expand Down Expand Up @@ -994,22 +1012,26 @@ def test_query_include(api):
collection = api.create_collection("test_query_include")
collection.add(**records)

include = ["metadatas", "documents", "distances"]
items = collection.query(
query_embeddings=[0, 0, 0],
include=["metadatas", "documents", "distances"],
include=include,
n_results=1,
)
assert items["embeddings"] is None
assert items["ids"][0][0] == "id1"
assert items["metadatas"][0][0]["int_value"] == 1
assert set(items["included"]) == set(include)

include = ["embeddings", "documents", "distances"]
items = collection.query(
query_embeddings=[0, 0, 0],
include=["embeddings", "documents", "distances"],
include=include,
n_results=1,
)
assert items["metadatas"] is None
assert items["ids"][0][0] == "id1"
assert set(items["included"]) == set(include)

items = collection.query(
query_embeddings=[[0, 0, 0], [1, 2, 1.2]],
Expand All @@ -1029,22 +1051,27 @@ def test_get_include(api):
collection = api.create_collection("test_get_include")
collection.add(**records)

items = collection.get(include=["metadatas", "documents"], where={"int_value": 1})
include = ["metadatas", "documents"]
items = collection.get(include=include, where={"int_value": 1})
assert items["embeddings"] is None
assert items["ids"][0] == "id1"
assert items["metadatas"][0]["int_value"] == 1
assert items["documents"][0] == "this document is first"
assert set(items["included"]) == set(include)

items = collection.get(include=["embeddings", "documents"])
include = ["embeddings", "documents"]
items = collection.get(include=include)
assert items["metadatas"] is None
assert items["ids"][0] == "id1"
assert approx_equal(items["embeddings"][1][0], 1.2)
assert set(items["included"]) == set(include)

items = collection.get(include=[])
assert items["documents"] is None
assert items["metadatas"] is None
assert items["embeddings"] is None
assert items["ids"][0] == "id1"
assert items["included"] == []

with pytest.raises(ValueError, match="include"):
items = collection.get(include=["metadatas", "undefined"])
Expand Down Expand Up @@ -1172,6 +1199,8 @@ def test_persist_index_loading_params(api, request):
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None

Expand Down Expand Up @@ -1290,6 +1319,8 @@ def test_get_nearest_neighbors_where_n_results_more_than_element(api):
for key in results.keys():
if key in includes or key == "ids":
assert len(results[key][0]) == 2
elif key == "included":
assert set(results[key]) == set(includes)
else:
assert results[key] is None

Expand Down
12 changes: 7 additions & 5 deletions clients/js/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ type WhereOperator = "$gt" | "$gte" | "$lt" | "$lte" | "$ne" | "$eq";

type OperatorExpression = {
[key in WhereOperator | InclusionOperator | LogicalOperator]?:
| LiteralValue
| ListLiteralValue;
| LiteralValue
| ListLiteralValue;
};

type BaseWhere = {
Expand All @@ -50,9 +50,9 @@ type WhereDocumentOperator = "$contains" | "$not_contains" | LogicalOperator;

export type WhereDocument = {
[key in WhereDocumentOperator]?:
| LiteralValue
| LiteralNumber
| WhereDocument[];
| LiteralValue
| LiteralNumber
| WhereDocument[];
};

export type CollectionType = {
Expand All @@ -67,6 +67,7 @@ export type GetResponse = {
documents: (null | Document)[];
metadatas: (null | Metadata)[];
error: null | string;
included: IncludeEnum[]
};

export type QueryResponse = {
Expand All @@ -75,6 +76,7 @@ export type QueryResponse = {
documents: (null | Document)[][];
metadatas: (null | Metadata)[][];
distances: null | number[][];
included: IncludeEnum[]
};

export type AddResponse = {
Expand Down
1 change: 1 addition & 0 deletions clients/js/test/get.collection.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ test("it should get a collection", async () => {
expect(results.ids.length).toBe(1);
expect(["test1"]).toEqual(expect.arrayContaining(results.ids));
expect(["test2"]).not.toEqual(expect.arrayContaining(results.ids));
expect(results.included).toEqual(expect.arrayContaining(["metadatas", "documents"]))

const results2 = await collection.get({ where: { test: "test1" } });
expect(results2).toBeDefined();
Expand Down
4 changes: 3 additions & 1 deletion clients/js/test/query.collection.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { EMBEDDINGS, IDS, METADATAS, DOCUMENTS } from "./data";
import { IEmbeddingFunction } from "../src/embeddings/IEmbeddingFunction";

export class TestEmbeddingFunction implements IEmbeddingFunction {
constructor() {}
constructor() { }

public async generate(texts: string[]): Promise<number[][]> {
let embeddings: number[][] = [];
Expand All @@ -29,6 +29,7 @@ test("it should query a collection", async () => {
expect(results).toBeInstanceOf(Object);
expect(["test1", "test2"]).toEqual(expect.arrayContaining(results.ids[0]));
expect(["test3"]).not.toEqual(expect.arrayContaining(results.ids[0]));
expect(results.included).toEqual(expect.arrayContaining(["metadatas", "documents"]))
});

// test where_document
Expand Down Expand Up @@ -68,6 +69,7 @@ test("it should get embedding with matching documents", async () => {
// expect(results2.embeddings[0][0]).toBeInstanceOf(Array);
expect(results2.embeddings![0].length).toBe(1);
expect(results2.embeddings![0][0]).toEqual([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
expect(results2.included).toEqual(expect.arrayContaining(["embeddings"]))
});

test("it should exclude documents matching - not_contains", async () => {
Expand Down

0 comments on commit c46ea2a

Please sign in to comment.