Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add query with driver mode for ObjectIndex #541

Merged
merged 1 commit into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions apis/python/src/tiledb/vector_search/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,16 @@ def _query_with_driver(
queries: np.ndarray,
k: int,
driver_mode=None,
driver_resources=None,
driver_resource_class: Optional[str] = None,
driver_resources: Optional[Mapping[str, Any]] = None,
driver_access_credentials_name=None,
**kwargs,
):
from tiledb.cloud import dag

if driver_resource_class and driver_resources:
raise TypeError("Cannot provide both resource_class and resources")

def query_udf(index_type, index_open_kwargs, query_kwargs):
from tiledb.vector_search.flat_index import FlatIndex
from tiledb.vector_search.ivf_flat_index import IVFFlatIndex
Expand Down Expand Up @@ -226,6 +230,9 @@ def query_udf(index_type, index_open_kwargs, query_kwargs):
self.index_open_kwargs,
query_kwargs,
name="vector-query-driver",
resource_class="large"
if (not driver_resources and not driver_resource_class)
else driver_resource_class,
resources=driver_resources,
image_name="vectorsearch",
access_credentials_name=driver_access_credentials_name,
Expand All @@ -239,7 +246,8 @@ def query(
queries: np.ndarray,
k: int,
driver_mode: Optional[Mode] = None,
driver_resources: Optional[str] = None,
driver_resource_class: Optional[str] = None,
driver_resources: Optional[Mapping[str, Any]] = None,
driver_access_credentials_name: Optional[str] = None,
**kwargs,
):
Expand All @@ -265,8 +273,11 @@ def query(
Number of results to return per query vector.
driver_mode: Mode
If not `None`, the query will be executed in a TileDB cloud taskgraph using the driver mode specified.
driver_resource_class: Optional[str]
If `driver_mode` was `REALTIME`, the resources class (`standard` or `large`) to use for the driver execution.
driver_resources: Optional[str]
If `driver_mode` was not `None`, the resources to use for the driver execution.
If `driver_mode` was `BATCH`, the resources to use for the driver execution.
Example `{"cpu": "1", "memory": "4Gi"}`
driver_access_credentials_name: Optional[str]
If `driver_mode` was not `None`, the access credentials name to use for the driver execution.
**kwargs
Expand Down Expand Up @@ -307,11 +318,12 @@ def flip_results(results):
"Cannot pass driver_mode=Mode.LOCAL to query() - use driver_mode=None to query locally."
)
results, indexes = self._query_with_driver(
queries,
k,
driver_mode,
driver_resources,
driver_access_credentials_name,
queries=queries,
k=k,
driver_mode=driver_mode,
driver_resources=driver_resources,
driver_resource_class=driver_resource_class,
driver_access_credentials_name=driver_access_credentials_name,
**kwargs,
)
if self.distance_metric == vspy.DistanceMetric.INNER_PRODUCT:
Expand Down
10 changes: 2 additions & 8 deletions apis/python/src/tiledb/vector_search/ivf_flat_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,19 +325,13 @@ def query(
self,
queries: np.ndarray,
k: int,
driver_mode: Mode = None,
driver_resources: Optional[str] = None,
driver_access_credentials_name: Optional[str] = None,
**kwargs,
):
if self.distance_metric == vspy.DistanceMetric.COSINE:
queries = normalize_vectors(queries)
return super().query(
queries,
k,
driver_mode,
driver_resources,
driver_access_credentials_name,
queries=queries,
k=k,
**kwargs,
)

Expand Down
163 changes: 163 additions & 0 deletions apis/python/src/tiledb/vector_search/object_api/object_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ class ObjectIndex:
Timestamp to open the index at.
load_embedding: bool
Whether to load the embedding function into memory.
open_for_remote_query_execution: bool
If `True`, do not load the embedding model and any index data locally, and instead perform all query functionality in a TileDB Cloud taskgraph.
open_vector_index_for_remote_query_execution: bool
If `True`, do not load any index data in main memory locally, and instead load index data and perform vector queries in a TileDB Cloud taskgraph.
Compared to `open_for_remote_query_execution`, this loads the object embedding function and computes query object embeddings locally.
load_metadata_in_memory: bool
Whether to load the metadata array into memory.
environment_variables: Dict
Expand All @@ -67,6 +72,8 @@ def __init__(
uri: str,
config: Optional[Mapping[str, Any]] = None,
timestamp=None,
open_for_remote_query_execution: bool = False,
open_vector_index_for_remote_query_execution: bool = False,
load_embedding: bool = True,
load_metadata_in_memory: bool = True,
environment_variables: Dict = {},
Expand All @@ -80,7 +87,29 @@ def __init__(
self.uri = uri
self.config = config
self.timestamp = timestamp
self.open_for_remote_query_execution = open_for_remote_query_execution
self.open_vector_index_for_remote_query_execution = (
open_vector_index_for_remote_query_execution
)
if self.open_vector_index_for_remote_query_execution:
kwargs["open_for_remote_query_execution"] = True
self.load_embedding = load_embedding
self.load_metadata_in_memory = load_metadata_in_memory
self.environment_variables = environment_variables
self.kwargs = kwargs
self.index_open_kwargs = {
"uri": uri,
"config": config,
"timestamp": timestamp,
"load_embedding": load_embedding,
"load_metadata_in_memory": load_metadata_in_memory,
"environment_variables": environment_variables,
"kwargs": kwargs,
}

if self.open_for_remote_query_execution:
return

group = tiledb.Group(uri, "r")
self.index_type = group.meta["index_type"]
group.close()
Expand Down Expand Up @@ -156,6 +185,95 @@ def __init__(
if self.load_metadata_in_memory:
self.metadata_df = self.metadata_array.df[:]

def _query_with_driver(
self,
query_objects: np.ndarray,
k: int,
query_metadata: Optional[OrderedDict] = None,
metadata_array_cond: Optional[str] = None,
metadata_df_filter_fn: Optional[str] = None,
return_objects: bool = True,
return_metadata: bool = True,
driver_mode: Optional[Mode] = Mode.REALTIME,
driver_resource_class: Optional[str] = None,
driver_resources: Optional[Mapping[str, Any]] = None,
driver_access_credentials_name: Optional[str] = None,
extra_driver_modules: Optional[List[str]] = None,
object_index_source_code: Optional[str] = None,
**kwargs,
):
from tiledb.cloud import dag

if driver_resource_class and driver_resources:
raise TypeError("Cannot provide both resource_class and resources")

def query_udf(
index_open_kwargs,
query_kwargs,
extra_driver_modules: Optional[List[str]] = None,
object_index_source_code: Optional[str] = None,
):
def install_extra_driver_modules():
if extra_driver_modules is not None:
import os
import subprocess
import sys

sys.path.insert(0, "/home/udf/.local/bin")
sys.path.insert(0, "/home/udf/.local/lib/python3.9/site-packages")
os.environ["PATH"] = f"/home/udf/.local/bin:{os.environ['PATH']}"
for module in extra_driver_modules:
subprocess.check_call(
[sys.executable, "-m", "pip", "install", module]
)

install_extra_driver_modules()
from tiledb.vector_search.object_api import object_index

if object_index_source_code is not None:
index = object_index.instantiate_object(
code=object_index_source_code,
class_name="ObjectIndex",
**index_open_kwargs,
)
else:
index = object_index.ObjectIndex(**index_open_kwargs)
# Query index
return index.query(**query_kwargs)

d = dag.DAG(
name="vector-query",
mode=driver_mode,
max_workers=1,
)
query_kwargs = {
"query_objects": query_objects,
"k": k,
"query_metadata": query_metadata,
"metadata_array_cond": metadata_array_cond,
"metadata_df_filter_fn": metadata_df_filter_fn,
"return_objects": return_objects,
"return_metadata": return_metadata,
}
query_kwargs.update(kwargs)
node = d.submit(
query_udf,
self.index_open_kwargs,
query_kwargs,
extra_driver_modules=extra_driver_modules,
object_index_source_code=object_index_source_code,
name="vector-query-driver",
resource_class="large"
if (not driver_resources and not driver_resource_class)
else driver_resource_class,
resources=driver_resources,
image_name="vectorsearch",
access_credentials_name=driver_access_credentials_name,
)
d.compute()
d.wait()
return node.result()

def query(
self,
query_objects: np.ndarray,
Expand All @@ -165,6 +283,11 @@ def query(
metadata_df_filter_fn: Optional[str] = None,
return_objects: bool = True,
return_metadata: bool = True,
driver_mode: Optional[Mode] = Mode.REALTIME,
driver_resource_class: Optional[str] = None,
driver_resources: Optional[Mapping[str, Any]] = None,
extra_driver_modules: Optional[List[str]] = None,
driver_access_credentials_name: Optional[str] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -226,6 +349,17 @@ def query(
Whether to return the objects themselves, or just the object IDs.
return_metadata: bool
Whether to return the metadata for the objects.
driver_mode: Mode
If not `None`, the query will be executed in a TileDB cloud taskgraph using the driver mode specified.
driver_resource_class: Optional[str]
If `driver_mode` was `REALTIME`, the resources class (`standard` or `large`) to use for the driver execution.
driver_resources: Optional[str]
If `driver_mode` was `BATCH`, the resources to use for the driver execution.
Example `{"cpu": "1", "memory": "4Gi"}`
extra_driver_modules: List[str], optional
A list of extra Python modules to install on the driver node.
driver_access_credentials_name: Optional[str]
If `driver_mode` was not `None`, the access credentials name to use for the driver execution.
**kwargs
Keyword arguments to pass to the index query method.

Expand All @@ -240,6 +374,35 @@ def query(
A tuple containing the distances, objects or object IDs, and optionally
the object metadata.
"""
if self.open_for_remote_query_execution:
if driver_mode is None or driver_mode == Mode.LOCAL:
raise TypeError(
f"Cannot pass driver_mode={driver_mode} to query() when using `open_for_remote_query_execution`."
)

object_index_source_code = get_source_code(self)
return self._query_with_driver(
query_objects=query_objects,
k=k,
query_metadata=query_metadata,
metadata_array_cond=metadata_array_cond,
metadata_df_filter_fn=metadata_df_filter_fn,
return_objects=return_objects,
return_metadata=return_metadata,
driver_mode=driver_mode,
driver_resource_class=driver_resource_class,
driver_resources=driver_resources,
extra_driver_modules=extra_driver_modules,
driver_access_credentials_name=driver_access_credentials_name,
object_index_source_code=object_index_source_code,
**kwargs,
)
elif self.open_vector_index_for_remote_query_execution:
kwargs["driver_mode"] = driver_mode
kwargs["driver_resource_class"] = driver_resource_class
kwargs["driver_resources"] = driver_resources
kwargs["driver_access_credentials_name"] = driver_access_credentials_name

if (
metadata_array_cond is not None or metadata_df_filter_fn is not None
) and self.object_metadata_array_uri is None:
Expand Down
34 changes: 31 additions & 3 deletions apis/python/test/test_object_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,21 @@ def assert_equal(


def evaluate_query(
index_type: str, index_uri, query_kwargs, dim_id, vector_dim_offset, config=None
index_type: str,
index_uri,
query_kwargs,
dim_id,
vector_dim_offset,
config=None,
open_for_remote_query_execution=False,
):
v_id = dim_id - vector_dim_offset

index = object_index.ObjectIndex(uri=index_uri, config=config)
index = object_index.ObjectIndex(
uri=index_uri,
open_for_remote_query_execution=open_for_remote_query_execution,
config=config,
)
distances, objects, metadata = index.query(
{"object": np.array([[dim_id, dim_id, dim_id, dim_id]])}, k=21, **query_kwargs
)
Expand Down Expand Up @@ -414,6 +424,19 @@ def test_object_index_ivf_flat_cloud(tmp_path):
vector_dim_offset=0,
config=config,
)
evaluate_query(
index_type="IVF_FLAT",
index_uri=index_uri,
query_kwargs={
"nprobe": 10,
"driver_mode": Mode.REALTIME,
"driver_resource_class": "standard",
},
dim_id=42,
vector_dim_offset=0,
config=config,
open_for_remote_query_execution=True,
)

# Add new data with a new reader
reader = TestReader(
Expand All @@ -440,10 +463,15 @@ def test_object_index_ivf_flat_cloud(tmp_path):
evaluate_query(
index_type="IVF_FLAT",
index_uri=index_uri,
query_kwargs={"nprobe": 10},
query_kwargs={
"nprobe": 10,
"driver_mode": Mode.REALTIME,
"driver_resource_class": "standard",
},
dim_id=1042,
vector_dim_offset=0,
config=config,
open_for_remote_query_execution=True,
)
delete_uri(index_uri, config)

Expand Down
Loading