Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added back multiprocessing to benchmark scripts #1246

Merged
merged 1 commit into from
May 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions docs/references/benchmark.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,34 +66,35 @@ docker run -d --name infinity -v $HOME/infinity:/var/infinity --ulimit nofile=50

4. Run Benchmark:

Drop file cache before benchmark query latency.
Drop file cache before benchmark.

```bash
echo 3 | sudo tee /proc/sys/vm/drop_caches
```

Tasks of the Python script `run.py` include:
- Delete the original data.
- Re-insert the data.
- Calculate the time to insert data and build index.
- Calculate query latency.
- Calculate QPS.
- Generate fulltext query set.
- Measure the time to import data and build index.
- Measure the query latency.
- Measure the QPS.

```bash
$ python run.py -h
usage: run.py [-h] [--generate] [--import] [--query] [--query-express QUERY_EXPRESS] [--engine ENGINE] [--dataset DATASET]
usage: run.py [-h] [--generate] [--import] [--query QUERY] [--query-express QUERY_EXPRESS] [--concurrency CONCURRENCY] [--engine ENGINE] [--dataset DATASET]

RAG Database Benchmark

options:
-h, --help show this help message and exit
--generate Generate fulltext queries based on the dataset
--import Import data set into database engine
--query Run single client to benchmark query latency
--query-express QUERY_EXPRESS
Run multiple clients in express mode to benchmark QPS
--engine ENGINE database engine to benchmark, one of: all, infinity, qdrant, elasticsearch
--dataset DATASET data set to benchmark, one of: all, gist, sift, geonames, enwiki
-h, --help show this help message and exit
--generate Generate fulltext query set based on the dataset (default: False)
--import Import dataset into database engine (default: False)
--query QUERY Run the query set only once using given number of clients with recording the result and latency. This is for result validation and latency analysis (default: 0)
--query-express QUERY_EXPRESS
Run the query set randomly using given number of clients without recording the result and latency. This is for QPS measurement. (default: 0)
--concurrency CONCURRENCY
Choose concurrency mechanism, one of: mp - multiprocessing(recommended), mt - multithreading. (default: mp)
--engine ENGINE Choose database engine to benchmark, one of: infinity, qdrant, elasticsearch (default: infinity)
--dataset DATASET Choose dataset to benchmark, one of: gist, sift, geonames, enwiki (default: enwiki)
```

Following are commands for engine `infinity` and dataset `enwiki`:
Expand Down
202 changes: 167 additions & 35 deletions python/benchmark/clients/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import h5py
import numpy as np
import threading
import multiprocessing


class BaseClient:
Expand All @@ -25,11 +26,19 @@ def __init__(self, conf_path: str) -> None:
self.data = None
self.queries = list()
self.clients = list()
self.lock = threading.Lock()
self.next_begin = 0
self.results = []
self.done_queries = 0
self.active_threads = 0
# Following are for multithreading
self.mt_lock = threading.Lock()
self.mt_next_begin = 0
self.mt_done_queries = 0
self.mt_active_workers = 0
self.mt_results = []
# Following are for multiprocessing
self.mp_manager = multiprocessing.Manager()
self.mp_lock = multiprocessing.Lock()
self.mp_next_begin = multiprocessing.Value("i", 0, lock=False)
self.mp_done_queries = multiprocessing.Value("i", 0, lock=False)
self.mp_active_workers = multiprocessing.Value("i", 0, lock=False)
self.mp_results = self.mp_manager.list()

@abstractmethod
def upload(self):
Expand All @@ -39,7 +48,7 @@ def upload(self):
pass

@abstractmethod
def setup_clients(self, num_threads=1):
def setup_clients(self, num_workers=1):
pass

@abstractmethod
Expand All @@ -60,8 +69,8 @@ def download_data(self, url, target_path):
else:
subprocess.run(["wget", "-O", target_path, url], check=True)

def search(self, is_express=False, num_threads=1):
self.setup_clients(num_threads)
def search_mt(self, is_express=False, num_workers=1):
self.setup_clients(num_workers)

query_path = os.path.join(self.path_prefix, self.data["query_path"])
_, ext = os.path.splitext(query_path)
Expand All @@ -81,17 +90,17 @@ def search(self, is_express=False, num_threads=1):
query = json.loads(line)["vector"]
self.queries.append(query)

self.active_threads = num_threads
self.mt_active_workers = num_workers
threads = []
for i in range(num_threads):
for i in range(num_workers):
threads.append(
threading.Thread(
target=self.search_thread_mainloop,
args=[is_express, i],
daemon=True,
)
)
for i in range(num_threads):
for i in range(num_workers):
threads[i].start()

report_qps_sec = 60
Expand All @@ -107,28 +116,28 @@ def search(self, is_express=False, num_threads=1):
done_queries_prev = 0
done_queries_curr = 0

while self.active_threads > 0:
while self.mt_active_workers > 0:
time.sleep(sleep_sec)
sleep_cnt += 1
if sleep_cnt < report_qps_sec / sleep_sec:
continue
sleep_cnt = 0
now = time.time()
if done_warm_up:
with self.lock:
done_queries_curr = self.done_queries
with self.mt_lock:
done_queries_curr = self.mt_done_queries
avg_start = done_queries_curr / (now - start)
avg_interval = (done_queries_curr - done_queries_prev) / (
now - report_prev
)
done_queries_prev = done_queries_curr
report_prev = now
logging.info(
f"average QPS since {start_str}: {avg_start}, average QPS of last interval:{avg_interval}"
f"average QPS since {start_str}: {avg_start}, average QPS of last interval: {avg_interval}"
)
else:
with self.lock:
self.done_queries = 0
with self.mt_lock:
self.mt_done_queries = 0
start = now
start_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(start))
report_prev = now
Expand All @@ -137,10 +146,10 @@ def search(self, is_express=False, num_threads=1):
"Collecting statistics for 30 minutes. Print statistics so far every minute. Type Ctrl+C to quit."
)

for i in range(num_threads):
for i in range(num_workers):
threads[i].join()
if not is_express:
self.save_and_check_results(self.results)
self.save_and_check_results(self.mt_results)

def search_thread_mainloop(self, is_express: bool, client_id: int):
query_batch = 100
Expand All @@ -152,37 +161,159 @@ def search_thread_mainloop(self, is_express: bool, client_id: int):
for i in range(query_batch):
query_id = local_rng.randrange(0, num_queries)
_ = self.do_single_query(query_id, client_id)
with self.lock:
self.done_queries += query_batch
with self.mt_lock:
self.mt_done_queries += query_batch
else:
begin = 0
end = 0
local_results = list()
while end < num_queries:
with self.lock:
self.done_queries += end - begin
begin = self.next_begin
with self.mt_lock:
self.mt_done_queries += end - begin
begin = self.mt_next_begin
end = begin + query_batch
if end > num_queries:
end = num_queries
self.next_begin = end
self.mt_next_begin = end
for query_id in range(begin, end):
start = time.time()
result = self.do_single_query(query_id, client_id)
latency = (time.time() - start) * 1000
result = [(query_id, latency)] + result
local_results.append(result)
with self.lock:
self.done_queries += end - begin
self.results += local_results
with self.lock:
self.active_threads -= 1
with self.mt_lock:
self.mt_done_queries += end - begin
self.mt_results += local_results
with self.mt_lock:
self.mt_active_workers -= 1

def search_mp(self, is_express=False, num_workers=1):
query_path = os.path.join(self.path_prefix, self.data["query_path"])
_, ext = os.path.splitext(query_path)
if self.data["mode"] == "fulltext":
assert ext == ".txt"
for line in open(query_path, "r"):
line = line.strip()
self.queries.append(line)
else:
self.data["mode"] == "vector"
if ext == ".hdf5":
with h5py.File(query_path, "r") as f:
self.queries = list(f["test"])
else:
assert ext == "jsonl"
for line in open(query_path, "r"):
query = json.loads(line)["vector"]
self.queries.append(query)

self.mp_active_workers.value = num_workers
workers = []
for i in range(num_workers):
workers.append(
multiprocessing.Process(
target=self.search_process_mainloop,
args=[is_express],
daemon=True,
)
)
for i in range(num_workers):
workers[i].start()

report_qps_sec = 60
sleep_sec = 10
sleep_cnt = 0
done_warm_up = True
if is_express:
logging.info(f"Let database warm-up for {report_qps_sec} seconds")
done_warm_up = False
start = time.time()
start_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(start))
report_prev = start
done_queries_prev = 0
done_queries_curr = 0

while True:
active_workers = 0
with self.mp_lock:
active_workers = self.mp_active_workers.value
if active_workers <= 0:
break
time.sleep(sleep_sec)
sleep_cnt += 1
if sleep_cnt < report_qps_sec / sleep_sec:
continue
sleep_cnt = 0
now = time.time()
if done_warm_up:
with self.mp_lock:
done_queries_curr = self.mp_done_queries.value
avg_start = done_queries_curr / (now - start)
avg_interval = (done_queries_curr - done_queries_prev) / (
now - report_prev
)
done_queries_prev = done_queries_curr
report_prev = now
logging.info(
f"average QPS since {start_str}: {avg_start}, average QPS of last interval: {avg_interval}"
)
else:
with self.mp_lock:
self.mp_done_queries.value = 0
start = now
start_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(start))
report_prev = now
done_warm_up = True
logging.info(
"Collecting statistics for 30 minutes. Print statistics so far every minute. Type Ctrl+C to quit."
)

for i in range(num_workers):
workers[i].join()
if not is_express:
self.save_and_check_results(self.mp_results)

def search_process_mainloop(self, is_express: bool):
self.setup_clients(1) # socket is unsafe to share among workers
query_batch = 100
num_queries = len(self.queries)
if is_express:
local_rng = random.Random() # random number generator per thread
deadline = time.time() + 30 * 60 # 30 minutes
while time.time() < deadline:
for i in range(query_batch):
query_id = local_rng.randrange(0, num_queries)
_ = self.do_single_query(query_id, 0)
with self.mp_lock:
self.mp_done_queries.value += query_batch
else:
begin = 0
end = 0
local_results = list()
while end < num_queries:
with self.mp_lock:
self.mp_done_queries.value += end - begin
begin = self.mp_next_begin.value
end = begin + query_batch
if end > num_queries:
end = num_queries
self.mp_next_begin.value = end
for query_id in range(begin, end):
start = time.time()
result = self.do_single_query(query_id, 0)
latency = (time.time() - start) * 1000
result = [(query_id, latency)] + result
local_results.append(result)
with self.mp_lock:
self.mp_done_queries.value += end - begin
self.mp_results += local_results
with self.mp_lock:
self.mp_active_workers.value -= 1

def save_and_check_results(self, results: list[list[Any]]):
"""
Compare the search results with ground truth to calculate recall.
"""
self.results.sort(key=lambda x: x[0][0])
results = sorted(results, key=lambda x: x[0][0])
if "result_path" in self.data:
result_path = self.data["result_path"]
with open(result_path, "w") as f:
Expand Down Expand Up @@ -242,7 +373,8 @@ def run_experiment(self, args):
self.upload()
finish_time = time.time()
logging.info(f"upload finish, cost time = {finish_time - start_time}")
elif args.query >= 1:
self.search(is_express=False, num_threads=args.query)
elif args.query_express >= 1:
self.search(is_express=True, num_threads=args.query_express)
elif args.query >= 1 or args.query_express >= 1:
is_express = True if args.query_express >= 1 else False
search_func = self.search_mp if args.concurrency == "mp" else self.search_mt
num_workers = max(args.query, args.query_express)
search_func(is_express, num_workers)
Loading