diff --git a/apis/python/test/common.py b/apis/python/test/common.py index ae703507a..49bae34c6 100644 --- a/apis/python/test/common.py +++ b/apis/python/test/common.py @@ -394,13 +394,17 @@ def setUpCloudToken(): tiledb.cloud.login(token=token) -def create_cloud_uri(name, folder_name=None): +def create_cloud_uri(name, folder_name=None, aws_uri=False): namespace, storage_path, _ = groups._default_ns_path_cred() storage_path = storage_path.replace("//", "/").replace("/", "//", 1) + if not folder_name: folder_name = random_name("vector_search") - test_path = f"tiledb://{namespace}/{storage_path}/{folder_name}" - return f"{test_path}/{name}" + + if aws_uri: + return f"{storage_path}/{folder_name}/{name}" + else: + return f"tiledb://{namespace}/{storage_path}/{folder_name}/{name}" def delete_uri(uri, config): diff --git a/apis/python/test/local-benchmarks.py b/apis/python/test/local-benchmarks.py index 233886bbc..cce9b5af5 100644 --- a/apis/python/test/local-benchmarks.py +++ b/apis/python/test/local-benchmarks.py @@ -7,7 +7,6 @@ import logging import os -import shutil import tarfile import time import urllib.request @@ -29,6 +28,7 @@ class RemoteURIType(Enum): LOCAL = 1 TILEDB = 2 + AWS = 3 ## Settings @@ -170,7 +170,7 @@ def _summary_string(self): summary_str += "\n" return summary_str - def add_data_to_ingestion_time_vs_average_query_accuracy(self): + def add_data_to_ingestion_time_vs_average_query_accuracy(self, marker="o"): summary = self._summarize_data() for tag, data in summary.items(): @@ -183,9 +183,9 @@ def add_data_to_ingestion_time_vs_average_query_accuracy(self): (data["ingestion"]["times"][i], average_accuracy) ) x, y = zip(*ingestion_times) - plt.scatter(y, x, marker="o", label=tag) + plt.scatter(y, x, marker=marker, label=tag) - def add_data_to_query_time_vs_accuracy(self): + def add_data_to_query_time_vs_accuracy(self, marker="o"): summary = self._summarize_data() for tag, data in summary.items(): @@ -195,7 +195,7 @@ def add_data_to_query_time_vs_accuracy(self): (data["query"]["times"][i], data["query"]["accuracies"][i]) ) x, y = zip(*query_times) - plt.plot(y, x, marker="o", label=tag) + plt.plot(y, x, marker=marker, label=tag) def save_charts(self): # Plot ingestion. @@ -239,13 +239,17 @@ def new_timer(self, name): return timer def save_charts(self): + markers = ["o", "^", "D", "*", "P", "s", "2"] + # Plot ingestion. plt.figure(figsize=(20, 12)) plt.xlabel("Average Query Accuracy") plt.ylabel("Time (seconds)") plt.title("Ingestion Time vs Average Query Accuracy") - for timer in self.timers: - timer.add_data_to_ingestion_time_vs_average_query_accuracy() + for idx, timer in self.timers: + timer.add_data_to_ingestion_time_vs_average_query_accuracy( + markers[idx % len(markers)] + ) plt.legend() plt.savefig(os.path.join(RESULTS_DIR, "ingestion_time_vs_accuracy.png")) plt.close() @@ -255,8 +259,8 @@ def save_charts(self): plt.xlabel("Accuracy") plt.ylabel("Time (seconds)") plt.title("Query Time vs Accuracy") - for timer in self.timers: - timer.add_data_to_query_time_vs_accuracy() + for idx, timer in self.timers: + timer.add_data_to_query_time_vs_accuracy(markers[idx % len(markers)]) plt.legend() plt.savefig(os.path.join(RESULTS_DIR, "query_time_vs_accuracy.png")) plt.close() @@ -281,32 +285,44 @@ def download_and_extract(url, download_path, extract_path): logger.info("Finished extracting files.") +config = {} + + def get_uri(tag): index_name = f"index_{tag.replace('=', '_')}" + index_uri = "" if REMOTE_URI_TYPE == RemoteURIType.LOCAL: index_uri = os.path.join(TEMP_DIR, index_name) - logger.info(f"Local URI {index_uri}") - if os.path.exists(index_uri): - shutil.rmtree(index_uri) - return index_uri elif REMOTE_URI_TYPE == RemoteURIType.TILEDB: from common import create_cloud_uri from common import setUpCloudToken setUpCloudToken() index_uri = create_cloud_uri(index_name, "local_benchmarks") - logger.info(f"TileDB URI {index_uri}") - Index.delete_index(uri=index_uri, config=tiledb.cloud.Config()) - return index_uri + + config = tiledb.cloud.Config() + elif REMOTE_URI_TYPE == RemoteURIType.AWS: + from common import create_cloud_uri + from common import setUpCloudToken + + setUpCloudToken() + index_uri = create_cloud_uri(index_name, "local_benchmarks", True) + + config = { + "vfs.s3.aws_access_key_id": os.environ["AWS_ACCESS_KEY_ID"], + "vfs.s3.aws_secret_access_key": os.environ["AWS_SECRET_ACCESS_KEY"], + "vfs.s3.region": os.environ["AWS_REGION"], + } else: raise ValueError(f"Invalid REMOTE_URI_TYPE {REMOTE_URI_TYPE}") + logger.info(f"index_uri: {index_uri}") + Index.delete_index(index_uri, config) + return index_uri -def cleanup_uri(index_uri): - if REMOTE_URI_TYPE == RemoteURIType.TILEDB: - from common import delete_uri - delete_uri(uri=index_uri, config=tiledb.cloud.Config()) +def cleanup_uri(index_uri): + Index.delete_index(index_uri, config) def benchmark_ivf_flat(): @@ -328,9 +344,7 @@ def benchmark_ivf_flat(): index_type=index_type, index_uri=index_uri, source_uri=SIFT_BASE_PATH, - config=tiledb.cloud.Config().dict() - if REMOTE_URI_TYPE is not None - else None, + config=config, partitions=partitions, training_sampling_policy=TrainingSamplingPolicy.RANDOM, ) @@ -370,9 +384,7 @@ def benchmark_vamana(): index_type=index_type, index_uri=index_uri, source_uri=SIFT_BASE_PATH, - config=tiledb.cloud.Config().dict() - if REMOTE_URI_TYPE is not None - else None, + config=config, l_build=l_build, r_max_degree=r_max_degree, training_sampling_policy=TrainingSamplingPolicy.RANDOM, @@ -414,9 +426,7 @@ def benchmark_ivf_pq(): index_type=index_type, index_uri=index_uri, source_uri=SIFT_BASE_PATH, - config=tiledb.cloud.Config().dict() - if REMOTE_URI_TYPE is not None - else None, + config=config, partitions=partitions, training_sampling_policy=TrainingSamplingPolicy.RANDOM, num_subspaces=num_subspaces,