Skip to content

Commit

Permalink
Fix #1131 A class with async "__call__" method fails to work as a mid…
Browse files Browse the repository at this point in the history
…dleware
  • Loading branch information
seratch committed Aug 21, 2024
1 parent 40f6d1e commit 5cf5551
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 7 deletions.
8 changes: 4 additions & 4 deletions slack_bolt/app/async_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
AsyncMessageListenerMatches,
)
from slack_bolt.oauth.async_internals import select_consistent_installation_store
from slack_bolt.util.utils import get_name_for_callable
from slack_bolt.util.utils import get_name_for_callable, is_coroutine_function
from slack_bolt.workflows.step.async_step import (
AsyncWorkflowStep,
AsyncWorkflowStepBuilder,
Expand Down Expand Up @@ -778,7 +778,7 @@ async def custom_error_handler(error, body, logger):
func: The function that is supposed to be executed
when getting an unhandled error in Bolt app.
"""
if not inspect.iscoroutinefunction(func):
if not is_coroutine_function(func):
name = get_name_for_callable(func)
raise BoltError(error_listener_function_must_be_coro_func(name))
self._async_listener_runner.listener_error_handler = AsyncCustomListenerErrorHandler(
Expand Down Expand Up @@ -1410,7 +1410,7 @@ def _register_listener(
value_to_return = functions[0]

for func in functions:
if not inspect.iscoroutinefunction(func):
if not is_coroutine_function(func):
name = get_name_for_callable(func)
raise BoltError(error_listener_function_must_be_coro_func(name))

Expand All @@ -1422,7 +1422,7 @@ def _register_listener(
for m in middleware or []:
if isinstance(m, AsyncMiddleware):
listener_middleware.append(m)
elif isinstance(m, Callable) and inspect.iscoroutinefunction(m):
elif isinstance(m, Callable) and is_coroutine_function(m):
listener_middleware.append(AsyncCustomMiddleware(app_name=self.name, func=m, base_logger=self._base_logger))
else:
raise ValueError(error_unexpected_listener_middleware(type(m)))
Expand Down
5 changes: 2 additions & 3 deletions slack_bolt/middleware/async_custom_middleware.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import inspect
from logging import Logger
from typing import Callable, Awaitable, Any, Sequence, Optional

Expand All @@ -7,7 +6,7 @@
from slack_bolt.request.async_request import AsyncBoltRequest
from slack_bolt.response import BoltResponse
from .async_middleware import AsyncMiddleware
from slack_bolt.util.utils import get_name_for_callable, get_arg_names_of_callable
from slack_bolt.util.utils import get_name_for_callable, get_arg_names_of_callable, is_coroutine_function


class AsyncCustomMiddleware(AsyncMiddleware):
Expand All @@ -24,7 +23,7 @@ def __init__(
base_logger: Optional[Logger] = None,
):
self.app_name = app_name
if inspect.iscoroutinefunction(func):
if is_coroutine_function(func):
self.func = func
else:
raise ValueError("Async middleware function must be an async function")
Expand Down
6 changes: 6 additions & 0 deletions slack_bolt/util/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,9 @@ def get_name_for_callable(func: Callable) -> str:

def get_arg_names_of_callable(func: Callable) -> List[str]:
return inspect.getfullargspec(inspect.unwrap(func)).args


def is_coroutine_function(func: Optional[Any]) -> bool:
return func is not None and (
inspect.iscoroutinefunction(func) or (hasattr(func, "__call__") and inspect.iscoroutinefunction(func.__call__))
)
20 changes: 20 additions & 0 deletions tests/scenario_tests_async/test_app_using_methods_in_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,14 @@ async def test_instance_methods(self):
app.shortcut("test-shortcut")(awesome.instance_method)
await self.run_app_and_verify(app)

@pytest.mark.asyncio
async def test_callable_class(self):
app = AsyncApp(client=self.web_client, signing_secret=self.signing_secret)
instance = CallableClass("Slackbot")
app.use(instance)
app.shortcut("test-shortcut")(instance.event_handler)
await self.run_app_and_verify(app)

@pytest.mark.asyncio
async def test_instance_methods_uncommon_name_1(self):
app = AsyncApp(client=self.web_client, signing_secret=self.signing_secret)
Expand Down Expand Up @@ -225,6 +233,18 @@ async def static_method(context: AsyncBoltContext, say: AsyncSay, ack: AsyncAck)
await say(f"Hello <@{context.user_id}>!")


class CallableClass:
def __init__(self, name: str):
self.name = name

async def __call__(self, next: Callable):
await next()

async def event_handler(self, context: AsyncBoltContext, say: AsyncSay, ack: AsyncAck):
await ack()
await say(f"Hello <@{context.user_id}>! My name is {self.name}")


async def top_level_function(invalid_arg, ack, say):
assert invalid_arg is None
await ack()
Expand Down

0 comments on commit 5cf5551

Please sign in to comment.