Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inspect tasks instead of assuming call signatures and return types #121

Merged
merged 2 commits into from
May 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 16 additions & 20 deletions src/ezmsg/core/backendprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,32 +301,28 @@ async def perf_publish(stream: Stream, obj: Any) -> None:
)

pub_fn = perf_publish if hasattr(task, TIMEIT_ATTR) else publish

call_fn = lambda _: task(unit)
signature = inspect.signature(task)
if len(signature.parameters) == 1:
call_fn = lambda _: task(unit)
elif len(signature.parameters) == 2:
call_fn = lambda msg: task(unit, msg)
else:
logger.error(f'Incompatible call signature on task: {task.__name__}')

@wraps(task)
async def wrapped_task(msg: Any = None) -> None:
try:
# If we don't sub or pub anything, we are a simple task
if not hasattr(task, SUBSCRIBES_ATTR) and not hasattr(
task, PUBLISHES_ATTR
):
await task(unit)

# No subscriptions; only publications...
elif not hasattr(task, SUBSCRIBES_ATTR):
async for stream, obj in task(unit):
result = call_fn(msg)
if inspect.isasyncgen(result):
async for stream, obj in result:
if obj and getattr(task, ZERO_COPY_ATTR, False) and obj is msg:
obj = deepcopy(obj)
await pub_fn(stream, obj)

# Subscribers need to be called with a message
else:
if not getattr(task, ZERO_COPY_ATTR):
msg = deepcopy(msg)
if hasattr(task, PUBLISHES_ATTR):
async for stream, obj in task(unit, msg):
if getattr(task, ZERO_COPY_ATTR) and obj is msg:
obj = deepcopy(obj)
await pub_fn(stream, obj)
else:
await task(unit, msg)
elif asyncio.iscoroutine(result):
await result

except Complete:
logger.info(f"{task_address} Complete")
Expand Down