Skip to content

Commit

Permalink
Clarify non-included fields in .get() and .query() responses
Browse files Browse the repository at this point in the history
  • Loading branch information
codetheweb committed Apr 19, 2024
1 parent 729e657 commit 78b0e0c
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 22 deletions.
24 changes: 16 additions & 8 deletions chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Optional, Tuple, Any, Union
from typing import TYPE_CHECKING, Optional, Tuple, Any, Union, cast

import numpy as np
from pydantic import BaseModel, PrivateAttr
Expand All @@ -8,6 +8,7 @@

from chromadb.api.types import (
URI,
Omitted,
CollectionMetadata,
DataLoader,
Embedding,
Expand Down Expand Up @@ -37,6 +38,7 @@
maybe_cast_one_to_many_image,
maybe_cast_one_to_many_uri,
validate_ids,
valid_include_keys,
validate_include,
validate_metadata,
validate_metadatas,
Expand Down Expand Up @@ -226,9 +228,12 @@ def get(
):
get_results["data"] = self._data_loader(get_results["uris"])

# Remove URIs from the result if they weren't requested
if "uris" not in include:
get_results["uris"] = None
for key in valid_include_keys(allow_distances=False):
if key not in include:
# Casting to Any because key is of type Include, but GetResult does not have a distances key
get_results[cast(Any, key)] = Omitted(
f"Add '{key}' to `include` to return this field."
)

return get_results

Expand Down Expand Up @@ -360,9 +365,11 @@ def query(
self._data_loader(uris) for uris in query_results["uris"]
]

# Remove URIs from the result if they weren't requested
if "uris" not in include:
query_results["uris"] = None
for key in valid_include_keys(allow_distances=True):
if key not in include:
query_results[key] = Omitted(
f"Add '{key}' to `include` to return this field."
)

return query_results

Expand All @@ -382,7 +389,8 @@ def modify(
validate_metadata(metadata)
if "hnsw:space" in metadata:
raise ValueError(
"Changing the distance function of a collection once it is created is not supported currently.")
"Changing the distance function of a collection once it is created is not supported currently."
)

self._client._modify(id=self.id, new_name=name, new_metadata=metadata)
if name:
Expand Down
58 changes: 44 additions & 14 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,29 @@
T = TypeVar("T")
OneOrMany = Union[T, List[T]]


class Omitted:
"""Represents a returned property that was omitted from the response. It raises a ValueError when any property is accessed."""

def __init__(self, message: str):
self.message = message

def __getattr__(self, _):
raise ValueError(self.message)

def __getitem__(self, _):
raise ValueError(self.message)

def __str__(self):
return self.message

def __repr__(self):
return f'Omitted("{self.message}")'


E = TypeVar("E")
Omittable = Union[E, Omitted]

# URIs
URI = str
URIs = List[URI]
Expand Down Expand Up @@ -152,21 +175,21 @@ def maybe_cast_one_to_many_image(target: OneOrMany[Image]) -> Images:

class GetResult(TypedDict):
ids: List[ID]
embeddings: Optional[List[Embedding]]
documents: Optional[List[Document]]
uris: Optional[URIs]
data: Optional[Loadable]
metadatas: Optional[List[Metadata]]
embeddings: Omittable[List[Embedding]]
documents: Omittable[List[Document]]
uris: Omittable[URIs]
data: Omittable[Loadable]
metadatas: Omittable[List[Metadata]]


class QueryResult(TypedDict):
ids: List[IDs]
embeddings: Optional[List[List[Embedding]]]
documents: Optional[List[List[Document]]]
uris: Optional[List[List[URI]]]
data: Optional[List[Loadable]]
metadatas: Optional[List[List[Metadata]]]
distances: Optional[List[List[float]]]
embeddings: Omittable[List[List[Embedding]]]
documents: Omittable[List[List[Document]]]
uris: Omittable[List[List[URI]]]
data: Omittable[List[Loadable]]
metadatas: Omittable[List[List[Metadata]]]
distances: Omittable[List[List[float]]]


class IndexMetadata(TypedDict):
Expand Down Expand Up @@ -441,6 +464,15 @@ def validate_where_document(where_document: WhereDocument) -> WhereDocument:
return where_document


def valid_include_keys(allow_distances: bool) -> Include:
"""Returns a list of available keys for the include parameter."""
allowed_values: Include = ["embeddings", "documents", "metadatas", "uris", "data"]
if allow_distances:
allowed_values.append("distances")

return allowed_values


def validate_include(include: Include, allow_distances: bool) -> Include:
"""Validates include to ensure it is a list of strings. Since get does not allow distances, allow_distances is used
to control if distances is allowed"""
Expand All @@ -450,9 +482,7 @@ def validate_include(include: Include, allow_distances: bool) -> Include:
for item in include:
if not isinstance(item, str):
raise ValueError(f"Expected include item to be a str, got {item}")
allowed_values = ["embeddings", "documents", "metadatas", "uris", "data"]
if allow_distances:
allowed_values.append("distances")
allowed_values = valid_include_keys(allow_distances=allow_distances)
if item not in allowed_values:
raise ValueError(
f"Expected include item to be one of {', '.join(allowed_values)}, got {item}"
Expand Down
25 changes: 25 additions & 0 deletions chromadb/test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1499,3 +1499,28 @@ def test_ssl_self_signed_with_verify_false(client_ssl):
client.heartbeat()
client_ssl.clear_system_cache()
assert "Unverified HTTPS request" in str(record[0].message)


def test_omitted_field_messages(api):
api.reset()
collection = api.create_collection("testspace")
collection.add(**records)

# Test .get()
document = collection.get(ids="id1")

assert "Add 'embeddings' to `include` to return this field" in str(document)

assert "Add 'embeddings' to `include` to return this field" in str(
document["embeddings"]
)

with pytest.raises(ValueError) as exc:
document["embeddings"][0]
assert "Add 'embeddings' to `include` to return this field" in str(exc.value)
assert exc.type == ValueError

# Test .query()
documents = collection.query(query_embeddings=[0, 0, 0], n_results=1)

assert "Add 'embeddings' to `include` to return this field" in str(documents)

0 comments on commit 78b0e0c

Please sign in to comment.