diff --git a/src/ezmsg/core/backendprocess.py b/src/ezmsg/core/backendprocess.py index 3a367ae..9708097 100644 --- a/src/ezmsg/core/backendprocess.py +++ b/src/ezmsg/core/backendprocess.py @@ -305,28 +305,31 @@ async def perf_publish(stream: Stream, obj: Any) -> None: @wraps(task) async def wrapped_task(msg: Any = None) -> None: try: - # If we don't sub or pub anything, we are a simple task - if not hasattr(task, SUBSCRIBES_ATTR) and not hasattr( - task, PUBLISHES_ATTR - ): - await task(unit) - - # No subscriptions; only publications... - elif not hasattr(task, SUBSCRIBES_ATTR): - async for stream, obj in task(unit): - await pub_fn(stream, obj) - # Subscribers need to be called with a message - else: - if not getattr(task, ZERO_COPY_ATTR): + result = None + signature = inspect.signature(task) + if len(signature.parameters) == 1: + # Task does not accept incoming messages + result = task(unit) + elif len(signature.parameters) == 2: + # Task requires an incoming message + if not getattr(task, ZERO_COPY_ATTR, False): msg = deepcopy(msg) - if hasattr(task, PUBLISHES_ATTR): - async for stream, obj in task(unit, msg): - if getattr(task, ZERO_COPY_ATTR) and obj is msg: - obj = deepcopy(obj) - await pub_fn(stream, obj) - else: - await task(unit, msg) + result = task(unit, msg) + else: + logger.error(f'Incompatible call signature on task: {task.__name__}') + + if inspect.isasyncgen(result): + # Task returned an async generator + # it must want to publish stuff + async for stream, obj in result: + if obj and getattr(task, ZERO_COPY_ATTR, False) and obj is msg: + obj = deepcopy(obj) + await pub_fn(stream, obj) + + elif asyncio.iscoroutine(result): + # Task returned a simple coroutine + await result except Complete: logger.info(f"{task_address} Complete")