Skip to content

Commit

Permalink
fixed cancellation of task ala python 3.8
Browse files Browse the repository at this point in the history
  • Loading branch information
sanderegg committed Jul 28, 2021
1 parent 32f2d94 commit 2b1fd4a
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 63 deletions.
23 changes: 12 additions & 11 deletions services/sidecar/src/simcore_service_sidecar/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,27 @@

from .__version__ import __version__
from .celery_configurator import create_celery_app
from .celery_task_utils import cancel_task
from .cli import run_sidecar
from .remote_debug import setup_remote_debugging
from .utils import cancel_task_by_fct_name

setup_remote_debugging()

app = create_celery_app()

log = logging.getLogger(__name__)

#
# SEE https://patorjk.com/software/taag/#p=display&h=0&f=Ogre&t=Celery-sidecar
#
WELCOME_MSG = r"""
.-') _ .-') _ ('-. ('-. _ .-')
( OO ). ( ( OO) ) _( OO) ( OO ).-.( \( -O )
(_)---\_) ,-.-') \ .'_ (,------. .-----. / . --. / ,------.
/ _ | | |OO),`'--..._) | .---' ' .--./ | \-. \ | /`. '
\ :` `. | | \| | \ ' | | | |('-..-'-' | | | / | |
'..`''.) | |(_/| | ' |(| '--. /_) |OO )\| |_.' | | |_.' |
.-._) \ ,| |_.'| | / : | .--' || |`-'| | .-. | | . '.'
\ /(_| | | '--' / | `---.(_' '--'\ | | | | | |\ \
`-----' `--' `-------' `------' `-----' `--' `--' `--' '--' {0} - {1}
___ _ _ _
/ __\ ___ | | ___ _ __ _ _ ___ (_) __| | ___ ___ __ _ _ __
/ / / _ \| | / _ \| '__|| | | | _____ / __|| | / _` | / _ \ / __| / _` || '__|
/ /___ | __/| || __/| | | |_| ||_____|\__ \| || (_| || __/| (__ | (_| || |
\____/ \___||_| \___||_| \__, | |___/|_| \__,_| \___| \___| \__,_||_|
|___/ {0} - {1}
""".format(
__version__, app.conf.osparc_sidecar_bootmode.value
)
Expand All @@ -39,7 +40,7 @@ def worker_shutting_down_handler(
):
# NOTE: this function shall be adapted when we switch to python 3.7+
log.warning("detected worker_shutting_down signal(%s, %s, %s)", sig, how, exitcode)
cancel_task(run_sidecar)
cancel_task_by_fct_name(run_sidecar.__name__)


@worker_ready.connect
Expand Down
4 changes: 2 additions & 2 deletions services/sidecar/src/simcore_service_sidecar/celery_task.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
import logging
from asyncio import CancelledError

from .cli import run_sidecar
from .utils import wrap_async_call

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -37,7 +37,7 @@ def _shared_task_dispatch(
celery_request.max_retries,
)
try:
wrap_async_call(
asyncio.run(
run_sidecar(
celery_request.request.id,
user_id,
Expand Down
11 changes: 0 additions & 11 deletions services/sidecar/src/simcore_service_sidecar/celery_task_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import asyncio
import logging
from pprint import pformat
from typing import Callable

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -39,12 +37,3 @@ def on_task_success_handler(
args if args else "none",
pformat(kwargs) if kwargs else "none",
)


def cancel_task(function: Callable) -> None:
tasks = asyncio.all_tasks()
for task in tasks:
# pylint: disable=protected-access
if task._coro.__name__ == function.__name__:
log.warning("canceling task....................")
task.cancel()
28 changes: 16 additions & 12 deletions services/sidecar/src/simcore_service_sidecar/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
from servicelib.logging_utils import log_decorator

from .boot_mode import BootMode
from .celery_task_utils import cancel_task
from .config import SIDECAR_INTERVAL_TO_CHECK_TASK_ABORTED_S
from .core import run_computational_task
from .db import DBContextManager
from .rabbitmq import RabbitMQ
from .utils import wrap_async_call
from .utils import cancel_task

log = logging.getLogger(__name__)

Expand All @@ -24,20 +23,22 @@
def main(job_id: str, user_id: str, project_id: str, node_id: str) -> None:

try:
wrap_async_call(run_sidecar(job_id, user_id, project_id, node_id=node_id))
asyncio.run(run_sidecar(job_id, user_id, project_id, node_id=node_id))
except Exception: # pylint: disable=broad-except
log.exception("Unexpected problem while running sidecar")


@log_decorator(logger=log, level=logging.INFO)
async def perdiodicaly_check_if_aborted(is_aborted_cb: Callable[[], bool]) -> None:
async def perdiodicaly_check_if_aborted(
is_aborted_cb: Callable[[], bool], task_name: str
) -> None:
try:
while await asyncio.sleep(
SIDECAR_INTERVAL_TO_CHECK_TASK_ABORTED_S, result=True
):
if is_aborted_cb():
log.info("Task was aborted. Cancelling...")
asyncio.get_event_loop().call_soon(cancel_task(run_sidecar))
log.info("Task was aborted. Cancelling fct [%s]...", task_name)
asyncio.get_event_loop().call_soon(cancel_task, task_name)
except asyncio.CancelledError:
pass

Expand All @@ -53,13 +54,16 @@ async def run_sidecar( # pylint: disable=too-many-arguments
max_retries: int = 1,
sidecar_mode: BootMode = BootMode.CPU,
) -> None:
abortion_task = (
asyncio.get_event_loop().create_task(
perdiodicaly_check_if_aborted(is_aborted_cb)

abortion_task: Optional[asyncio.Task] = None
if current_task := asyncio.current_task():
abortion_task = (
asyncio.create_task(
perdiodicaly_check_if_aborted(is_aborted_cb, current_task.get_name())
)
if is_aborted_cb
else None
)
if is_aborted_cb
else None
)
try:
async with DBContextManager() as db_engine, RabbitMQ() as rabbit_mq:
await run_computational_task(
Expand Down
49 changes: 23 additions & 26 deletions services/sidecar/src/simcore_service_sidecar/rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import aio_pika
import tenacity
from models_library.settings.celery import CeleryConfig
from pydantic import BaseModel # pylint: disable=no-name-in-module
from servicelib.logging_utils import log_decorator
from models_library.settings.rabbit import ( # pylint: disable=no-name-in-module
RabbitDsn,
)
from pydantic import BaseModel, PrivateAttr
from servicelib.rabbitmq_utils import RabbitMQRetryPolicyUponInitialization

from . import config
Expand All @@ -26,69 +28,64 @@ def _close_callback(sender: Any, exc: Optional[BaseException]):
sender,
exc_info=True,
)
else:
log.info("Rabbit connection closed from %s", sender)


def _channel_close_callback(sender: Any, exc: Optional[BaseException]):
if exc:
log.error(
"Rabbit channel closed with exception from %s:", sender, exc_info=True
)
else:
log.info("Rabbit channel closed from %s", sender)


class RabbitMQ(BaseModel):
celery_config: CeleryConfig = None
connection: aio_pika.Connection = None
channel: aio_pika.Channel = None
logs_exchange: aio_pika.Exchange = None
instrumentation_exchange: aio_pika.Exchange = None
celery_config: Optional[CeleryConfig] = None
_connection: aio_pika.Connection = PrivateAttr()
_channel: aio_pika.Channel = PrivateAttr()
_logs_exchange: aio_pika.Exchange = PrivateAttr()
_instrumentation_exchange: aio_pika.Exchange = PrivateAttr()

class Config:
# see https://pydantic-docs.helpmanual.io/usage/types/#arbitrary-types-allowed
arbitrary_types_allowed = True

@log_decorator(logger=log)
async def connect(self):
if not self.celery_config:
self.celery_config = config.CELERY_CONFIG
url = self.celery_config.rabbit.dsn
if not url:
raise ValueError("Rabbit DSN not set")
log.debug("Connecting to %s", url)
await _wait_till_rabbit_responsive(url)

# NOTE: to show the connection name in the rabbitMQ UI see there [https://www.bountysource.com/issues/89342433-setting-custom-connection-name-via-client_properties-doesn-t-work-when-connecting-using-an-amqp-url]
self.connection = await aio_pika.connect(
self._connection = await aio_pika.connect(
url + f"?name={__name__}_{id(socket.gethostname())}",
client_properties={"connection_name": "sidecar connection"},
)
self.connection.add_close_callback(_close_callback)
self._connection.add_close_callback(_close_callback)

log.debug("Creating channel")
self.channel = await self.connection.channel(publisher_confirms=False)
self.channel.add_close_callback(_channel_close_callback)
self._channel = await self._connection.channel(publisher_confirms=False)
self._channel.add_close_callback(_channel_close_callback)

log.debug("Declaring %s exchange", self.celery_config.rabbit.channels["log"])
self.logs_exchange = await self.channel.declare_exchange(
self._logs_exchange = await self._channel.declare_exchange(
self.celery_config.rabbit.channels["log"], aio_pika.ExchangeType.FANOUT
)

log.debug(
"Declaring %s exchange",
self.celery_config.rabbit.channels["instrumentation"],
)
self.instrumentation_exchange = await self.channel.declare_exchange(
self._instrumentation_exchange = await self._channel.declare_exchange(
self.celery_config.rabbit.channels["instrumentation"],
aio_pika.ExchangeType.FANOUT,
)

@log_decorator(logger=log)
async def close(self):
await self.channel.close()
await self.connection.close()
await self._channel.close()
await self._connection.close()

@log_decorator(logger=log)
async def _post_message(
self, exchange: aio_pika.Exchange, data: Dict[str, Union[str, Any]]
):
Expand All @@ -104,7 +101,7 @@ async def post_log_message(
log_msg: Union[str, List[str]],
):
await self._post_message(
self.logs_exchange,
self._logs_exchange,
data={
"Channel": "Log",
"Node": node_id,
Expand All @@ -118,7 +115,7 @@ async def post_progress_message(
self, user_id: str, project_id: str, node_id: str, progress_msg: str
):
await self._post_message(
self.logs_exchange,
self._logs_exchange,
data={
"Channel": "Progress",
"Node": node_id,
Expand All @@ -133,7 +130,7 @@ async def post_instrumentation_message(
instrumentation_data: Dict,
):
await self._post_message(
self.instrumentation_exchange,
self._instrumentation_exchange,
data=instrumentation_data,
)

Expand All @@ -146,7 +143,7 @@ async def __aexit__(self, exc_type, exc, tb):


@tenacity.retry(**RabbitMQRetryPolicyUponInitialization().kwargs)
async def _wait_till_rabbit_responsive(url: str):
async def _wait_till_rabbit_responsive(url: RabbitDsn):
connection = await aio_pika.connect(url)
await connection.close()
return True
18 changes: 18 additions & 0 deletions services/sidecar/src/simcore_service_sidecar/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,21 @@ def touch_tmpfile(extension=".dat") -> Path:
"""
with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as file_handler:
return Path(file_handler.name)


def cancel_task(task_name: str) -> None:
tasks = asyncio.all_tasks()
logger.debug("running tasks: %s", tasks)
for task in tasks:
if task.get_name() == task_name:
logger.warning("canceling task %s....................", task)
task.cancel()


def cancel_task_by_fct_name(fct_name: str) -> None:
tasks = asyncio.all_tasks()
logger.debug("running tasks: %s", tasks)
for task in tasks:
if task.get_coro().__name__ == fct_name:
logger.warning("canceling task %s....................", task)
task.cancel()
2 changes: 1 addition & 1 deletion services/sidecar/tests/unit/test_rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def rabbit_message_handler(message: aio_pika.IncomingMessage):
await rabbit_queue.consume(rabbit_message_handler, exclusive=True, no_ack=True)

async with RabbitMQ() as rabbitmq:
assert rabbitmq.connection.ready
assert rabbitmq._connection.ready # pylint: disable=protected-access

await rabbitmq.post_log_message(user_id, project_id, node_id, log_msg)
await rabbitmq.post_log_message(user_id, project_id, node_id, log_messages)
Expand Down

0 comments on commit 2b1fd4a

Please sign in to comment.