Skip to content

Commit

Permalink
First pass update gather methods to prevent task leaks on error
Browse files Browse the repository at this point in the history
This update replaces instances of gather() with (default behavior) return_exceptions=false since only the future that raises the exception will be stopped. TaskGroup ensures all taks complete or are cancelled before the with block is exited. This update is also applied to gather()s that are wrapped by task leak handling identical to TaskGroup to reduce the boilerplate added by said handling

This update is not applied to gather invocations that need to return (all) exceptions since that is not supported by TaskGroup. It is also not applied to gathers used as a shorthand for try catch pass logic (except for a case where logging was improved)

NOTE we avoid the use of create_task in generators in favor of list comprehensions which call create_task right away instead of lazily (doesn't matter for task result)

NOTE TaskGroup also prevents the task from being garbage collected which allows cleanup of variables for whose only purpose was ensuring the lifetime of the task reference matched the scope of the function
  • Loading branch information
linkous8 committed Apr 23, 2024
1 parent 3c1a3f1 commit 0629476
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 136 deletions.
6 changes: 3 additions & 3 deletions servo/assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ async def assemble(
)

# Attach all connectors to the servo
await asyncio.gather(
*list(map(lambda s: s.dispatch_event(servo.servo.Events.attach, s), servos))
)
async with asyncio.TaskGroup() as tg:
for s in servos:
_ = tg.create_task(s.dispatch_event(servo.servo.Events.attach, s))

return assembly

Expand Down
14 changes: 6 additions & 8 deletions servo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,14 +1123,12 @@ def print_callback(input: str) -> None:
# gather() expects a loop to exist at invocation time which is not compatible with the run_async
# execution model. wrap the gather in a standard async function to work around this
async def gather_checks():
return await asyncio.gather(
*list(
map(
lambda s: s.check_servo(print_callback),
context.assembly.servos,
)
),
)
async with asyncio.TaskGroup() as tg:
tasks = [
tg.create_task(s.check_servo(print_callback))
for s in context.assembly.servos
]
return (t.result() for t in tasks)

results = run_async(gather_checks())
ready = functools.reduce(lambda x, y: x and y, results)
Expand Down
108 changes: 44 additions & 64 deletions servo/connectors/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,27 +1141,15 @@ async def create_tuning_pod(self) -> V1Pod:
)
)
progress.start()

task = asyncio.create_task(PodHelper.wait_until_ready(tuning_pod))
task.add_done_callback(lambda _: progress.complete())
gather_task = asyncio.gather(
task,
progress.watch(progress_logger),
)

try:
await asyncio.wait_for(gather_task, timeout=self.timeout.total_seconds())
async with asyncio.timeout(delay=self.timeout.total_seconds()):
async with asyncio.TaskGroup() as tg:
task = tg.create_task(PodHelper.wait_until_ready(tuning_pod))
task.add_done_callback(lambda _: progress.complete())
_ = tg.create_task(progress.watch(progress_logger))

except asyncio.TimeoutError:
servo.logger.error(f"Timed out waiting for Tuning Pod to become ready...")
servo.logger.debug(f"Cancelling Task: {task}, progress: {progress}")
for t in {task, gather_task}:
t.cancel()
with contextlib.suppress(asyncio.CancelledError):
await t
servo.logger.debug(f"Cancelled Task: {t}, progress: {progress}")

# get latest status of tuning pod for raise_for_status
await self.raise_for_status()

# Hydrate local state
Expand Down Expand Up @@ -1631,49 +1619,46 @@ async def apply(self, adjustments: List[servo.Adjustment]) -> None:
# TODO: Run sanity checks to look for out of band changes

async def raise_for_status(self) -> None:
handle_error_tasks = []

def _raise_for_task(task: asyncio.Task, optimization: BaseOptimization) -> None:
if task.done() and not task.cancelled():
if exception := task.exception():
handle_error_tasks.append(
asyncio.create_task(optimization.handle_error(exception))
)

tasks = []
for optimization in self.optimizations:
task = asyncio.create_task(optimization.raise_for_status())
task.add_done_callback(
functools.partial(_raise_for_task, optimization=optimization)
)
tasks.append(task)

for future in asyncio.as_completed(
tasks, timeout=self.config.timeout.total_seconds()
):
try:
await future
except Exception as error:
servo.logger.exception(f"Optimization failed with error: {error}")
# TODO: first handle_error_task to raise will likely interrupt other tasks.
# Gather with return_exceptions=True and aggregate resulting exceptions into group before raising
async with asyncio.TaskGroup() as tg:

def _raise_for_task(
task: asyncio.Task, optimization: BaseOptimization
) -> None:
if task.done() and not task.cancelled():
if exception := task.exception():
_ = tg.create_task(optimization.handle_error(exception))

tasks = []
for optimization in self.optimizations:
task = asyncio.create_task(optimization.raise_for_status())
task.add_done_callback(
functools.partial(_raise_for_task, optimization=optimization)
)
tasks.append(task)

# TODO: first handler to raise will likely interrupt other tasks.
# Gather with return_exceptions=True and aggregate resulting exceptions before raising
await asyncio.gather(*handle_error_tasks)
for future in asyncio.as_completed(
tasks, timeout=self.config.timeout.total_seconds()
):
try:
await future
except Exception as error:
servo.logger.exception(f"Optimization failed with error: {error}")

async def is_ready(self):
if self.optimizations:
self.logger.debug(
f"Checking for readiness of {len(self.optimizations)} optimizations"
)
try:
results = await asyncio.wait_for(
asyncio.gather(
*list(map(lambda a: a.is_ready(), self.optimizations)),
),
timeout=self.config.timeout.total_seconds(),
)
async with asyncio.timeout(delay=self.config.timeout.total_seconds()):
async with asyncio.TaskGroup() as tg:
results = [
tg.create_task(o.is_ready()) for o in self.optimizations
]

return all(results)
return all((r.result() for r in results))

except asyncio.TimeoutError:
return False
Expand Down Expand Up @@ -2297,15 +2282,13 @@ async def adjust(
progress=p.progress,
)
progress = servo.EventProgress(timeout=self.config.timeout)
future = asyncio.create_task(state.apply(adjustments))
future.add_done_callback(lambda _: progress.trigger())

# Catch-all for spaghettified non-EventError usage
try:
await asyncio.gather(
future,
progress.watch(progress_logger),
)
async with asyncio.TaskGroup() as tg:
future = tg.create_task(state.apply(adjustments))
future.add_done_callback(lambda _: progress.trigger())
_ = tg.create_task(progress.watch(progress_logger))

# Handle settlement
settlement = control.settlement or self.config.settlement
Expand Down Expand Up @@ -2383,13 +2366,10 @@ async def _create_optimizations(self) -> KubernetesOptimizations:
)
progress = servo.EventProgress(timeout=self.config.timeout)
try:
future = asyncio.create_task(KubernetesOptimizations.create(self.config))
future.add_done_callback(lambda _: progress.trigger())

await asyncio.gather(
future,
progress.watch(progress_logger),
)
async with asyncio.TaskGroup() as tg:
future = tg.create_task(KubernetesOptimizations.create(self.config))
future.add_done_callback(lambda _: progress.trigger())
_ = tg.create_task(progress.watch(progress_logger))

return future.result()
except Exception as e:
Expand Down
35 changes: 17 additions & 18 deletions servo/connectors/prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,28 +994,25 @@ async def measure(
),
)
fast_fail_progress = servo.EventProgress(timeout=measurement_duration)
gather_tasks = [
asyncio.create_task(progress.watch(self.observe)),
asyncio.create_task(
async with asyncio.TaskGroup() as tg:
_ = tg.create_task(progress.watch(self.observe))
_ = tg.create_task(
fast_fail_progress.watch(
fast_fail_observer.observe, every=self.config.fast_fail.period
fast_fail_observer.observe,
every=self.config.fast_fail.period,
)
),
]
try:
await asyncio.gather(*gather_tasks)
except:
[task.cancel() for task in gather_tasks]
await asyncio.gather(*gather_tasks, return_exceptions=True)
raise
)

else:
await progress.watch(self.observe)

# Capture the measurements
self.logger.info(f"Querying Prometheus for {len(metrics__)} metrics...")
readings = await asyncio.gather(
*list(map(lambda m: self._query_prometheus(m, start, end), metrics__))
)
async with asyncio.TaskGroup() as tg:
q_tasks = [
tg.create_task(self._query_prometheus(m, start, end)) for m in metrics__
]
readings = (qt.result() for qt in q_tasks)
all_readings = (
functools.reduce(lambda x, y: x + y, readings) if readings else []
)
Expand Down Expand Up @@ -1077,9 +1074,11 @@ async def _query_slo_metrics(
self, start: datetime, end: datetime, metrics: List[PrometheusMetric]
) -> Dict[str, List[servo.TimeSeries]]:
"""Query prometheus for the provided metrics and return mapping of metric names to their corresponding readings"""
readings = await asyncio.gather(
*list(map(lambda m: self._query_prometheus(m, start, end), metrics))
)
async with asyncio.TaskGroup() as tg:
q_tasks = [
tg.create_task(self._query_prometheus(m, start, end)) for m in metrics
]
readings = (qt.result() for qt in q_tasks)
return dict(map(lambda tup: (tup[0].name, tup[1]), zip(metrics, readings)))


Expand Down
31 changes: 13 additions & 18 deletions servo/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,36 +1038,31 @@ async def run(self) -> List[EventResult]:
if results:
break
else:
group = asyncio.gather(
*list(
map(
lambda c: c.run_event_handlers(
async with asyncio.TaskGroup() as tg:
ev_tasks = [
tg.create_task(
c.run_event_handlers(
self.event,
Preposition.on,
return_exceptions=self._return_exceptions,
*self._args,
**self._kwargs,
),
self._connectors,
)
)
),
)
results = await group
for c in self._connectors
]

results = (et.result() for et in ev_tasks)
results = list(filter(lambda r: r is not None, results))
results = functools.reduce(lambda x, y: x + y, results, [])

# Invoke the after event handlers
if self._prepositions & Preposition.after:
await asyncio.gather(
*list(
map(
lambda c: c.run_event_handlers(
self.event, Preposition.after, results
),
self._connectors,
async with asyncio.TaskGroup() as tg:
for c in self._connectors:
_ = tg.create_task(
c.run_event_handlers(self.event, Preposition.after, results)
)
)
)

if self.channel:
await self.channel.close()
Expand Down
6 changes: 5 additions & 1 deletion servo/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,11 @@ async def _shutdown(self, loop: asyncio.AbstractEventLoop, signal=None) -> None:
except Exception as error:
self.logger.critical(f"Failed assembly shutdown with error: {error}")

await asyncio.gather(self.progress_handler.shutdown(), return_exceptions=True)
try:
await self.progress_handler.shutdown()
except Exception as error:
self.logger.warning(f"Failed progress handler shutdown with error: {error}")

self.logger.remove(self.progress_handler_id)

# Cancel any outstanding tasks -- under a clean, graceful shutdown this list will be empty
Expand Down
41 changes: 17 additions & 24 deletions servo/utilities/subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,29 +314,28 @@ async def stream_subprocess_output(
:raises asyncio.TimeoutError: Raised if the timeout expires before the subprocess exits.
:return: The exit status of the subprocess.
"""
tasks = []
if process.stdout:
tasks.append(
asyncio.create_task(
_read_lines_from_output_stream(process.stdout, stdout_callback),
name="stdout",
)
)
if process.stderr:
tasks.append(
asyncio.create_task(
_read_lines_from_output_stream(process.stderr, stderr_callback),
name="stderr",
)
)

timeout_in_seconds = (
timeout.total_seconds() if isinstance(timeout, datetime.timedelta) else timeout
)
try:
# Gather the stream output tasks and the parent process
gather_task = asyncio.gather(*tasks, process.wait())
await asyncio.wait_for(gather_task, timeout=timeout_in_seconds)
async with asyncio.timeout(delay=timeout_in_seconds):
async with asyncio.TaskGroup() as tg:
if process.stdout:
tg.create_task(
_read_lines_from_output_stream(process.stdout, stdout_callback),
name="stdout",
)

if process.stderr:
tg.create_task(
_read_lines_from_output_stream(process.stderr, stderr_callback),
name="stderr",
)

tg.create_task(process.wait())

# Gather the stream output tasks and the parent process (with block does not exit until error or all complete)

except (asyncio.TimeoutError, asyncio.CancelledError):
with contextlib.suppress(ProcessLookupError):
Expand All @@ -351,12 +350,6 @@ async def stream_subprocess_output(
process.kill()
await process.wait()

with contextlib.suppress(asyncio.CancelledError):
await gather_task

[task.cancel() for task in tasks]
await asyncio.gather(*tasks, return_exceptions=True)

raise

return cast(int, process.returncode)
Expand Down

0 comments on commit 0629476

Please sign in to comment.