Skip to content

Commit

Permalink
Speed up bulk processor flush and cut release (#187)
Browse files Browse the repository at this point in the history
* speed up bulk processor

* set versions and better defaults

* set versions and better defaults

* lint

* one more lint

* delete temp db

* fixed the queue flush of small chunks

* lint and unit tests
  • Loading branch information
philkra committed Dec 14, 2023
1 parent 8a2456a commit 1eb8d7b
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 75 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "xata"
version = "1.2.1"
version = "1.2.2"
description = "Python SDK for Xata.io"
authors = ["Xata <support@xata.io>"]
license = "Apache-2.0"
Expand Down
94 changes: 65 additions & 29 deletions tests/integration-tests/helpers_bulkprocessor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,36 +30,44 @@
class TestHelpersBulkProcessor(object):
def setup_class(self):
self.db_name = utils.get_db_name()
self.branch_name = "main"
self.client = XataClient(db_name=self.db_name, branch_name=self.branch_name)
self.client = XataClient(db_name=self.db_name)
self.fake = Faker()

assert self.client.databases().create(self.db_name).is_success()
assert self.client.table().create("Posts").is_success()
assert self.client.table().create("Users").is_success()

# create schema
assert self.client.table().set_schema(
"Posts",
{
"columns": [
{"name": "title", "type": "string"},
{"name": "text", "type": "text"},
]
},
).is_success()
assert self.client.table().set_schema(
"Users",
{
"columns": [
{"name": "username", "type": "string"},
{"name": "email", "type": "string"},
]
},
).is_success()
assert (
self.client.table()
.set_schema(
"Posts",
{
"columns": [
{"name": "title", "type": "string"},
{"name": "text", "type": "text"},
]
},
)
.is_success()
)
assert (
self.client.table()
.set_schema(
"Users",
{
"columns": [
{"name": "username", "type": "string"},
{"name": "email", "type": "string"},
]
},
)
.is_success()
)

def teardown_class(self):
assert self.client.databases().delete(self.db_name).is_success()
# assert self.client.databases().delete(self.db_name).is_success()
pass

@pytest.fixture
def record(self) -> dict:
Expand All @@ -70,7 +78,7 @@ def _get_record(self) -> dict:
"title": self.fake.company(),
"text": self.fake.text(),
}

def _get_user(self) -> dict:
return {
"username": self.fake.name(),
Expand All @@ -81,20 +89,22 @@ def test_bulk_insert_records(self, record: dict):
bp = BulkProcessor(
self.client,
thread_pool_size=1,
batch_size=43,
)
bp.put_records("Posts", [self._get_record() for x in range(42)])
bp.flush_queue()

r = self.client.data().summarize("Posts", {"summaries": {"proof": {"count": "*"}}})
assert r.is_success()
assert "summaries" in r
assert r["summaries"][0]["proof"] == 42

stats = bp.get_stats()
assert stats["total"] == 42
assert stats["queue"] == 0
assert stats["failed_batches"] == 0
assert stats["tables"]["Posts"] == 42
assert stats["total_batches"] == 1

r = self.client.data().summarize("Posts", {"summaries": {"proof": {"count": "*"}}})
assert r.is_success()
assert "summaries" in r
assert r["summaries"][0]["proof"] == stats["total"]

def test_flush_queue(self):
assert self.client.sql().query('DELETE FROM "Posts" WHERE 1 = 1').is_success()
Expand All @@ -112,15 +122,40 @@ def test_flush_queue(self):
assert r.is_success()
assert "summaries" in r
assert r["summaries"][0]["proof"] == 1000

stats = bp.get_stats()
assert stats["total"] == 1000
assert stats["queue"] == 0
assert stats["failed_batches"] == 0
assert stats["total_batches"] == 20
assert stats["tables"]["Posts"] == 1000

def test_flush_queue_many_threads(self):
assert self.client.sql().query('DELETE FROM "Users" WHERE 1 = 1').is_success()

bp = BulkProcessor(
self.client,
thread_pool_size=8,
batch_size=10,
)
bp.put_records("Users", [self._get_user() for x in range(750)])
bp.flush_queue()

r = self.client.data().summarize("Users", {"summaries": {"proof": {"count": "*"}}})
assert r.is_success()
assert "summaries" in r
assert r["summaries"][0]["proof"] == 750

stats = bp.get_stats()
assert stats["total"] == 750
assert stats["queue"] == 0
assert stats["failed_batches"] == 0
assert stats["total_batches"] == 75
assert stats["tables"]["Users"] == 750

def test_multiple_tables(self):
assert self.client.sql().query('DELETE FROM "Posts" WHERE 1 = 1').is_success()
assert self.client.sql().query('DELETE FROM "Users" WHERE 1 = 1').is_success()

bp = BulkProcessor(
self.client,
Expand All @@ -141,10 +176,11 @@ def test_multiple_tables(self):
assert r.is_success()
assert "summaries" in r
assert r["summaries"][0]["proof"] == 33 * 7

stats = bp.get_stats()
assert stats["queue"] == 0
assert stats["failed_batches"] == 0
assert stats["total_batches"] == 14
assert stats["tables"]["Posts"] == 33 * 9
assert stats["tables"]["Users"] == 33 * 7
assert stats["total"] == stats["tables"]["Posts"] + stats["tables"]["Users"]
6 changes: 3 additions & 3 deletions tests/unit-tests/helpers_bulk_processor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ def test_bulk_processor_init(self):

with pytest.raises(Exception) as e:
BulkProcessor(client, batch_size=-1)
assert str(e.value) == "batch size can not be less than one, default: 25"
assert str(e.value) == "batch size can not be less than one, default: 50"

with pytest.raises(Exception) as e:
BulkProcessor(client, flush_interval=-1)
assert str(e.value) == "flush interval can not be negative, default: 5.000000"
assert str(e.value) == "flush interval can not be negative, default: 2.000000"

with pytest.raises(Exception) as e:
BulkProcessor(client, processing_timeout=-1)
assert str(e.value) == "processing timeout can not be negative, default: 0.025000"
assert str(e.value) == "processing timeout can not be negative, default: 0.050000"

def test_bulk_processor_stats(self):
client = XataClient(api_key="api_key", workspace_id="ws_id")
Expand Down
2 changes: 1 addition & 1 deletion xata/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

# TODO this is a manual task, to keep in sync with pyproject.toml
# could/should be automated to keep in sync
__version__ = "1.2.1"
__version__ = "1.2.2"

PERSONAL_API_KEY_LOCATION = "~/.config/xata/key"
DEFAULT_DATA_PLANE_DOMAIN = "xata.sh"
Expand Down
91 changes: 50 additions & 41 deletions xata/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
from .client import XataClient

BP_DEFAULT_THREAD_POOL_SIZE = 4
BP_DEFAULT_BATCH_SIZE = 25
BP_DEFAULT_FLUSH_INTERVAL = 5
BP_DEFAULT_PROCESSING_TIMEOUT = 0.025
BP_DEFAULT_BATCH_SIZE = 50
BP_DEFAULT_FLUSH_INTERVAL = 2
BP_DEFAULT_PROCESSING_TIMEOUT = 0.05
BP_DEFAULT_THROW_EXCEPTION = False
BP_VERSION = "0.3.0"
BP_VERSION = "0.3.1"
TRX_MAX_OPERATIONS = 1000
TRX_VERSION = "0.1.0"
TRX_BACKOFF = 0.1
Expand Down Expand Up @@ -85,10 +85,13 @@ def __init__(
self.flush_interval = flush_interval
self.failed_batches_queue = []
self.throw_exception = throw_exception
self.stats = {"total": 0, "queue": 0, "failed_batches": 0, "tables": {}}

self.stats = {"total": 0, "queue": 0, "failed_batches": 0, "total_batches": 0, "tables": {}}
self.stats_lock = Lock()
self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")

self.thread_workers = []
self.worker_active = True
self.records = self.Records(self.batch_size, self.flush_interval, self.logger)

for i in range(thread_pool_size):
Expand All @@ -110,12 +113,16 @@ def process(self, id: int):
self.processing_timeout,
)
)
while True:
while self.worker_active:
sleep_backoff = 5 # slow down if no records exist
time.sleep(self.processing_timeout * sleep_backoff)

# process
batch = self.records.next_batch()
if "table" in batch and len(batch["records"]) > 0:
try:
r = self.client.records().bulk_insert(batch["table"], {"records": batch["records"]})
if r.status_code != 200:
if not r.is_success():
self.logger.error(
"thread #%d: unable to process batch for table '%s', with error: %d - %s"
% (id, batch["table"], r.status_code, r.json())
Expand All @@ -137,14 +144,16 @@ def process(self, id: int):
"thread #%d: pushed a batch of %d records to table %s"
% (id, len(batch["records"]), batch["table"])
)
# with self.stats_lock:
self.stats["total"] += len(batch["records"])
self.stats["queue"] = self.records.size()
if batch["table"] not in self.stats["tables"]:
self.stats["tables"][batch["table"]] = 0
self.stats["tables"][batch["table"]] += len(batch["records"])
self.stats["total_batches"] += 1
except Exception as exc:
logging.error("thread #%d: %s" % (id, exc))
time.sleep(self.processing_timeout)
sleep_backoff = 1 # keep velocity

def put_record(self, table_name: str, record: dict):
"""
Expand Down Expand Up @@ -179,26 +188,38 @@ def get_stats(self):
"""
return self.stats

def get_queue_size(self) -> int:
with self.stats_lock:
return self.stats["queue"]

def flush_queue(self):
"""
Flush all records from the queue.
Flush all records from the queue. Call this as you close the ingestion operation
https://github.com/xataio/xata-py/issues/184
"""
self.logger.debug("flushing queue with %d records .." % (self.records.size()))

# force flush the records queue and shorten the processing times
self.records.force_queue_flush()
self.processing_timeout = 0.001
wait = 0.005 * len(self.thread_workers)

while self.records.size() > 0:
self.logger.debug("flushing queue with %d records." % self.stats["queue"])
time.sleep(wait)
# back off to wait for all threads run at least once
if self.records.size() == 0 or self.get_queue_size() == 0:
time.sleep(self.processing_timeout * len(self.thread_workers))

# Last poor mans check if queue is fully flushed
if self.records.size() > 0 or self.stats["queue"] > 0:
self.logger.debug("one more flush interval necessary with queue at %d records." % self.stats["queue"])
time.sleep(wait)
# ensure the full records queue is flushed first
while self.records.size() != 0:
time.sleep(self.processing_timeout * len(self.thread_workers))
self.logger.debug("flushing %d records to processing queue." % self.records.size())

# let's make really sure the queue is empty
while self.get_queue_size() > 0:
time.sleep(self.processing_timeout * len(self.thread_workers))
self.logger.debug("flushing processor queue with %d records." % self.stats["queue"])

self.worker_active = False
for worker in self.thread_workers:
worker.join()

class Records(object):
"""
Expand All @@ -212,7 +233,6 @@ def __init__(self, batch_size: int, flush_interval: int, logger):
"""
self.batch_size = batch_size
self.flush_interval = flush_interval
self.force_flush = False
self.logger = logger

self.store = dict()
Expand All @@ -224,10 +244,8 @@ def force_queue_flush(self):
Force next batch to be available
https://github.com/xataio/xata-py/issues/184
"""
with self.lock:
self.force_flush = True
self.flush_interval = 0.001
self.batch_size = 1
# push for immediate flushes
self.flush_interval = 0

def put(self, table_name: str, records: list[dict]):
"""
Expand All @@ -250,38 +268,30 @@ def next_batch(self) -> dict:
:returns dict
"""
if self.size() == 0:
return {}
table_name = ""
with self.lock:
names = list(self.store.keys())
if len(names) == 0:
return {}

self.store_ptr += 1
if len(names) <= self.store_ptr:
self.store_ptr = 0
table_name = names[self.store_ptr]

rs = []
if self.length(table_name) == 0:
return {"table": table_name, "records": rs}

with self.store[table_name]["lock"]:
# flush interval exceeded
time_elapsed = time.time() - self.store[table_name]["flushed"]
flush_needed = time_elapsed > self.flush_interval
if flush_needed and len(self.store[table_name]["records"]) > 0:
self.logger.debug(
"flushing table '%s' with %d records after interval %s > %d"
% (
table_name,
len(self.store[table_name]["records"]),
time_elapsed,
self.flush_interval,
)
)
flush_needed = time_elapsed >= self.flush_interval
# force flush table, batch size reached or timer exceeded
if self.force_flush or len(self.store[table_name]["records"]) >= self.batch_size or flush_needed:
if len(self.store[table_name]["records"]) >= self.batch_size or flush_needed:
self.store[table_name]["flushed"] = time.time()
rs = self.store[table_name]["records"][0 : self.batch_size]
del self.store[table_name]["records"][0 : self.batch_size]
return {"table": table_name, "records": rs}
return {"table": table_name, "records": rs}

def length(self, table_name: str) -> int:
"""
Expand All @@ -296,8 +306,7 @@ def size(self) -> int:
"""
Get total size of stored records
"""
with self.lock:
return sum([len(self.store[n]["records"]) for n in self.store.keys()])
return sum([self.length(n) for n in self.store.keys()])


def to_rfc339(dt: datetime, tz=timezone.utc) -> str:
Expand Down

0 comments on commit 1eb8d7b

Please sign in to comment.