Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH]: clarify non-included fields in .get() and .query() responses #2028

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Optional, Tuple, Any, Union
from typing import TYPE_CHECKING, Optional, Tuple, Any, Union, cast

import numpy as np
from pydantic import BaseModel, PrivateAttr
Expand All @@ -8,6 +8,7 @@

from chromadb.api.types import (
URI,
Omitted,
CollectionMetadata,
DataLoader,
Embedding,
Expand Down Expand Up @@ -37,6 +38,7 @@
maybe_cast_one_to_many_image,
maybe_cast_one_to_many_uri,
validate_ids,
valid_include_keys,
validate_include,
validate_metadata,
validate_metadatas,
Expand Down Expand Up @@ -226,9 +228,12 @@ def get(
):
get_results["data"] = self._data_loader(get_results["uris"])

# Remove URIs from the result if they weren't requested
if "uris" not in include:
get_results["uris"] = None
for key in valid_include_keys(allow_distances=False):
if key not in include:
# Casting to Any because key is of type Include, but GetResult does not have a distances key
get_results[cast(Any, key)] = Omitted(
f"Add '{key}' to `include` to return this field."
)

return get_results

Expand Down Expand Up @@ -360,9 +365,11 @@ def query(
self._data_loader(uris) for uris in query_results["uris"]
]

# Remove URIs from the result if they weren't requested
if "uris" not in include:
query_results["uris"] = None
for key in valid_include_keys(allow_distances=True):
if key not in include:
query_results[key] = Omitted(
f"Add '{key}' to `include` to return this field."
)

return query_results

Expand All @@ -382,7 +389,8 @@ def modify(
validate_metadata(metadata)
if "hnsw:space" in metadata:
raise ValueError(
"Changing the distance function of a collection once it is created is not supported currently.")
"Changing the distance function of a collection once it is created is not supported currently."
)

self._client._modify(id=self.id, new_name=name, new_metadata=metadata)
if name:
Expand Down
58 changes: 44 additions & 14 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,29 @@
T = TypeVar("T")
OneOrMany = Union[T, List[T]]


class Omitted:
"""Represents a returned property that was omitted from the response. It raises a ValueError when any property is accessed."""

def __init__(self, message: str):
self.message = message

def __getattr__(self, _):
raise ValueError(self.message)

def __getitem__(self, _):
raise ValueError(self.message)

def __str__(self):
return self.message

def __repr__(self):
return f'Omitted("{self.message}")'


E = TypeVar("E")
Omittable = Union[E, Omitted]

# URIs
URI = str
URIs = List[URI]
Expand Down Expand Up @@ -152,21 +175,21 @@ def maybe_cast_one_to_many_image(target: OneOrMany[Image]) -> Images:

class GetResult(TypedDict):
ids: List[ID]
embeddings: Optional[List[Embedding]]
documents: Optional[List[Document]]
uris: Optional[URIs]
data: Optional[Loadable]
metadatas: Optional[List[Metadata]]
embeddings: Omittable[List[Embedding]]
documents: Omittable[List[Document]]
uris: Omittable[URIs]
data: Omittable[Loadable]
metadatas: Omittable[List[Metadata]]


class QueryResult(TypedDict):
ids: List[IDs]
embeddings: Optional[List[List[Embedding]]]
documents: Optional[List[List[Document]]]
uris: Optional[List[List[URI]]]
data: Optional[List[Loadable]]
metadatas: Optional[List[List[Metadata]]]
distances: Optional[List[List[float]]]
embeddings: Omittable[List[List[Embedding]]]
documents: Omittable[List[List[Document]]]
uris: Omittable[List[List[URI]]]
data: Omittable[List[Loadable]]
metadatas: Omittable[List[List[Metadata]]]
distances: Omittable[List[List[float]]]


class IndexMetadata(TypedDict):
Expand Down Expand Up @@ -441,6 +464,15 @@ def validate_where_document(where_document: WhereDocument) -> WhereDocument:
return where_document


def valid_include_keys(allow_distances: bool) -> Include:
"""Returns a list of available keys for the include parameter."""
allowed_values: Include = ["embeddings", "documents", "metadatas", "uris", "data"]
if allow_distances:
allowed_values.append("distances")

return allowed_values


def validate_include(include: Include, allow_distances: bool) -> Include:
"""Validates include to ensure it is a list of strings. Since get does not allow distances, allow_distances is used
to control if distances is allowed"""
Expand All @@ -450,9 +482,7 @@ def validate_include(include: Include, allow_distances: bool) -> Include:
for item in include:
if not isinstance(item, str):
raise ValueError(f"Expected include item to be a str, got {item}")
allowed_values = ["embeddings", "documents", "metadatas", "uris", "data"]
if allow_distances:
allowed_values.append("distances")
allowed_values = valid_include_keys(allow_distances=allow_distances)
if item not in allowed_values:
raise ValueError(
f"Expected include item to be one of {', '.join(allowed_values)}, got {item}"
Expand Down
3 changes: 2 additions & 1 deletion chromadb/test/property/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Metadatas,
Where,
WhereDocument,
Omitted,
)
import chromadb.test.property.strategies as strategies
import hypothesis.strategies as st
Expand Down Expand Up @@ -345,7 +346,7 @@ def test_empty_filter(api: ServerAPI) -> None:
n_results=3,
)
assert res["ids"] == [[], []]
assert res["embeddings"] is None
assert isinstance(res["embeddings"], Omitted)
assert res["distances"] == [[], []]
assert res["metadatas"] == [[], []]

Expand Down
69 changes: 47 additions & 22 deletions chromadb/test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import chromadb
from chromadb.api.fastapi import FastAPI
from chromadb.api.types import QueryResult, EmbeddingFunction, Document
from chromadb.api.types import QueryResult, EmbeddingFunction, Document, Omitted
from chromadb.config import Settings
import chromadb.server.fastapi
import pytest
Expand Down Expand Up @@ -92,7 +92,7 @@ def test_persist_index_loading(api_fixture, request):
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
else:
assert nn[key] is None
assert isinstance(nn[key], Omitted)


@pytest.mark.parametrize("api_fixture", [local_persist_api])
Expand All @@ -119,7 +119,7 @@ def __call__(self, input):
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
else:
assert nn[key] is None
assert isinstance(nn[key], Omitted)


@pytest.mark.parametrize("api_fixture", [local_persist_api])
Expand Down Expand Up @@ -147,7 +147,7 @@ def __call__(self, input):
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
else:
assert nn[key] is None
assert isinstance(nn[key], Omitted)

assert nn["ids"] == [["id1"]]
assert nn["embeddings"] == [[[1, 2, 3]]]
Expand Down Expand Up @@ -261,7 +261,7 @@ def test_get_from_db(api):
if (key in includes) or (key == "ids"):
assert len(records[key]) == 2
else:
assert records[key] is None
assert isinstance(records[key], Omitted)


def test_reset_db(api):
Expand Down Expand Up @@ -291,7 +291,7 @@ def test_get_nearest_neighbors(api):
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
else:
assert nn[key] is None
assert isinstance(nn[key], Omitted)

nn = collection.query(
query_embeddings=[[1.1, 2.3, 3.2]],
Expand All @@ -303,7 +303,7 @@ def test_get_nearest_neighbors(api):
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
else:
assert nn[key] is None
assert isinstance(nn[key], Omitted)

nn = collection.query(
query_embeddings=[[1.1, 2.3, 3.2], [0.1, 2.3, 4.5]],
Expand All @@ -315,7 +315,7 @@ def test_get_nearest_neighbors(api):
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 2
else:
assert nn[key] is None
assert isinstance(nn[key], Omitted)


def test_delete(api):
Expand Down Expand Up @@ -438,7 +438,7 @@ def test_increment_index_on(api):
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
else:
assert nn[key] is None
assert isinstance(nn[key], Omitted)


def test_add_a_collection(api):
Expand Down Expand Up @@ -999,7 +999,7 @@ def test_query_include(api):
include=["metadatas", "documents", "distances"],
n_results=1,
)
assert items["embeddings"] is None
assert isinstance(items["embeddings"], Omitted)
assert items["ids"][0][0] == "id1"
assert items["metadatas"][0][0]["int_value"] == 1

Expand All @@ -1008,18 +1008,18 @@ def test_query_include(api):
include=["embeddings", "documents", "distances"],
n_results=1,
)
assert items["metadatas"] is None
assert isinstance(items["metadatas"], Omitted)
assert items["ids"][0][0] == "id1"

items = collection.query(
query_embeddings=[[0, 0, 0], [1, 2, 1.2]],
include=[],
n_results=2,
)
assert items["documents"] is None
assert items["metadatas"] is None
assert items["embeddings"] is None
assert items["distances"] is None
assert isinstance(items["documents"], Omitted)
assert isinstance(items["metadatas"], Omitted)
assert isinstance(items["embeddings"], Omitted)
assert isinstance(items["distances"], Omitted)
assert items["ids"][0][0] == "id1"
assert items["ids"][0][1] == "id2"

Expand All @@ -1030,20 +1030,20 @@ def test_get_include(api):
collection.add(**records)

items = collection.get(include=["metadatas", "documents"], where={"int_value": 1})
assert items["embeddings"] is None
assert isinstance(items["embeddings"], Omitted)
assert items["ids"][0] == "id1"
assert items["metadatas"][0]["int_value"] == 1
assert items["documents"][0] == "this document is first"

items = collection.get(include=["embeddings", "documents"])
assert items["metadatas"] is None
assert isinstance(items["metadatas"], Omitted)
assert items["ids"][0] == "id1"
assert approx_equal(items["embeddings"][1][0], 1.2)

items = collection.get(include=[])
assert items["documents"] is None
assert items["metadatas"] is None
assert items["embeddings"] is None
assert isinstance(items["documents"], Omitted)
assert isinstance(items["metadatas"], Omitted)
assert isinstance(items["embeddings"], Omitted)
assert items["ids"][0] == "id1"

with pytest.raises(ValueError, match="include"):
Expand Down Expand Up @@ -1173,7 +1173,7 @@ def test_persist_index_loading_params(api, request):
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
else:
assert nn[key] is None
assert isinstance(nn[key], Omitted)


def test_add_large(api):
Expand Down Expand Up @@ -1291,7 +1291,7 @@ def test_get_nearest_neighbors_where_n_results_more_than_element(api):
if key in includes or key == "ids":
assert len(results[key][0]) == 2
else:
assert results[key] is None
assert isinstance(results[key], Omitted)


def test_invalid_n_results_param(api):
Expand Down Expand Up @@ -1499,3 +1499,28 @@ def test_ssl_self_signed_with_verify_false(client_ssl):
client.heartbeat()
client_ssl.clear_system_cache()
assert "Unverified HTTPS request" in str(record[0].message)


def test_omitted_field_messages(api):
api.reset()
collection = api.create_collection("testspace")
collection.add(**records)

# Test .get()
document = collection.get(ids="id1")

assert "Add 'embeddings' to `include` to return this field" in str(document)

assert "Add 'embeddings' to `include` to return this field" in str(
document["embeddings"]
)

with pytest.raises(ValueError) as exc:
document["embeddings"][0]
assert "Add 'embeddings' to `include` to return this field" in str(exc.value)
assert exc.type == ValueError

# Test .query()
documents = collection.query(query_embeddings=[0, 0, 0], n_results=1)

assert "Add 'embeddings' to `include` to return this field" in str(documents)
Loading