From a387cb2bfd2a0329fe0799f5c9d9e7d5f295bb53 Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Tue, 28 May 2024 11:53:49 +0800 Subject: [PATCH] Generate an event ID in `Batch` step (#525) --- storey/flow.py | 23 +++++++++++++++++++++-- storey/sources.py | 17 +---------------- tests/test_flow.py | 10 ++++++++-- 3 files changed, 30 insertions(+), 20 deletions(-) diff --git a/storey/flow.py b/storey/flow.py index 461a8cbb..37bef0af 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -19,6 +19,7 @@ import pickle import time import traceback +import uuid from asyncio import Task from collections import defaultdict from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor @@ -331,6 +332,20 @@ def _check_step_in_flow(self, type_to_check): return False +class WithUUID: + def __init__(self): + self._current_uuid_base = None + self._current_uuid_count = 0 + + def _get_uuid(self): + if not self._current_uuid_base or self._current_uuid_count == 1024: + self._current_uuid_base = uuid.uuid4().hex + self._current_uuid_count = 0 + result = f"{self._current_uuid_base}-{self._current_uuid_count:04}" + self._current_uuid_count += 1 + return result + + class Choice(Flow): """Redirects each input element into at most one of multiple downstreams. @@ -1153,7 +1168,7 @@ async def _emit_all(self): await self._emit_batch(key) -class Batch(_Batching): +class Batch(_Batching, WithUUID): """Batches events into lists of up to max_events events. Each emitted list contained max_events events, unless flush_after_seconds seconds have passed since the first event in the batch was received, at which the batch is emitted with potentially fewer than max_events event. @@ -1170,8 +1185,12 @@ class Batch(_Batching): _do_downstream_per_event = False + def __init__(self, *args, **kwargs): + _Batching.__init__(self, *args, **kwargs) + WithUUID.__init__(self) + async def _emit(self, batch, batch_key, batch_time, batch_events, last_event_time=None): - event = Event(batch) + event = Event(batch, id=self._get_uuid()) if not self._full_event: # Preserve reference to the original events to avoid early commit of offsets event._original_events = batch_events diff --git a/storey/sources.py b/storey/sources.py index b6c217da..8b642b60 100644 --- a/storey/sources.py +++ b/storey/sources.py @@ -19,7 +19,6 @@ import threading import time import traceback -import uuid import warnings import weakref from collections import defaultdict @@ -34,7 +33,7 @@ from nuclio_sdk import QualifiedOffset from .dtypes import Event, _termination_obj -from .flow import Complete, Flow +from .flow import Complete, Flow, WithUUID from .queue import SimpleAsyncQueue from .utils import find_filters, find_partitions, url_to_file_system @@ -94,20 +93,6 @@ def _convert_to_datetime(obj, time_format: Optional[str] = None): raise ValueError(f"Could not parse '{obj}' (of type {type(obj)}) as a time.") -class WithUUID: - def __init__(self): - self._current_uuid_base = None - self._current_uuid_count = 0 - - def _get_uuid(self): - if not self._current_uuid_base or self._current_uuid_count == 1024: - self._current_uuid_base = uuid.uuid4().hex - self._current_uuid_count = 0 - result = f"{self._current_uuid_base}-{self._current_uuid_count:04}" - self._current_uuid_count += 1 - return result - - class FlowControllerBase(WithUUID): def __init__( self, diff --git a/tests/test_flow.py b/tests/test_flow.py index fd8afc6c..bd484069 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -1768,7 +1768,7 @@ def test_batch(): [ SyncEmitSource(), Batch(4, 100), - Reduce([], lambda acc, x: append_and_return(acc, x)), + Reduce([], lambda acc, x: append_and_return(acc, x), full_event=True), ] ).run() @@ -1776,7 +1776,13 @@ def test_batch(): controller.emit(i) controller.terminate() termination_result = controller.await_termination() - assert termination_result == [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9]] + assert len(termination_result) == 3 + assert termination_result[0].id + assert termination_result[0].body == [0, 1, 2, 3] + assert termination_result[1].id + assert termination_result[1].body == [4, 5, 6, 7] + assert termination_result[2].id + assert termination_result[2].body == [8, 9] def test_batch_full_event():