-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
225 additions
and
10 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from __future__ import annotations | ||
|
||
import shelve | ||
import shutil | ||
import time | ||
from pathlib import Path | ||
|
||
from rocksdict import Options, Rdict | ||
|
||
|
||
class RocksDict: | ||
rdict_path = "./test_rocksdict" | ||
|
||
def __init__(self): | ||
self.db = Rdict(self.rdict_path) | ||
|
||
def __setitem__(self, key, value): | ||
self.db[key] = value | ||
|
||
def __getitem__(self, key): | ||
return self.db[key] | ||
|
||
def __delitem__(self, key): | ||
del self.db[key] | ||
|
||
def keys(self): | ||
return self.db.keys() | ||
|
||
def items(self): | ||
return self.db.items() | ||
|
||
def __contains__(self, key): | ||
return key in self.db | ||
|
||
def __iter__(self): | ||
return iter(self.db) | ||
|
||
def destroy(self): | ||
self.db.close() | ||
Rdict.destroy(self.rdict_path) | ||
|
||
|
||
class ShelveDict: | ||
root_path = Path("./shelve_db") | ||
if not root_path.exists(): | ||
root_path.mkdir(parents=True, exist_ok=True) | ||
|
||
db_path = str(root_path / "test_shelve") | ||
|
||
def __init__(self): | ||
self.sd = shelve.open(self.db_path) | ||
|
||
def __getitem__(self, key): | ||
return self.sd[key] | ||
|
||
def __setitem__(self, key, value): | ||
self.sd[key] = value | ||
|
||
def __delitem__(self, key): | ||
del self.sd[key] | ||
|
||
def __contains__(self, key): | ||
return key in self.sd | ||
|
||
def __iter__(self): | ||
return iter(self.sd) | ||
|
||
def __len__(self): | ||
return len(self.sd) | ||
|
||
def keys(self): | ||
return self.sd.keys() | ||
|
||
def items(self): | ||
return self.sd.items() | ||
|
||
def destroy(self): | ||
self.sd.close() | ||
shutil.rmtree(self.root_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
from __future__ import annotations | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
from db_class import RocksDict, ShelveDict | ||
from rich import print | ||
from sparrow import MeasureTime | ||
|
||
from flaxkv import FlaxKV | ||
|
||
benchmark_info = {} | ||
|
||
N = 10_000 | ||
|
||
|
||
def prepare_data(n): | ||
d = {} | ||
for i in range(n): | ||
d[f'vector-{i}'] = np.random.rand(1000) | ||
return d | ||
|
||
|
||
@pytest.fixture(scope="module", autouse=True) | ||
def print_info(request): | ||
def plot(df: pd.DataFrame): | ||
import matplotlib.pyplot as plt | ||
|
||
# df.plot(kind="bar", figsize=(10, 7)) | ||
df.reset_index(inplace=True) | ||
plt.figure(figsize=(10, 6)) | ||
write_color = '#ADD8E6' | ||
read_color = '#3EB489' | ||
plt.bar( | ||
df["index"], | ||
df["write"], | ||
width=0.4, | ||
color=write_color, | ||
label='Write', | ||
align='center', | ||
) | ||
plt.bar( | ||
df["index"], | ||
df["read"], | ||
width=0.4, | ||
color=read_color, | ||
label='Read', | ||
align='edge', | ||
) | ||
|
||
plt.title(f"Read and Write (N={N}) 1000-dim vectors") | ||
plt.xlabel("DB Type") | ||
plt.ylabel("Time (seconds)") | ||
plt.yscale('log') | ||
plt.xticks(rotation=20) | ||
plt.legend(title="Operation") | ||
plt.show() | ||
|
||
def print_result(): | ||
df = pd.DataFrame(benchmark_info).T | ||
df = df.sort_values(by="read", ascending=True) | ||
print() | ||
print(df) | ||
plot(df) | ||
|
||
request.addfinalizer(print_result) | ||
|
||
|
||
@pytest.fixture( | ||
params=[ | ||
"dict", | ||
"flaxkv-LMDB", | ||
"flaxkv-LevelDB", | ||
"RocksDict", | ||
"Shelve", | ||
] | ||
) | ||
def temp_db(request): | ||
|
||
if request.param == "flaxkv-LMDB": | ||
db = FlaxKV('benchmark', backend='lmdb') | ||
elif request.param == "flaxkv-LevelDB": | ||
db = FlaxKV('benchmark', backend='leveldb') | ||
elif request.param == "RocksDict": | ||
db = RocksDict() | ||
elif request.param == "Shelve": | ||
db = ShelveDict() | ||
elif request.param == "dict": | ||
db = {} | ||
else: | ||
raise | ||
yield db, request.param | ||
try: | ||
db.destroy() | ||
except: | ||
... | ||
|
||
|
||
def benchmark(db, db_name, n=200): | ||
print("\n--------------------------") | ||
data = prepare_data(n) | ||
mt = MeasureTime().start() | ||
for i, (key, value) in enumerate(data.items()): | ||
db[key] = value | ||
|
||
write_cost = mt.show_interval(f"{db_name} write") | ||
|
||
keys = list(db.keys()) | ||
# shuffle keys | ||
import random | ||
|
||
random.shuffle(keys) | ||
|
||
mt.start() | ||
for key in keys: | ||
a, b = key, db[key] | ||
read_cost = mt.show_interval(f"{db_name} read (traverse elements) ") | ||
print("--------------------------") | ||
return write_cost, read_cost | ||
|
||
|
||
def test_benchmark(temp_db): | ||
db, db_name = temp_db | ||
write_cost, read_cost = benchmark(db, db_name=db_name, n=N) | ||
benchmark_info[db_name] = {"write": write_cost, "read": read_cost} |