Skip to content

Commit

Permalink
Add benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
KenyonY committed Dec 26, 2023
1 parent 8b5a02f commit ab99887
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 10 deletions.
Binary file added .github/img/benchmark.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,16 @@ print(len(db))
in this case, you can use `db.write_immediately()` to immediately trigger a write operation.

### Benchmark
todo
![benchmark](.github/img/benchmark.png)

Test Content: Write and read traversal for N=10,000 numpy array vectors (each vector is 1000-dimensional).

Execute the test:
```bash
cd benchmark/
pytest -s -v run.py
```


### Use Cases
- **Key-Value Structure:**
Expand Down
20 changes: 11 additions & 9 deletions README_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,6 @@
- **线程安全**:仅使用必要的锁来确保安全的并发访问同时又能兼顾性能。


## TODO

- [x] 客户端-服务器架构
- [ ] 性能测试

---

## 快速入门
Expand Down Expand Up @@ -124,17 +119,24 @@ print(len(db))
此时可使用`db.write_immediately()`来立即触发写入操作。

### Benchmark
todo
![benchmark](.github/img/benchmark.png)

测试内容:对N=10,000 条1000维的numpy array进行写入和遍历读取

执行测试:
```bash
cd benchmark/
pytest -s -v run.py
```

### 适用场景

- **键-值型结构**
用于保存简单的键值结构数据
适用于保存简单的键值结构数据
- **高频写入**
非常适合需要高频插入/更新数据的场景
适合需要高频插入/更新数据的场景
- **机器学习**
`flaxkv`十分适合用于保存机器学习中的各种嵌入向量、图像、文本和其它键-值结构的大型数据集。
适用于保存机器学习中的各种嵌入向量、图像、文本和其它键-值结构的大型数据集。


## 引用
Expand Down
79 changes: 79 additions & 0 deletions benchmark/db_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from __future__ import annotations

import shelve
import shutil
import time
from pathlib import Path

from rocksdict import Options, Rdict


class RocksDict:
rdict_path = "./test_rocksdict"

def __init__(self):
self.db = Rdict(self.rdict_path)

def __setitem__(self, key, value):
self.db[key] = value

def __getitem__(self, key):
return self.db[key]

def __delitem__(self, key):
del self.db[key]

def keys(self):
return self.db.keys()

def items(self):
return self.db.items()

def __contains__(self, key):
return key in self.db

def __iter__(self):
return iter(self.db)

def destroy(self):
self.db.close()
Rdict.destroy(self.rdict_path)


class ShelveDict:
root_path = Path("./shelve_db")
if not root_path.exists():
root_path.mkdir(parents=True, exist_ok=True)

db_path = str(root_path / "test_shelve")

def __init__(self):
self.sd = shelve.open(self.db_path)

def __getitem__(self, key):
return self.sd[key]

def __setitem__(self, key, value):
self.sd[key] = value

def __delitem__(self, key):
del self.sd[key]

def __contains__(self, key):
return key in self.sd

def __iter__(self):
return iter(self.sd)

def __len__(self):
return len(self.sd)

def keys(self):
return self.sd.keys()

def items(self):
return self.sd.items()

def destroy(self):
self.sd.close()
shutil.rmtree(self.root_path)
125 changes: 125 additions & 0 deletions benchmark/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from __future__ import annotations

import numpy as np
import pandas as pd
import pytest
from db_class import RocksDict, ShelveDict
from rich import print
from sparrow import MeasureTime

from flaxkv import FlaxKV

benchmark_info = {}

N = 10_000


def prepare_data(n):
d = {}
for i in range(n):
d[f'vector-{i}'] = np.random.rand(1000)
return d


@pytest.fixture(scope="module", autouse=True)
def print_info(request):
def plot(df: pd.DataFrame):
import matplotlib.pyplot as plt

# df.plot(kind="bar", figsize=(10, 7))
df.reset_index(inplace=True)
plt.figure(figsize=(10, 6))
write_color = '#ADD8E6'
read_color = '#3EB489'
plt.bar(
df["index"],
df["write"],
width=0.4,
color=write_color,
label='Write',
align='center',
)
plt.bar(
df["index"],
df["read"],
width=0.4,
color=read_color,
label='Read',
align='edge',
)

plt.title(f"Read and Write (N={N}) 1000-dim vectors")
plt.xlabel("DB Type")
plt.ylabel("Time (seconds)")
plt.yscale('log')
plt.xticks(rotation=20)
plt.legend(title="Operation")
plt.show()

def print_result():
df = pd.DataFrame(benchmark_info).T
df = df.sort_values(by="read", ascending=True)
print()
print(df)
plot(df)

request.addfinalizer(print_result)


@pytest.fixture(
params=[
"dict",
"flaxkv-LMDB",
"flaxkv-LevelDB",
"RocksDict",
"Shelve",
]
)
def temp_db(request):

if request.param == "flaxkv-LMDB":
db = FlaxKV('benchmark', backend='lmdb')
elif request.param == "flaxkv-LevelDB":
db = FlaxKV('benchmark', backend='leveldb')
elif request.param == "RocksDict":
db = RocksDict()
elif request.param == "Shelve":
db = ShelveDict()
elif request.param == "dict":
db = {}
else:
raise
yield db, request.param
try:
db.destroy()
except:
...


def benchmark(db, db_name, n=200):
print("\n--------------------------")
data = prepare_data(n)
mt = MeasureTime().start()
for i, (key, value) in enumerate(data.items()):
db[key] = value

write_cost = mt.show_interval(f"{db_name} write")

keys = list(db.keys())
# shuffle keys
import random

random.shuffle(keys)

mt.start()
for key in keys:
a, b = key, db[key]
read_cost = mt.show_interval(f"{db_name} read (traverse elements) ")
print("--------------------------")
return write_cost, read_cost


def test_benchmark(temp_db):
db, db_name = temp_db
write_cost, read_cost = benchmark(db, db_name=db_name, n=N)
benchmark_info[db_name] = {"write": write_cost, "read": read_cost}

0 comments on commit ab99887

Please sign in to comment.