Skip to content

Commit

Permalink
fix: conform executors
Browse files Browse the repository at this point in the history
  • Loading branch information
andrzejnovak committed Mar 15, 2022
1 parent 2c2d537 commit 332d8cb
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 36 deletions.
69 changes: 35 additions & 34 deletions coffea/processor/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,11 +497,14 @@ def __call__(
if self.x509_proxy is None:
self.x509_proxy = _get_x509_proxy()

return work_queue_main(
items,
function,
accumulator,
**self.__dict__,
return (
work_queue_main(
items,
function,
accumulator,
**self.__dict__,
),
0,
)


Expand Down Expand Up @@ -545,7 +548,7 @@ def __call__(
desc=self.desc,
)
gen = map(function, gen)
return accumulate(gen, accumulator)
return accumulate(gen, accumulator), 0


@dataclass
Expand All @@ -568,9 +571,9 @@ class FuturesExecutor(ExecutorBase):
status : bool, optional
If true (default), enable progress bar
desc : str, optional
Label of progress bar unit (default: 'Processing')
Label of progress description (default: 'Processing')
unit : str, optional
Label of progress bar description (default: 'items')
Label of progress bar bar unit (default: 'items')
compression : int, optional
Compress accumulator outputs in flight with LZ4, at level specified (default 1)
Set to ``None`` for no compression.
Expand Down Expand Up @@ -603,7 +606,7 @@ class FuturesExecutor(ExecutorBase):
] = None
merging: Optional[Union[bool, Tuple[int, int, int]]] = False
workers: int = 1
recoverable: bool = True
recoverable: bool = False
tailtimeout: Optional[int] = None

def __getstate__(self):
Expand Down Expand Up @@ -639,9 +642,9 @@ def processwith(pool, mergepool):
_mdesc = (
"Merging"
if self.merging
else "Merging (main process, by batch)"
else "Merging (local)"
)
prog_id_merge = progress.add_task(_mdesc, total=1, unit="merges")
prog_id_merge = progress.add_task(_mdesc, total=0, unit="merges")

merged = None
while len(FH.futures) + len(FH.merges) > 0:
Expand Down Expand Up @@ -683,21 +686,16 @@ def processwith(pool, mergepool):
total=progress._tasks[prog_id_merge].total
+ len(batch),
),
_decompress(accumulator),
_decompress(merged),
),
self.compression,
)

progress.refresh()

# Add checkpointing

if self.merging:
progress.update(prog_id_merge, advance=1)
progress.refresh()
merged = FH.completed.pop().result()
if len(FH.completed) > 0:
raise RuntimeError("Not all futures are added.")
if len(FH.completed) > 0:
raise RuntimeError("Not all futures are added.")
return accumulate([_decompress(merged), accumulator]), 0

except Exception as e:
Expand Down Expand Up @@ -921,13 +919,16 @@ def belongsto(heavy_input, workerindex, item):

# FIXME: fancy widget doesn't appear, have to live with boring pbar
progress(work, multi=True, notebook=False)
return accumulate(
[
work.result()
if self.compression is None
else _decompress(work.result())
],
accumulator,
return (
accumulate(
[
work.result()
if self.compression is None
else _decompress(work.result())
],
accumulator,
),
0,
)
except KilledWorker as ex:
baditem = key_to_item[ex.task]
Expand All @@ -941,7 +942,7 @@ def belongsto(heavy_input, workerindex, item):
from distributed import progress

progress(work, multi=True, notebook=False)
return {"out": dd.from_delayed(work)}
return {"out": dd.from_delayed(work)}, 0


@dataclass
Expand Down Expand Up @@ -1029,7 +1030,7 @@ def __call__(
parsl.dfk().cleanup()
parsl.clear()

return accumulator
return accumulator, 0


class ParquetFileContext:
Expand Down Expand Up @@ -1581,14 +1582,14 @@ def __call__(
)
else:
processor_instance.postprocess(wrapped_out["out"])

_return = (wrapped_out["out"],)
if hasattr(self.executor, "recoverable") and self.executor.recoverable:
_return = *_return, list(wrapped_out["processed"])
if self.savemetrics and not self.use_dataframes:
wrapped_out["metrics"]["chunks"] = len(chunks)
return (
wrapped_out["out"],
list(wrapped_out["processed"]),
wrapped_out["metrics"],
)
return wrapped_out["out"], list(wrapped_out["processed"])
_return = *_return, wrapped_out["metrics"]
return _return if len(_return) > 1 else _return[0]


def run_spark_job(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def postprocess(self, accumulator):
proc.process,
None,
)
assert out == {"itemsum": 45}
assert out == ({"itemsum": 45}, 0)

class TestOldStyle(ProcessorABC):
@property
Expand All @@ -140,4 +140,4 @@ def postprocess(self, accumulator):
proc.process,
proc.accumulator,
)
assert out["itemsum"] == 45
assert out[0]["itemsum"] == 45

0 comments on commit 332d8cb

Please sign in to comment.