diff --git a/src/ezmsg/core/backendprocess.py b/src/ezmsg/core/backendprocess.py index 9708097..e185cc2 100644 --- a/src/ezmsg/core/backendprocess.py +++ b/src/ezmsg/core/backendprocess.py @@ -301,34 +301,27 @@ async def perf_publish(stream: Stream, obj: Any) -> None: ) pub_fn = perf_publish if hasattr(task, TIMEIT_ATTR) else publish + + call_fn = lambda _: task(unit) + signature = inspect.signature(task) + if len(signature.parameters) == 1: + call_fn = lambda _: task(unit) + elif len(signature.parameters) == 2: + call_fn = lambda msg: task(unit, msg) + else: + logger.error(f'Incompatible call signature on task: {task.__name__}') @wraps(task) async def wrapped_task(msg: Any = None) -> None: try: - - result = None - signature = inspect.signature(task) - if len(signature.parameters) == 1: - # Task does not accept incoming messages - result = task(unit) - elif len(signature.parameters) == 2: - # Task requires an incoming message - if not getattr(task, ZERO_COPY_ATTR, False): - msg = deepcopy(msg) - result = task(unit, msg) - else: - logger.error(f'Incompatible call signature on task: {task.__name__}') - + result = call_fn(msg) if inspect.isasyncgen(result): - # Task returned an async generator - # it must want to publish stuff async for stream, obj in result: if obj and getattr(task, ZERO_COPY_ATTR, False) and obj is msg: obj = deepcopy(obj) await pub_fn(stream, obj) elif asyncio.iscoroutine(result): - # Task returned a simple coroutine await result except Complete: