Skip to content

Commit

Permalink
Merge pull request #17 from yohplala/pre_buffer
Browse files Browse the repository at this point in the history
Pre buffer
  • Loading branch information
yohplala authored Jun 16, 2024
2 parents f0649a2 + 1836b59 commit b627175
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 86 deletions.
68 changes: 42 additions & 26 deletions oups/aggstream/aggstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@
ACCEPTED_AGG_FUNC = {FIRST, LAST, MIN, MAX, SUM}
# List of keys.
KEY_AGGSTREAM = "aggstream"
KEY_PRE = "pre"
KEY_PRE_BUFFER = "pre_buffer"
KEY_SEGAGG_BUFFER = "segagg_buffer"
KEY_POST_BUFFER = "post_buffer"
KEY_BIN_ON_OUT = "bin_on_out"
KEY_AGG_RES_BUFFER = "agg_res_buffer"
KEY_BIN_RES_BUFFER = "bin_res_buffer"
KEY_PRE = "pre"
KEY_PRE_BUFFER = "pre_buffer"
KEY_FILTERS = "filters"
KEY_RESTART_INDEX = "restart_index"
KEY_AGG_RES = "agg_res"
Expand Down Expand Up @@ -233,12 +233,12 @@ def _init_keys_config(
return consolidated_keys_config, agg_pd


def _init_agg_buffers(
def _init_buffers(
store: ParquetSet,
keys: dict,
):
"""
Initialize aggregation buffers into ``agg_buffers``.
Initialize pre, aggregation and post buffers from existing results.
Also set ``seed_index_restart``.
Expand All @@ -256,6 +256,8 @@ def _init_agg_buffers(
- ``seed_index_restart``, int, float or pTimestamp, the index
from which (included) should be restarted the nex aggregation
iteration.
- ``pre_buffer``, dict, user-defined buffer to keep track of intermediate
variables between successive pre-processing of individual seed chunk.
- ``agg_buffers``, dict of aggregation buffer variables specific for each
key, in the form:
``{key: {'agg_n_rows' : 0,
Expand All @@ -269,6 +271,7 @@ def _init_agg_buffers(
}``
"""
pre_buffer = {}
agg_buffers = {}
seed_index_restart_set = set()
for key in keys:
Expand All @@ -293,13 +296,16 @@ def _init_agg_buffers(
)
aggstream_md = prev_agg_res._oups_metadata[KEY_AGGSTREAM]
# - 'last_seed_index' to trim accordingly head of seed data.
# - metadata related to pre-processing of individual seed chunk.
# - metadata related to binning process from past binnings
# on prior data. It is used in case 'bin_by' is a callable.
# If not used, it is an empty dict.
# - metadata related to post-processing of prior
# aggregation results, to be used by 'post'. If not used,
# it is an empty dict.
seed_index_restart_set.add(aggstream_md[KEY_RESTART_INDEX])
if KEY_PRE_BUFFER in aggstream_md:
pre_buffer = aggstream_md[KEY_PRE_BUFFER]
agg_buffers[key][KEY_SEGAGG_BUFFER] = (
aggstream_md[KEY_SEGAGG_BUFFER] if aggstream_md[KEY_SEGAGG_BUFFER] else {}
)
Expand All @@ -312,11 +318,12 @@ def _init_agg_buffers(

if len(seed_index_restart_set) > 1:
raise ValueError(
"not possible to aggregate on multiple keys with existing"
" aggregation results not aggregated up to the same seed index.",
"not possible to aggregate on multiple keys with existing "
"aggregation results not aggregated up to the same seed index.",
)
return (
None if not seed_index_restart_set else seed_index_restart_set.pop(),
pre_buffer,
agg_buffers,
)

Expand Down Expand Up @@ -425,7 +432,7 @@ def _iter_data(
the next. Its initial value is that provided by `pre_buffer`.
In-place modifications of seed dataframe has to be carried out here.
pre_buffer : dict or None
pre_buffer : dict
Buffer to keep track of intermediate data that can be required for
proceeding with pre of individual seed item.
filters : dict or None
Expand Down Expand Up @@ -453,6 +460,9 @@ def _iter_data(
- 'last_seed_index', Union[int, float, pTimestamp], the last seed
index value (likely of an incomplete group), of the current seed
chunk, before filters are applied.
- 'pre_buffer' : dict, buffer to keep track of intermediate data that
can be required for proceeding with preprocessing of individual seed
chunk.
- 'filter_id', str, indicating which set of filters has been
applied for the seed chunk provided.
- 'filtered_chunk', pDataFrame, from the seed Iterable, with
Expand All @@ -466,8 +476,9 @@ def _iter_data(
Reasons to discard last seed row (or row group) may be twofold:
- last row is temporary (yet to get some final values, for instance
if seed data is some kind of aggregation stream itself),
- last rows are part of a single row group not yet complete itself
(new rows part of this row group to be expected).
- last rows are part of a single row group 'same index value in
'ordered_on')not yet complete itself (new rows part of this row group
to be expected).
"""
if restart_index is None:
Expand All @@ -489,7 +500,7 @@ def _iter_data(
if pre:
# Apply user checks.
try:
pre(seed_chunk, pre_buffer)
pre(on=seed_chunk, buffer=pre_buffer)
except Exception as e:
# Stop iteration in case of failing pre.
# Aggregation has been run up to the last valid chunk.
Expand Down Expand Up @@ -531,10 +542,12 @@ def _iter_data(
continue
elif filter_array_loc.all():
# If filter only contains 1, simply return full seed chunk.
yield last_seed_index, filt_id, seed_chunk
yield last_seed_index, pre_buffer, filt_id, seed_chunk
else:
# Otherwise, filter.
yield last_seed_index, filt_id, seed_chunk.loc[filter_array_loc].reset_index(
yield last_seed_index, pre_buffer, filt_id, seed_chunk.loc[
filter_array_loc
].reset_index(
drop=True,
)

Expand Down Expand Up @@ -589,6 +602,7 @@ def _post_n_write_agg_chunks(
index_name: Optional[str] = None,
post: Optional[Callable] = None,
last_seed_index: Optional[Union[int, float, pTimestamp]] = None,
pre_buffer: Optional[dict] = None,
):
"""
Write list of aggregation row groups with optional post.
Expand Down Expand Up @@ -666,6 +680,9 @@ def _post_n_write_agg_chunks(
Last index in seed data. Can be numeric type, timestamp... (for
recording in metadata of aggregation results)
Writing metadata is triggered ONLY if ``last_seed_index`` is provided.
pre_buffer : dict or None
Buffer to keep track of intermediate data that can be required for
proceeding with preprocessing of individual seed chunk.
"""
if (agg_res := agg_buffers[KEY_AGG_RES]) is None:
Expand All @@ -674,10 +691,9 @@ def _post_n_write_agg_chunks(
# If 'last_seed_index', at least set it in oups metadata.
# It is possible new seed data has been streamed and taken into
# account, but used for this key, because having been filtered out.
# Also set 'pre_buffer' if not null.
OUPS_METADATA[key] = {
KEY_AGGSTREAM: {
KEY_RESTART_INDEX: last_seed_index,
},
KEY_AGGSTREAM: {KEY_RESTART_INDEX: last_seed_index, KEY_PRE_BUFFER: pre_buffer},
}
write_metadata(pf=store[key].pf, md_key=key)
return
Expand All @@ -703,9 +719,11 @@ def _post_n_write_agg_chunks(
)
if last_seed_index:
# If 'last_seed_index', set oups metadata.
# Also set 'pre_buffer' if not null.
OUPS_METADATA[key] = {
KEY_AGGSTREAM: {
KEY_RESTART_INDEX: last_seed_index,
KEY_PRE_BUFFER: pre_buffer,
KEY_SEGAGG_BUFFER: agg_buffers[KEY_SEGAGG_BUFFER],
KEY_POST_BUFFER: post_buffer,
},
Expand Down Expand Up @@ -816,7 +834,8 @@ class AggStream:
aggregation iteration.
'pre' : Callable, to apply user-defined pre-processing on seed.
'pre_buffer' : dict, to keep track of intermediate values for
`pre` function.
proceeding with pre-processing of individual seed
items (by `pre` function).
'filters' : dict, as per `filters` parameter.
}``
- ``self.store``, oups store, as per `store` parameter.
Expand Down Expand Up @@ -872,7 +891,6 @@ def __init__(
store: ParquetSet,
keys: Union[dataclass, dict],
pre: Optional[Callable] = None,
pre_buffer: Optional[dict] = None,
filters: Optional[dict] = None,
agg: Optional[dict] = None,
bin_by: Optional[Union[TimeGrouper, Callable[[Series, dict], tuple]]] = None,
Expand Down Expand Up @@ -958,12 +976,6 @@ def __init__(
Modification of seed chunk, if any, has to be realized in-place.
No DataFrame returned by this function is expected.
pre_buffer : dict, default None
Buffer to keep track of intermediate data that can be required for
proceeding with pre-processing of individual seed item.
Once aggregation stream is over, its value is not recorded.
User has to take care of this if needed. Its value can be
retrieved with ``self.pre_buffer`` object attribute.
filters : Union[dict, None], default None
Dict in the form
``{"filter_id":[[("col", op, val), ...], ...]}``
Expand Down Expand Up @@ -1209,12 +1221,13 @@ def __init__(
) = _init_keys_config(ordered_on, keys, keys_default)
(
restart_index,
pre_buffer,
self.agg_buffers,
) = _init_agg_buffers(store, keys)
) = _init_buffers(store, keys)
self.seed_config = {
KEY_ORDERED_ON: ordered_on,
KEY_PRE: pre,
KEY_PRE_BUFFER: {} if pre_buffer is None else pre_buffer,
KEY_PRE_BUFFER: pre_buffer,
KEY_FILTERS: filters,
KEY_RESTART_INDEX: restart_index,
}
Expand Down Expand Up @@ -1350,7 +1363,7 @@ def agg(
seed = self._init_agg_cs(seed)
seed_check_exception = False
try:
for _last_seed_index, filter_id, filtered_chunk in _iter_data(
for _last_seed_index, _pre_buffer, filter_id, filtered_chunk in _iter_data(
seed=seed,
**self.seed_config,
trim_start=trim_start,
Expand All @@ -1374,6 +1387,8 @@ def agg(
# Set 'seed_index_restart' to the 'last_seed_index' with
# which restarting the next aggregation iteration.
self.seed_config[KEY_RESTART_INDEX] = _last_seed_index
# Also keep track of last 'pre_buffer' value.
self.seed_config[KEY_PRE_BUFFER] = _pre_buffer
except SeedPreException as sce:
seed_check_exception = True
exception_message = str(sce)
Expand All @@ -1391,6 +1406,7 @@ def agg(
index_name=self.keys_config[key][KEY_BIN_ON_OUT],
post=self.keys_config[key][KEY_POST],
last_seed_index=self.seed_config[KEY_RESTART_INDEX],
pre_buffer=self.seed_config[KEY_PRE_BUFFER],
)
for key, agg_res in self.agg_buffers.items()
)
Expand Down
Loading

0 comments on commit b627175

Please sign in to comment.