diff --git a/slack_bolt/app/async_app.py b/slack_bolt/app/async_app.py index 7e984c5d9..89a5124a3 100644 --- a/slack_bolt/app/async_app.py +++ b/slack_bolt/app/async_app.py @@ -24,7 +24,7 @@ AsyncMessageListenerMatches, ) from slack_bolt.oauth.async_internals import select_consistent_installation_store -from slack_bolt.util.utils import get_name_for_callable +from slack_bolt.util.utils import get_name_for_callable, is_coroutine_function from slack_bolt.workflows.step.async_step import ( AsyncWorkflowStep, AsyncWorkflowStepBuilder, @@ -778,7 +778,7 @@ async def custom_error_handler(error, body, logger): func: The function that is supposed to be executed when getting an unhandled error in Bolt app. """ - if not inspect.iscoroutinefunction(func): + if not is_coroutine_function(func): name = get_name_for_callable(func) raise BoltError(error_listener_function_must_be_coro_func(name)) self._async_listener_runner.listener_error_handler = AsyncCustomListenerErrorHandler( @@ -1410,7 +1410,7 @@ def _register_listener( value_to_return = functions[0] for func in functions: - if not inspect.iscoroutinefunction(func): + if not is_coroutine_function(func): name = get_name_for_callable(func) raise BoltError(error_listener_function_must_be_coro_func(name)) @@ -1422,7 +1422,7 @@ def _register_listener( for m in middleware or []: if isinstance(m, AsyncMiddleware): listener_middleware.append(m) - elif isinstance(m, Callable) and inspect.iscoroutinefunction(m): + elif isinstance(m, Callable) and is_coroutine_function(m): listener_middleware.append(AsyncCustomMiddleware(app_name=self.name, func=m, base_logger=self._base_logger)) else: raise ValueError(error_unexpected_listener_middleware(type(m))) diff --git a/slack_bolt/middleware/async_custom_middleware.py b/slack_bolt/middleware/async_custom_middleware.py index e2060b75c..a8f2a0f9d 100644 --- a/slack_bolt/middleware/async_custom_middleware.py +++ b/slack_bolt/middleware/async_custom_middleware.py @@ -1,4 +1,3 @@ -import inspect from logging import Logger from typing import Callable, Awaitable, Any, Sequence, Optional @@ -7,7 +6,7 @@ from slack_bolt.request.async_request import AsyncBoltRequest from slack_bolt.response import BoltResponse from .async_middleware import AsyncMiddleware -from slack_bolt.util.utils import get_name_for_callable, get_arg_names_of_callable +from slack_bolt.util.utils import get_name_for_callable, get_arg_names_of_callable, is_coroutine_function class AsyncCustomMiddleware(AsyncMiddleware): @@ -24,7 +23,7 @@ def __init__( base_logger: Optional[Logger] = None, ): self.app_name = app_name - if inspect.iscoroutinefunction(func): + if is_coroutine_function(func): self.func = func else: raise ValueError("Async middleware function must be an async function") diff --git a/slack_bolt/util/utils.py b/slack_bolt/util/utils.py index efb815399..a5bcdbe5f 100644 --- a/slack_bolt/util/utils.py +++ b/slack_bolt/util/utils.py @@ -88,3 +88,9 @@ def get_name_for_callable(func: Callable) -> str: def get_arg_names_of_callable(func: Callable) -> List[str]: return inspect.getfullargspec(inspect.unwrap(func)).args + + +def is_coroutine_function(func: Optional[Any]) -> bool: + return func is not None and ( + inspect.iscoroutinefunction(func) or (hasattr(func, "__call__") and inspect.iscoroutinefunction(func.__call__)) + ) diff --git a/tests/scenario_tests_async/test_app_using_methods_in_class.py b/tests/scenario_tests_async/test_app_using_methods_in_class.py index 806bdf7f1..a24fe9528 100644 --- a/tests/scenario_tests_async/test_app_using_methods_in_class.py +++ b/tests/scenario_tests_async/test_app_using_methods_in_class.py @@ -149,6 +149,14 @@ async def test_instance_methods(self): app.shortcut("test-shortcut")(awesome.instance_method) await self.run_app_and_verify(app) + @pytest.mark.asyncio + async def test_callable_class(self): + app = AsyncApp(client=self.web_client, signing_secret=self.signing_secret) + instance = CallableClass("Slackbot") + app.use(instance) + app.shortcut("test-shortcut")(instance.event_handler) + await self.run_app_and_verify(app) + @pytest.mark.asyncio async def test_instance_methods_uncommon_name_1(self): app = AsyncApp(client=self.web_client, signing_secret=self.signing_secret) @@ -225,6 +233,18 @@ async def static_method(context: AsyncBoltContext, say: AsyncSay, ack: AsyncAck) await say(f"Hello <@{context.user_id}>!") +class CallableClass: + def __init__(self, name: str): + self.name = name + + async def __call__(self, next: Callable): + await next() + + async def event_handler(self, context: AsyncBoltContext, say: AsyncSay, ack: AsyncAck): + await ack() + await say(f"Hello <@{context.user_id}>! My name is {self.name}") + + async def top_level_function(invalid_arg, ack, say): assert invalid_arg is None await ack()