Skip to content

Commit

Permalink
Add query with driver mode for ObjectIndex (#541)
Browse files Browse the repository at this point in the history
Add query with driver mode for `ObjectIndex`.

This add options to open an `ObjectIndex` for remote query execution, as well as provide resources and driver mode per query. This allows to both compute object embeddings and vector queries in the server side.
  • Loading branch information
NikolaosPapailiou authored Oct 8, 2024
1 parent a4a082f commit 491e09d
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 19 deletions.
28 changes: 20 additions & 8 deletions apis/python/src/tiledb/vector_search/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,16 @@ def _query_with_driver(
queries: np.ndarray,
k: int,
driver_mode=None,
driver_resources=None,
driver_resource_class: Optional[str] = None,
driver_resources: Optional[Mapping[str, Any]] = None,
driver_access_credentials_name=None,
**kwargs,
):
from tiledb.cloud import dag

if driver_resource_class and driver_resources:
raise TypeError("Cannot provide both resource_class and resources")

def query_udf(index_type, index_open_kwargs, query_kwargs):
from tiledb.vector_search.flat_index import FlatIndex
from tiledb.vector_search.ivf_flat_index import IVFFlatIndex
Expand Down Expand Up @@ -226,6 +230,9 @@ def query_udf(index_type, index_open_kwargs, query_kwargs):
self.index_open_kwargs,
query_kwargs,
name="vector-query-driver",
resource_class="large"
if (not driver_resources and not driver_resource_class)
else driver_resource_class,
resources=driver_resources,
image_name="vectorsearch",
access_credentials_name=driver_access_credentials_name,
Expand All @@ -239,7 +246,8 @@ def query(
queries: np.ndarray,
k: int,
driver_mode: Optional[Mode] = None,
driver_resources: Optional[str] = None,
driver_resource_class: Optional[str] = None,
driver_resources: Optional[Mapping[str, Any]] = None,
driver_access_credentials_name: Optional[str] = None,
**kwargs,
):
Expand All @@ -265,8 +273,11 @@ def query(
Number of results to return per query vector.
driver_mode: Mode
If not `None`, the query will be executed in a TileDB cloud taskgraph using the driver mode specified.
driver_resource_class: Optional[str]
If `driver_mode` was `REALTIME`, the resources class (`standard` or `large`) to use for the driver execution.
driver_resources: Optional[str]
If `driver_mode` was not `None`, the resources to use for the driver execution.
If `driver_mode` was `BATCH`, the resources to use for the driver execution.
Example `{"cpu": "1", "memory": "4Gi"}`
driver_access_credentials_name: Optional[str]
If `driver_mode` was not `None`, the access credentials name to use for the driver execution.
**kwargs
Expand Down Expand Up @@ -307,11 +318,12 @@ def flip_results(results):
"Cannot pass driver_mode=Mode.LOCAL to query() - use driver_mode=None to query locally."
)
results, indexes = self._query_with_driver(
queries,
k,
driver_mode,
driver_resources,
driver_access_credentials_name,
queries=queries,
k=k,
driver_mode=driver_mode,
driver_resources=driver_resources,
driver_resource_class=driver_resource_class,
driver_access_credentials_name=driver_access_credentials_name,
**kwargs,
)
if self.distance_metric == vspy.DistanceMetric.INNER_PRODUCT:
Expand Down
10 changes: 2 additions & 8 deletions apis/python/src/tiledb/vector_search/ivf_flat_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,19 +325,13 @@ def query(
self,
queries: np.ndarray,
k: int,
driver_mode: Mode = None,
driver_resources: Optional[str] = None,
driver_access_credentials_name: Optional[str] = None,
**kwargs,
):
if self.distance_metric == vspy.DistanceMetric.COSINE:
queries = normalize_vectors(queries)
return super().query(
queries,
k,
driver_mode,
driver_resources,
driver_access_credentials_name,
queries=queries,
k=k,
**kwargs,
)

Expand Down
163 changes: 163 additions & 0 deletions apis/python/src/tiledb/vector_search/object_api/object_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ class ObjectIndex:
Timestamp to open the index at.
load_embedding: bool
Whether to load the embedding function into memory.
open_for_remote_query_execution: bool
If `True`, do not load the embedding model and any index data locally, and instead perform all query functionality in a TileDB Cloud taskgraph.
open_vector_index_for_remote_query_execution: bool
If `True`, do not load any index data in main memory locally, and instead load index data and perform vector queries in a TileDB Cloud taskgraph.
Compared to `open_for_remote_query_execution`, this loads the object embedding function and computes query object embeddings locally.
load_metadata_in_memory: bool
Whether to load the metadata array into memory.
environment_variables: Dict
Expand All @@ -67,6 +72,8 @@ def __init__(
uri: str,
config: Optional[Mapping[str, Any]] = None,
timestamp=None,
open_for_remote_query_execution: bool = False,
open_vector_index_for_remote_query_execution: bool = False,
load_embedding: bool = True,
load_metadata_in_memory: bool = True,
environment_variables: Dict = {},
Expand All @@ -80,7 +87,29 @@ def __init__(
self.uri = uri
self.config = config
self.timestamp = timestamp
self.open_for_remote_query_execution = open_for_remote_query_execution
self.open_vector_index_for_remote_query_execution = (
open_vector_index_for_remote_query_execution
)
if self.open_vector_index_for_remote_query_execution:
kwargs["open_for_remote_query_execution"] = True
self.load_embedding = load_embedding
self.load_metadata_in_memory = load_metadata_in_memory
self.environment_variables = environment_variables
self.kwargs = kwargs
self.index_open_kwargs = {
"uri": uri,
"config": config,
"timestamp": timestamp,
"load_embedding": load_embedding,
"load_metadata_in_memory": load_metadata_in_memory,
"environment_variables": environment_variables,
"kwargs": kwargs,
}

if self.open_for_remote_query_execution:
return

group = tiledb.Group(uri, "r")
self.index_type = group.meta["index_type"]
group.close()
Expand Down Expand Up @@ -156,6 +185,95 @@ def __init__(
if self.load_metadata_in_memory:
self.metadata_df = self.metadata_array.df[:]

def _query_with_driver(
self,
query_objects: np.ndarray,
k: int,
query_metadata: Optional[OrderedDict] = None,
metadata_array_cond: Optional[str] = None,
metadata_df_filter_fn: Optional[str] = None,
return_objects: bool = True,
return_metadata: bool = True,
driver_mode: Optional[Mode] = Mode.REALTIME,
driver_resource_class: Optional[str] = None,
driver_resources: Optional[Mapping[str, Any]] = None,
driver_access_credentials_name: Optional[str] = None,
extra_driver_modules: Optional[List[str]] = None,
object_index_source_code: Optional[str] = None,
**kwargs,
):
from tiledb.cloud import dag

if driver_resource_class and driver_resources:
raise TypeError("Cannot provide both resource_class and resources")

def query_udf(
index_open_kwargs,
query_kwargs,
extra_driver_modules: Optional[List[str]] = None,
object_index_source_code: Optional[str] = None,
):
def install_extra_driver_modules():
if extra_driver_modules is not None:
import os
import subprocess
import sys

sys.path.insert(0, "/home/udf/.local/bin")
sys.path.insert(0, "/home/udf/.local/lib/python3.9/site-packages")
os.environ["PATH"] = f"/home/udf/.local/bin:{os.environ['PATH']}"
for module in extra_driver_modules:
subprocess.check_call(
[sys.executable, "-m", "pip", "install", module]
)

install_extra_driver_modules()
from tiledb.vector_search.object_api import object_index

if object_index_source_code is not None:
index = object_index.instantiate_object(
code=object_index_source_code,
class_name="ObjectIndex",
**index_open_kwargs,
)
else:
index = object_index.ObjectIndex(**index_open_kwargs)
# Query index
return index.query(**query_kwargs)

d = dag.DAG(
name="vector-query",
mode=driver_mode,
max_workers=1,
)
query_kwargs = {
"query_objects": query_objects,
"k": k,
"query_metadata": query_metadata,
"metadata_array_cond": metadata_array_cond,
"metadata_df_filter_fn": metadata_df_filter_fn,
"return_objects": return_objects,
"return_metadata": return_metadata,
}
query_kwargs.update(kwargs)
node = d.submit(
query_udf,
self.index_open_kwargs,
query_kwargs,
extra_driver_modules=extra_driver_modules,
object_index_source_code=object_index_source_code,
name="vector-query-driver",
resource_class="large"
if (not driver_resources and not driver_resource_class)
else driver_resource_class,
resources=driver_resources,
image_name="vectorsearch",
access_credentials_name=driver_access_credentials_name,
)
d.compute()
d.wait()
return node.result()

def query(
self,
query_objects: np.ndarray,
Expand All @@ -165,6 +283,11 @@ def query(
metadata_df_filter_fn: Optional[str] = None,
return_objects: bool = True,
return_metadata: bool = True,
driver_mode: Optional[Mode] = Mode.REALTIME,
driver_resource_class: Optional[str] = None,
driver_resources: Optional[Mapping[str, Any]] = None,
extra_driver_modules: Optional[List[str]] = None,
driver_access_credentials_name: Optional[str] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -226,6 +349,17 @@ def query(
Whether to return the objects themselves, or just the object IDs.
return_metadata: bool
Whether to return the metadata for the objects.
driver_mode: Mode
If not `None`, the query will be executed in a TileDB cloud taskgraph using the driver mode specified.
driver_resource_class: Optional[str]
If `driver_mode` was `REALTIME`, the resources class (`standard` or `large`) to use for the driver execution.
driver_resources: Optional[str]
If `driver_mode` was `BATCH`, the resources to use for the driver execution.
Example `{"cpu": "1", "memory": "4Gi"}`
extra_driver_modules: List[str], optional
A list of extra Python modules to install on the driver node.
driver_access_credentials_name: Optional[str]
If `driver_mode` was not `None`, the access credentials name to use for the driver execution.
**kwargs
Keyword arguments to pass to the index query method.
Expand All @@ -240,6 +374,35 @@ def query(
A tuple containing the distances, objects or object IDs, and optionally
the object metadata.
"""
if self.open_for_remote_query_execution:
if driver_mode is None or driver_mode == Mode.LOCAL:
raise TypeError(
f"Cannot pass driver_mode={driver_mode} to query() when using `open_for_remote_query_execution`."
)

object_index_source_code = get_source_code(self)
return self._query_with_driver(
query_objects=query_objects,
k=k,
query_metadata=query_metadata,
metadata_array_cond=metadata_array_cond,
metadata_df_filter_fn=metadata_df_filter_fn,
return_objects=return_objects,
return_metadata=return_metadata,
driver_mode=driver_mode,
driver_resource_class=driver_resource_class,
driver_resources=driver_resources,
extra_driver_modules=extra_driver_modules,
driver_access_credentials_name=driver_access_credentials_name,
object_index_source_code=object_index_source_code,
**kwargs,
)
elif self.open_vector_index_for_remote_query_execution:
kwargs["driver_mode"] = driver_mode
kwargs["driver_resource_class"] = driver_resource_class
kwargs["driver_resources"] = driver_resources
kwargs["driver_access_credentials_name"] = driver_access_credentials_name

if (
metadata_array_cond is not None or metadata_df_filter_fn is not None
) and self.object_metadata_array_uri is None:
Expand Down
34 changes: 31 additions & 3 deletions apis/python/test/test_object_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,21 @@ def assert_equal(


def evaluate_query(
index_type: str, index_uri, query_kwargs, dim_id, vector_dim_offset, config=None
index_type: str,
index_uri,
query_kwargs,
dim_id,
vector_dim_offset,
config=None,
open_for_remote_query_execution=False,
):
v_id = dim_id - vector_dim_offset

index = object_index.ObjectIndex(uri=index_uri, config=config)
index = object_index.ObjectIndex(
uri=index_uri,
open_for_remote_query_execution=open_for_remote_query_execution,
config=config,
)
distances, objects, metadata = index.query(
{"object": np.array([[dim_id, dim_id, dim_id, dim_id]])}, k=21, **query_kwargs
)
Expand Down Expand Up @@ -414,6 +424,19 @@ def test_object_index_ivf_flat_cloud(tmp_path):
vector_dim_offset=0,
config=config,
)
evaluate_query(
index_type="IVF_FLAT",
index_uri=index_uri,
query_kwargs={
"nprobe": 10,
"driver_mode": Mode.REALTIME,
"driver_resource_class": "standard",
},
dim_id=42,
vector_dim_offset=0,
config=config,
open_for_remote_query_execution=True,
)

# Add new data with a new reader
reader = TestReader(
Expand All @@ -440,10 +463,15 @@ def test_object_index_ivf_flat_cloud(tmp_path):
evaluate_query(
index_type="IVF_FLAT",
index_uri=index_uri,
query_kwargs={"nprobe": 10},
query_kwargs={
"nprobe": 10,
"driver_mode": Mode.REALTIME,
"driver_resource_class": "standard",
},
dim_id=1042,
vector_dim_offset=0,
config=config,
open_for_remote_query_execution=True,
)
delete_uri(index_uri, config)

Expand Down

0 comments on commit 491e09d

Please sign in to comment.