diff --git a/pyproject.toml b/pyproject.toml index fa6d046..7b6cc1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "xata" -version = "1.2.1" +version = "1.2.2" description = "Python SDK for Xata.io" authors = ["Xata "] license = "Apache-2.0" diff --git a/tests/integration-tests/helpers_bulkprocessor_test.py b/tests/integration-tests/helpers_bulkprocessor_test.py index f15a14f..4bc5289 100644 --- a/tests/integration-tests/helpers_bulkprocessor_test.py +++ b/tests/integration-tests/helpers_bulkprocessor_test.py @@ -30,8 +30,7 @@ class TestHelpersBulkProcessor(object): def setup_class(self): self.db_name = utils.get_db_name() - self.branch_name = "main" - self.client = XataClient(db_name=self.db_name, branch_name=self.branch_name) + self.client = XataClient(db_name=self.db_name) self.fake = Faker() assert self.client.databases().create(self.db_name).is_success() @@ -39,27 +38,36 @@ def setup_class(self): assert self.client.table().create("Users").is_success() # create schema - assert self.client.table().set_schema( - "Posts", - { - "columns": [ - {"name": "title", "type": "string"}, - {"name": "text", "type": "text"}, - ] - }, - ).is_success() - assert self.client.table().set_schema( - "Users", - { - "columns": [ - {"name": "username", "type": "string"}, - {"name": "email", "type": "string"}, - ] - }, - ).is_success() + assert ( + self.client.table() + .set_schema( + "Posts", + { + "columns": [ + {"name": "title", "type": "string"}, + {"name": "text", "type": "text"}, + ] + }, + ) + .is_success() + ) + assert ( + self.client.table() + .set_schema( + "Users", + { + "columns": [ + {"name": "username", "type": "string"}, + {"name": "email", "type": "string"}, + ] + }, + ) + .is_success() + ) def teardown_class(self): - assert self.client.databases().delete(self.db_name).is_success() + # assert self.client.databases().delete(self.db_name).is_success() + pass @pytest.fixture def record(self) -> dict: @@ -70,7 +78,7 @@ def _get_record(self) -> dict: "title": self.fake.company(), "text": self.fake.text(), } - + def _get_user(self) -> dict: return { "username": self.fake.name(), @@ -81,20 +89,22 @@ def test_bulk_insert_records(self, record: dict): bp = BulkProcessor( self.client, thread_pool_size=1, + batch_size=43, ) bp.put_records("Posts", [self._get_record() for x in range(42)]) bp.flush_queue() - r = self.client.data().summarize("Posts", {"summaries": {"proof": {"count": "*"}}}) - assert r.is_success() - assert "summaries" in r - assert r["summaries"][0]["proof"] == 42 - stats = bp.get_stats() assert stats["total"] == 42 assert stats["queue"] == 0 assert stats["failed_batches"] == 0 assert stats["tables"]["Posts"] == 42 + assert stats["total_batches"] == 1 + + r = self.client.data().summarize("Posts", {"summaries": {"proof": {"count": "*"}}}) + assert r.is_success() + assert "summaries" in r + assert r["summaries"][0]["proof"] == stats["total"] def test_flush_queue(self): assert self.client.sql().query('DELETE FROM "Posts" WHERE 1 = 1').is_success() @@ -112,15 +122,40 @@ def test_flush_queue(self): assert r.is_success() assert "summaries" in r assert r["summaries"][0]["proof"] == 1000 - + stats = bp.get_stats() assert stats["total"] == 1000 assert stats["queue"] == 0 assert stats["failed_batches"] == 0 + assert stats["total_batches"] == 20 assert stats["tables"]["Posts"] == 1000 + def test_flush_queue_many_threads(self): + assert self.client.sql().query('DELETE FROM "Users" WHERE 1 = 1').is_success() + + bp = BulkProcessor( + self.client, + thread_pool_size=8, + batch_size=10, + ) + bp.put_records("Users", [self._get_user() for x in range(750)]) + bp.flush_queue() + + r = self.client.data().summarize("Users", {"summaries": {"proof": {"count": "*"}}}) + assert r.is_success() + assert "summaries" in r + assert r["summaries"][0]["proof"] == 750 + + stats = bp.get_stats() + assert stats["total"] == 750 + assert stats["queue"] == 0 + assert stats["failed_batches"] == 0 + assert stats["total_batches"] == 75 + assert stats["tables"]["Users"] == 750 + def test_multiple_tables(self): assert self.client.sql().query('DELETE FROM "Posts" WHERE 1 = 1').is_success() + assert self.client.sql().query('DELETE FROM "Users" WHERE 1 = 1').is_success() bp = BulkProcessor( self.client, @@ -141,10 +176,11 @@ def test_multiple_tables(self): assert r.is_success() assert "summaries" in r assert r["summaries"][0]["proof"] == 33 * 7 - + stats = bp.get_stats() assert stats["queue"] == 0 assert stats["failed_batches"] == 0 + assert stats["total_batches"] == 14 assert stats["tables"]["Posts"] == 33 * 9 assert stats["tables"]["Users"] == 33 * 7 assert stats["total"] == stats["tables"]["Posts"] + stats["tables"]["Users"] diff --git a/tests/unit-tests/helpers_bulk_processor_test.py b/tests/unit-tests/helpers_bulk_processor_test.py index 19b5c72..3eedf38 100644 --- a/tests/unit-tests/helpers_bulk_processor_test.py +++ b/tests/unit-tests/helpers_bulk_processor_test.py @@ -35,15 +35,15 @@ def test_bulk_processor_init(self): with pytest.raises(Exception) as e: BulkProcessor(client, batch_size=-1) - assert str(e.value) == "batch size can not be less than one, default: 25" + assert str(e.value) == "batch size can not be less than one, default: 50" with pytest.raises(Exception) as e: BulkProcessor(client, flush_interval=-1) - assert str(e.value) == "flush interval can not be negative, default: 5.000000" + assert str(e.value) == "flush interval can not be negative, default: 2.000000" with pytest.raises(Exception) as e: BulkProcessor(client, processing_timeout=-1) - assert str(e.value) == "processing timeout can not be negative, default: 0.025000" + assert str(e.value) == "processing timeout can not be negative, default: 0.050000" def test_bulk_processor_stats(self): client = XataClient(api_key="api_key", workspace_id="ws_id") diff --git a/xata/client.py b/xata/client.py index 5ec7c18..168968f 100644 --- a/xata/client.py +++ b/xata/client.py @@ -39,7 +39,7 @@ # TODO this is a manual task, to keep in sync with pyproject.toml # could/should be automated to keep in sync -__version__ = "1.2.1" +__version__ = "1.2.2" PERSONAL_API_KEY_LOCATION = "~/.config/xata/key" DEFAULT_DATA_PLANE_DOMAIN = "xata.sh" diff --git a/xata/helpers.py b/xata/helpers.py index b4fdd4b..b9cb601 100644 --- a/xata/helpers.py +++ b/xata/helpers.py @@ -27,11 +27,11 @@ from .client import XataClient BP_DEFAULT_THREAD_POOL_SIZE = 4 -BP_DEFAULT_BATCH_SIZE = 25 -BP_DEFAULT_FLUSH_INTERVAL = 5 -BP_DEFAULT_PROCESSING_TIMEOUT = 0.025 +BP_DEFAULT_BATCH_SIZE = 50 +BP_DEFAULT_FLUSH_INTERVAL = 2 +BP_DEFAULT_PROCESSING_TIMEOUT = 0.05 BP_DEFAULT_THROW_EXCEPTION = False -BP_VERSION = "0.3.0" +BP_VERSION = "0.3.1" TRX_MAX_OPERATIONS = 1000 TRX_VERSION = "0.1.0" TRX_BACKOFF = 0.1 @@ -85,10 +85,13 @@ def __init__( self.flush_interval = flush_interval self.failed_batches_queue = [] self.throw_exception = throw_exception - self.stats = {"total": 0, "queue": 0, "failed_batches": 0, "tables": {}} + + self.stats = {"total": 0, "queue": 0, "failed_batches": 0, "total_batches": 0, "tables": {}} + self.stats_lock = Lock() self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}") self.thread_workers = [] + self.worker_active = True self.records = self.Records(self.batch_size, self.flush_interval, self.logger) for i in range(thread_pool_size): @@ -110,12 +113,16 @@ def process(self, id: int): self.processing_timeout, ) ) - while True: + while self.worker_active: + sleep_backoff = 5 # slow down if no records exist + time.sleep(self.processing_timeout * sleep_backoff) + + # process batch = self.records.next_batch() if "table" in batch and len(batch["records"]) > 0: try: r = self.client.records().bulk_insert(batch["table"], {"records": batch["records"]}) - if r.status_code != 200: + if not r.is_success(): self.logger.error( "thread #%d: unable to process batch for table '%s', with error: %d - %s" % (id, batch["table"], r.status_code, r.json()) @@ -137,14 +144,16 @@ def process(self, id: int): "thread #%d: pushed a batch of %d records to table %s" % (id, len(batch["records"]), batch["table"]) ) + # with self.stats_lock: self.stats["total"] += len(batch["records"]) self.stats["queue"] = self.records.size() if batch["table"] not in self.stats["tables"]: self.stats["tables"][batch["table"]] = 0 self.stats["tables"][batch["table"]] += len(batch["records"]) + self.stats["total_batches"] += 1 except Exception as exc: logging.error("thread #%d: %s" % (id, exc)) - time.sleep(self.processing_timeout) + sleep_backoff = 1 # keep velocity def put_record(self, table_name: str, record: dict): """ @@ -179,26 +188,38 @@ def get_stats(self): """ return self.stats + def get_queue_size(self) -> int: + with self.stats_lock: + return self.stats["queue"] + def flush_queue(self): """ - Flush all records from the queue. + Flush all records from the queue. Call this as you close the ingestion operation + https://github.com/xataio/xata-py/issues/184 """ self.logger.debug("flushing queue with %d records .." % (self.records.size())) # force flush the records queue and shorten the processing times self.records.force_queue_flush() - self.processing_timeout = 0.001 - wait = 0.005 * len(self.thread_workers) - while self.records.size() > 0: - self.logger.debug("flushing queue with %d records." % self.stats["queue"]) - time.sleep(wait) + # back off to wait for all threads run at least once + if self.records.size() == 0 or self.get_queue_size() == 0: + time.sleep(self.processing_timeout * len(self.thread_workers)) - # Last poor mans check if queue is fully flushed - if self.records.size() > 0 or self.stats["queue"] > 0: - self.logger.debug("one more flush interval necessary with queue at %d records." % self.stats["queue"]) - time.sleep(wait) + # ensure the full records queue is flushed first + while self.records.size() != 0: + time.sleep(self.processing_timeout * len(self.thread_workers)) + self.logger.debug("flushing %d records to processing queue." % self.records.size()) + + # let's make really sure the queue is empty + while self.get_queue_size() > 0: + time.sleep(self.processing_timeout * len(self.thread_workers)) + self.logger.debug("flushing processor queue with %d records." % self.stats["queue"]) + + self.worker_active = False + for worker in self.thread_workers: + worker.join() class Records(object): """ @@ -212,7 +233,6 @@ def __init__(self, batch_size: int, flush_interval: int, logger): """ self.batch_size = batch_size self.flush_interval = flush_interval - self.force_flush = False self.logger = logger self.store = dict() @@ -224,10 +244,8 @@ def force_queue_flush(self): Force next batch to be available https://github.com/xataio/xata-py/issues/184 """ - with self.lock: - self.force_flush = True - self.flush_interval = 0.001 - self.batch_size = 1 + # push for immediate flushes + self.flush_interval = 0 def put(self, table_name: str, records: list[dict]): """ @@ -250,38 +268,30 @@ def next_batch(self) -> dict: :returns dict """ + if self.size() == 0: + return {} table_name = "" with self.lock: names = list(self.store.keys()) - if len(names) == 0: - return {} - self.store_ptr += 1 if len(names) <= self.store_ptr: self.store_ptr = 0 table_name = names[self.store_ptr] rs = [] + if self.length(table_name) == 0: + return {"table": table_name, "records": rs} + with self.store[table_name]["lock"]: # flush interval exceeded time_elapsed = time.time() - self.store[table_name]["flushed"] - flush_needed = time_elapsed > self.flush_interval - if flush_needed and len(self.store[table_name]["records"]) > 0: - self.logger.debug( - "flushing table '%s' with %d records after interval %s > %d" - % ( - table_name, - len(self.store[table_name]["records"]), - time_elapsed, - self.flush_interval, - ) - ) + flush_needed = time_elapsed >= self.flush_interval # force flush table, batch size reached or timer exceeded - if self.force_flush or len(self.store[table_name]["records"]) >= self.batch_size or flush_needed: + if len(self.store[table_name]["records"]) >= self.batch_size or flush_needed: self.store[table_name]["flushed"] = time.time() rs = self.store[table_name]["records"][0 : self.batch_size] del self.store[table_name]["records"][0 : self.batch_size] - return {"table": table_name, "records": rs} + return {"table": table_name, "records": rs} def length(self, table_name: str) -> int: """ @@ -296,8 +306,7 @@ def size(self) -> int: """ Get total size of stored records """ - with self.lock: - return sum([len(self.store[n]["records"]) for n in self.store.keys()]) + return sum([self.length(n) for n in self.store.keys()]) def to_rfc339(dt: datetime, tz=timezone.utc) -> str: