diff --git a/coffea/processor/executor.py b/coffea/processor/executor.py index 5c64d7ced..32edad41d 100644 --- a/coffea/processor/executor.py +++ b/coffea/processor/executor.py @@ -151,7 +151,7 @@ def chunks(self, target_chunksize, align_clusters, dynamic_chunksize): return target_chunksize -@dataclass(unsafe_hash=True) +@dataclass(unsafe_hash=True, frozen=True) class WorkItem: dataset: str filename: str @@ -216,6 +216,42 @@ def __call__(self, items): return accumulate(items) +class FuturesHolder: + def __init__(self, futures, refresh=2): + self.futures = set(futures) + self.merges = set() + self.completed = set() + self.done = {"futures": 0, "merges": 0} + self.refresh = refresh + + def update(self, refresh=None): + if refresh is None: + refresh = self.refresh + if self.futures: + completed, self.futures = concurrent.futures.wait( + self.futures, + timeout=refresh, + return_when=concurrent.futures.FIRST_COMPLETED, + ) + self.completed.update(completed) + self.done["futures"] += len(completed) + + if self.merges: + completed, self.merges = concurrent.futures.wait( + self.merges, + timeout=refresh, + return_when=concurrent.futures.FIRST_COMPLETED, + ) + self.completed.update(completed) + self.done["merges"] += len(completed) + + def fetch(self, N): + return [ + self.completed.pop().result() for _ in range(N) if len(self.completed) > 0 + ] + + + def _futures_handler(futures, timeout): """Essentially the same as concurrent.futures.as_completed but makes sure not to hold references to futures any longer than strictly necessary, @@ -523,13 +559,18 @@ class FuturesExecutor(ExecutorBase): Number of parallel processes for futures (default 1) status : bool, optional If true (default), enable progress bar - unit : str, optional - Label of progress bar unit (default: 'Processing') desc : str, optional + Label of progress bar unit (default: 'Processing') + unit : str, optional Label of progress bar description (default: 'items') compression : int, optional Compress accumulator outputs in flight with LZ4, at level specified (default 1) Set to ``None`` for no compression. + recoverable : bool + merging : bool | tuple(int, int, int), optional + Enables intermediate merges in jobs. Format is (n_batches, min_batch_size, max_batch_size) + Passing ``True`` will use default: (5, 4, 100) + checkpoints : bool tailtimeout : int, optional Timeout requirement on job tails. Cancel all remaining jobs if none have finished in the timeout window. @@ -537,6 +578,8 @@ class FuturesExecutor(ExecutorBase): pool: Union[Callable[..., concurrent.futures.Executor], concurrent.futures.Executor] = concurrent.futures.ProcessPoolExecutor # fmt: skip workers: int = 1 + recoverable: bool = True + merging: bool = False tailtimeout: Optional[int] = None def __getstate__(self): @@ -553,23 +596,94 @@ def __call__( if self.compression is not None: function = _compression_wrapper(self.compression, function) + def merge_tqdm(chunks, accumulator, desc="Adding"): + gen = (c for c in chunks) + return accumulate( + tqdm( + gen if self.compression is None else map(_decompress, gen), + disable=not self.status, + unit=self.unit, + total=len(chunks), + desc=desc, + ), + accumulator, + ) + def processwith(pool): - gen = _futures_handler( - {pool.submit(function, item) for item in items}, self.tailtimeout + reducer = _reduce(self.compression) + + FH = FuturesHolder( + set(pool.submit(function, item) for item in items), refresh=2 ) + + if isinstance(self.merging, tuple) and len(self.merging) == 3: + nparts, minred, maxred = self.merging + else: + nparts, minred, maxred = 5, 4, 100 + try: - return accumulate( - tqdm( - gen if self.compression is None else map(_decompress, gen), - disable=not self.status, - unit=self.unit, - total=len(items), - desc=self.desc, - ), - accumulator, - ) - finally: - gen.close() + pbar = tqdm(disable=not self.status, unit=self.unit, total=len(items), + desc=self.desc, position=0, ascii=True) + if self.merging: + mbar = tqdm(disable=not self.status, total=1, desc="Merging", + position=1, ascii=True) + + while len(FH.futures) + len(FH.merges) > 0: + FH.update() + reduce = min(maxred, max(len(FH.completed) // nparts + 1, minred)) + + pbar.update(FH.done["futures"] - pbar.n) + pbar.refresh() + + if self.merging: + mbar.update(FH.done["merges"] - mbar.n) + mbar.refresh() + while len(FH.completed) > 1: + batch = [b for b in FH.fetch(reduce)] + FH.merges.add(pool.submit(reducer, batch)) + mbar.total += 1 + mbar.refresh() + + # Add checkpointing + # if FH.done["futures"]% 100 == 0: + # accumulate([future.result() for future in FH.completed]) + # dump to pickle + + pbar.close() + if self.merging: + mbar.update(1) # last one + mbar.refresh() + mbar.close() + merged = FH.completed.pop().result() + if len(FH.completed) > 0: + raise RuntimeError("Not all futures are added.") + else: + merged = reducer(FH.fetch(len(FH.completed))) + + return reducer([merged, accumulator]), 0 + + except Exception as e: + if self.recoverable: + print(f"Exception '{type(e)}' occured, recovering progress...") + for job in FH.futures: + job.cancel() + + if self.merging: + with tqdm(disable=not self.status, total=len(FH.merges), desc="Recovering finished jobs", position=1, ascii=True) as mbar: + while len(FH.merges) > 0: + FH.update() + mbar.update(FH.done["merges"] - mbar.n) + mbar.refresh() + + def is_good(future): + return future.done() and not future.cancelled() and future.exception() is None + + FH.update() + recovered = [future.result() for future in FH.completed if is_good(future)] + + return merge_tqdm(recovered, accumulator, desc="Merging finished jobs"), e + else: + raise type(e)(str(e)).with_traceback(sys.exc_info()[2]) from None if isinstance(self.pool, concurrent.futures.Executor): return processwith(pool=self.pool) @@ -1078,7 +1192,7 @@ def _preprocess_fileset(self, fileset: Dict) -> None: self.skipbadfiles, partial(self.metadata_fetcher, self.xrootdtimeout, self.align_clusters), ) - out = pre_executor(to_get, closure, out) + out, _ = pre_executor(to_get, closure, out) while out: item = out.pop() self.metadata_cache[item] = item.metadata @@ -1257,14 +1371,15 @@ def _work_function( metrics["columns"] = set(events.materialized) metrics["entries"] = events.size metrics["processtime"] = toc - tic - return {"out": out, "metrics": metrics} - return {"out": out} + return {"out": out, "metrics": metrics, "processed": set([item])} + return {"out": out, "processed": set([item])} def __call__( self, fileset: Dict, treename: str, processor_instance: ProcessorABC, + prepro_only=False, ) -> Accumulatable: """Run the processor_instance on a given fileset @@ -1281,26 +1396,35 @@ def __call__( An instance of a class deriving from ProcessorABC """ + meta = False if not isinstance(fileset, (Mapping, str)): - raise ValueError( - "Expected fileset to be a mapping dataset: list(files) or filename" - ) + if isinstance(fileset[0], WorkItem): + meta = True + else: + raise ValueError( + "Expected fileset to be a mapping dataset: list(files) or filename" + ) if not isinstance(processor_instance, ProcessorABC): raise ValueError("Expected processor_instance to derive from ProcessorABC") - if self.format == "root": - fileset = list(self._normalize_fileset(fileset, treename)) - for filemeta in fileset: - filemeta.maybe_populate(self.metadata_cache) + if meta: + chunks = fileset + else: + if self.format == "root": + fileset = list(self._normalize_fileset(fileset, treename)) + for filemeta in fileset: + filemeta.maybe_populate(self.metadata_cache) - self._preprocess_fileset(fileset) - fileset = self._filter_badfiles(fileset) + self._preprocess_fileset(fileset) + fileset = self._filter_badfiles(fileset) - # reverse fileset list to match the order of files as presented in version - # v0.7.4. This fixes tests using maxchunks. - fileset.reverse() + # reverse fileset list to match the order of files as presented in version + # v0.7.4. This fixes tests using maxchunks. + fileset.reverse() - chunks = self._chunk_generator(fileset, treename) + chunks = self._chunk_generator(fileset, treename) + if prepro_only: + return [c for c in chunks] if self.processor_compression is None: pi_to_send = processor_instance @@ -1363,13 +1487,20 @@ def __call__( ) executor = self.executor.copy(**exe_args) - wrapped_out = executor(chunks, closure, None) - processor_instance.postprocess(wrapped_out["out"]) + wrapped_out, e = executor(chunks, closure, None) + if e != 0: + # print(type(e[0])(str(e[0])).with_traceback(e[1])) + print(type(e)(str(e))) + + if wrapped_out is None: + raise ValueError("No chunks were processed, veryify ``processor`` instance structure.") + else: + processor_instance.postprocess(wrapped_out["out"]) if self.savemetrics and not self.use_dataframes: wrapped_out["metrics"]["chunks"] = len(chunks) - return wrapped_out["out"], wrapped_out["metrics"] - return wrapped_out["out"] + return wrapped_out["out"], list(wrapped_out['processed']), wrapped_out["metrics"] + return wrapped_out["out"], list(wrapped_out['processed']) def run_spark_job(