Skip to content

Commit

Permalink
inspect instead of assume
Browse files Browse the repository at this point in the history
  • Loading branch information
Griffin Milsap committed May 10, 2024
1 parent 2682137 commit 1163e5d
Showing 1 changed file with 23 additions and 20 deletions.
43 changes: 23 additions & 20 deletions src/ezmsg/core/backendprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,28 +305,31 @@ async def perf_publish(stream: Stream, obj: Any) -> None:
@wraps(task)
async def wrapped_task(msg: Any = None) -> None:
try:
# If we don't sub or pub anything, we are a simple task
if not hasattr(task, SUBSCRIBES_ATTR) and not hasattr(
task, PUBLISHES_ATTR
):
await task(unit)

# No subscriptions; only publications...
elif not hasattr(task, SUBSCRIBES_ATTR):
async for stream, obj in task(unit):
await pub_fn(stream, obj)

# Subscribers need to be called with a message
else:
if not getattr(task, ZERO_COPY_ATTR):
result = None
signature = inspect.signature(task)
if len(signature.parameters) == 1:
# Task does not accept incoming messages
result = task(unit)
elif len(signature.parameters) == 2:
# Task requires an incoming message
if not getattr(task, ZERO_COPY_ATTR, False):
msg = deepcopy(msg)
if hasattr(task, PUBLISHES_ATTR):
async for stream, obj in task(unit, msg):
if getattr(task, ZERO_COPY_ATTR) and obj is msg:
obj = deepcopy(obj)
await pub_fn(stream, obj)
else:
await task(unit, msg)
result = task(unit, msg)
else:
logger.error(f'Incompatible call signature on task: {task.__name__}')

if inspect.isasyncgen(result):
# Task returned an async generator
# it must want to publish stuff
async for stream, obj in result:
if obj and getattr(task, ZERO_COPY_ATTR, False) and obj is msg:
obj = deepcopy(obj)
await pub_fn(stream, obj)

elif asyncio.iscoroutine(result):
# Task returned a simple coroutine
await result

except Complete:
logger.info(f"{task_address} Complete")
Expand Down

0 comments on commit 1163e5d

Please sign in to comment.