Skip to content

Commit

Permalink
Support AWS index URI in local-benchmarks.py (#516)
Browse files Browse the repository at this point in the history
  • Loading branch information
jparismorgan authored Sep 6, 2024
1 parent 8b4a53e commit 701b9cc
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 32 deletions.
10 changes: 7 additions & 3 deletions apis/python/test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,13 +394,17 @@ def setUpCloudToken():
tiledb.cloud.login(token=token)


def create_cloud_uri(name, folder_name=None):
def create_cloud_uri(name, folder_name=None, aws_uri=False):
namespace, storage_path, _ = groups._default_ns_path_cred()
storage_path = storage_path.replace("//", "/").replace("/", "//", 1)

if not folder_name:
folder_name = random_name("vector_search")
test_path = f"tiledb://{namespace}/{storage_path}/{folder_name}"
return f"{test_path}/{name}"

if aws_uri:
return f"{storage_path}/{folder_name}/{name}"
else:
return f"tiledb://{namespace}/{storage_path}/{folder_name}/{name}"


def delete_uri(uri, config):
Expand Down
68 changes: 39 additions & 29 deletions apis/python/test/local-benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import logging
import os
import shutil
import tarfile
import time
import urllib.request
Expand All @@ -29,6 +28,7 @@
class RemoteURIType(Enum):
LOCAL = 1
TILEDB = 2
AWS = 3


## Settings
Expand Down Expand Up @@ -170,7 +170,7 @@ def _summary_string(self):
summary_str += "\n"
return summary_str

def add_data_to_ingestion_time_vs_average_query_accuracy(self):
def add_data_to_ingestion_time_vs_average_query_accuracy(self, marker="o"):
summary = self._summarize_data()

for tag, data in summary.items():
Expand All @@ -183,9 +183,9 @@ def add_data_to_ingestion_time_vs_average_query_accuracy(self):
(data["ingestion"]["times"][i], average_accuracy)
)
x, y = zip(*ingestion_times)
plt.scatter(y, x, marker="o", label=tag)
plt.scatter(y, x, marker=marker, label=tag)

def add_data_to_query_time_vs_accuracy(self):
def add_data_to_query_time_vs_accuracy(self, marker="o"):
summary = self._summarize_data()

for tag, data in summary.items():
Expand All @@ -195,7 +195,7 @@ def add_data_to_query_time_vs_accuracy(self):
(data["query"]["times"][i], data["query"]["accuracies"][i])
)
x, y = zip(*query_times)
plt.plot(y, x, marker="o", label=tag)
plt.plot(y, x, marker=marker, label=tag)

def save_charts(self):
# Plot ingestion.
Expand Down Expand Up @@ -239,13 +239,17 @@ def new_timer(self, name):
return timer

def save_charts(self):
markers = ["o", "^", "D", "*", "P", "s", "2"]

# Plot ingestion.
plt.figure(figsize=(20, 12))
plt.xlabel("Average Query Accuracy")
plt.ylabel("Time (seconds)")
plt.title("Ingestion Time vs Average Query Accuracy")
for timer in self.timers:
timer.add_data_to_ingestion_time_vs_average_query_accuracy()
for idx, timer in self.timers:
timer.add_data_to_ingestion_time_vs_average_query_accuracy(
markers[idx % len(markers)]
)
plt.legend()
plt.savefig(os.path.join(RESULTS_DIR, "ingestion_time_vs_accuracy.png"))
plt.close()
Expand All @@ -255,8 +259,8 @@ def save_charts(self):
plt.xlabel("Accuracy")
plt.ylabel("Time (seconds)")
plt.title("Query Time vs Accuracy")
for timer in self.timers:
timer.add_data_to_query_time_vs_accuracy()
for idx, timer in self.timers:
timer.add_data_to_query_time_vs_accuracy(markers[idx % len(markers)])
plt.legend()
plt.savefig(os.path.join(RESULTS_DIR, "query_time_vs_accuracy.png"))
plt.close()
Expand All @@ -281,32 +285,44 @@ def download_and_extract(url, download_path, extract_path):
logger.info("Finished extracting files.")


config = {}


def get_uri(tag):
index_name = f"index_{tag.replace('=', '_')}"
index_uri = ""
if REMOTE_URI_TYPE == RemoteURIType.LOCAL:
index_uri = os.path.join(TEMP_DIR, index_name)
logger.info(f"Local URI {index_uri}")
if os.path.exists(index_uri):
shutil.rmtree(index_uri)
return index_uri
elif REMOTE_URI_TYPE == RemoteURIType.TILEDB:
from common import create_cloud_uri
from common import setUpCloudToken

setUpCloudToken()
index_uri = create_cloud_uri(index_name, "local_benchmarks")
logger.info(f"TileDB URI {index_uri}")
Index.delete_index(uri=index_uri, config=tiledb.cloud.Config())
return index_uri

config = tiledb.cloud.Config()
elif REMOTE_URI_TYPE == RemoteURIType.AWS:
from common import create_cloud_uri
from common import setUpCloudToken

setUpCloudToken()
index_uri = create_cloud_uri(index_name, "local_benchmarks", True)

config = {
"vfs.s3.aws_access_key_id": os.environ["AWS_ACCESS_KEY_ID"],
"vfs.s3.aws_secret_access_key": os.environ["AWS_SECRET_ACCESS_KEY"],
"vfs.s3.region": os.environ["AWS_REGION"],
}
else:
raise ValueError(f"Invalid REMOTE_URI_TYPE {REMOTE_URI_TYPE}")

logger.info(f"index_uri: {index_uri}")
Index.delete_index(index_uri, config)
return index_uri

def cleanup_uri(index_uri):
if REMOTE_URI_TYPE == RemoteURIType.TILEDB:
from common import delete_uri

delete_uri(uri=index_uri, config=tiledb.cloud.Config())
def cleanup_uri(index_uri):
Index.delete_index(index_uri, config)


def benchmark_ivf_flat():
Expand All @@ -328,9 +344,7 @@ def benchmark_ivf_flat():
index_type=index_type,
index_uri=index_uri,
source_uri=SIFT_BASE_PATH,
config=tiledb.cloud.Config().dict()
if REMOTE_URI_TYPE is not None
else None,
config=config,
partitions=partitions,
training_sampling_policy=TrainingSamplingPolicy.RANDOM,
)
Expand Down Expand Up @@ -370,9 +384,7 @@ def benchmark_vamana():
index_type=index_type,
index_uri=index_uri,
source_uri=SIFT_BASE_PATH,
config=tiledb.cloud.Config().dict()
if REMOTE_URI_TYPE is not None
else None,
config=config,
l_build=l_build,
r_max_degree=r_max_degree,
training_sampling_policy=TrainingSamplingPolicy.RANDOM,
Expand Down Expand Up @@ -414,9 +426,7 @@ def benchmark_ivf_pq():
index_type=index_type,
index_uri=index_uri,
source_uri=SIFT_BASE_PATH,
config=tiledb.cloud.Config().dict()
if REMOTE_URI_TYPE is not None
else None,
config=config,
partitions=partitions,
training_sampling_policy=TrainingSamplingPolicy.RANDOM,
num_subspaces=num_subspaces,
Expand Down

0 comments on commit 701b9cc

Please sign in to comment.