diff --git a/changelog/4811.improvement.rst b/changelog/4811.improvement.rst new file mode 100644 index 000000000000..344f6cd7b10f --- /dev/null +++ b/changelog/4811.improvement.rst @@ -0,0 +1 @@ +Support invoking a ``SlackBot`` by direct messaging or ``@`` mentions. \ No newline at end of file diff --git a/docs/user-guide/connectors/slack.rst b/docs/user-guide/connectors/slack.rst index 9bd828cb4d68..fe2ccf4ff001 100644 --- a/docs/user-guide/connectors/slack.rst +++ b/docs/user-guide/connectors/slack.rst @@ -53,9 +53,9 @@ e.g. using: You need to supply a ``credentials.yml`` with the following content: -- The ``slack_channel`` is the target your bot posts to. - This can be a channel or an individual person. You can leave out - the argument to post DMs to the bot. +- The ``slack_channel`` can be a channel or an individual person that the bot should listen to for communications, in + addition to the default behavior of listening for direct messages and app mentions, i.e. "@app_name". + - Use the entry for ``Bot User OAuth Access Token`` in the "OAuth & Permissions" tab as your ``slack_token``. It should start @@ -75,4 +75,4 @@ The endpoint for receiving slack messages is ``http://localhost:5005/webhooks/slack/webhook``, replacing the host and port with the appropriate values. This is the URL you should add in the "OAuth & Permissions" section as well as -the "Event Subscriptions". +the "Event Subscriptions". \ No newline at end of file diff --git a/rasa/core/channels/slack.py b/rasa/core/channels/slack.py index 73bf823ef2bc..ac97f94b7e87 100644 --- a/rasa/core/channels/slack.py +++ b/rasa/core/channels/slack.py @@ -169,6 +169,20 @@ def __init__( self.retry_reason_header = slack_retry_reason_header self.retry_num_header = slack_retry_number_header + @staticmethod + def _is_app_mention(slack_event: Dict) -> bool: + try: + return slack_event["event"]["type"] == "app_mention" + except KeyError: + return False + + @staticmethod + def _is_direct_message(slack_event: Dict) -> bool: + try: + return slack_event["event"]["channel_type"] == "im" + except KeyError: + return False + @staticmethod def _is_user_message(slack_event: Dict) -> bool: return ( @@ -293,11 +307,15 @@ async def process_message( return response.text(None, status=201, headers={"X-Slack-No-Retry": 1}) + if metadata is not None: + output_channel = metadata.get("out_channel") + else: + output_channel = None + try: - out_channel = self.get_output_channel() user_msg = UserMessage( text, - out_channel, + self.get_output_channel(output_channel), sender_id, input_channel=self.name(), metadata=metadata, @@ -310,6 +328,24 @@ async def process_message( return response.text("") + def get_metadata(self, request: Request) -> Dict[Text, Any]: + """Extracts the metadata from a slack API event (https://api.slack.com/types/event). + + Args: + request: A `Request` object that contains a slack API event in the body. + + Returns: + Metadata extracted from the sent event payload. This includes the output channel for the response, + and users that have installed the bot. + """ + slack_event = request.json + event = slack_event.get("event", {}) + + return { + "out_channel": event.get("channel"), + "users": slack_event.get("authed_users"), + } + def blueprint( self, on_new_message: Callable[[UserMessage], Awaitable[Any]] ) -> Blueprint: @@ -342,24 +378,45 @@ async def webhook(request: Request) -> HTTPResponse: elif request.json: output = request.json + event = output.get("event", {}) + user_message = event.get("text", "") + sender_id = event.get("user", "") + metadata = self.get_metadata(request) + if "challenge" in output: return response.json(output.get("challenge")) - elif self._is_user_message(output): - metadata = self.get_metadata(request) + elif self._is_user_message(output) and self._is_supported_channel( + output, metadata + ): return await self.process_message( request, on_new_message, - self._sanitize_user_message( - output["event"]["text"], output["authed_users"] + text=self._sanitize_user_message( + user_message, metadata["users"] ), - output.get("event").get("user"), - metadata, + sender_id=sender_id, + metadata=metadata, + ) + else: + logger.warning( + f"Received message on unsupported channel: {metadata['out_channel']}" ) - return response.text("Bot message delivered") + return response.text("Bot message delivered.") return slack_webhook - def get_output_channel(self) -> OutputChannel: - return SlackBot(self.slack_token, self.slack_channel) + def _is_supported_channel(self, slack_event: Dict, metadata: Dict) -> bool: + return ( + self._is_direct_message(slack_event) + or self._is_app_mention(slack_event) + or metadata["out_channel"] == self.slack_channel + ) + + def get_output_channel(self, channel: Optional[Text] = None) -> OutputChannel: + channel = channel or self.slack_channel + return SlackBot(self.slack_token, channel) + + def set_output_channel(self, channel: Text) -> None: + self.slack_channel = channel diff --git a/tests/core/test_channels.py b/tests/core/test_channels.py index 3473dfdbfb4e..32f61521e4d9 100644 --- a/tests/core/test_channels.py +++ b/tests/core/test_channels.py @@ -2,7 +2,7 @@ import logging import urllib.parse from typing import Dict -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, Mock import pytest import responses @@ -461,6 +461,92 @@ def test_botframework_attachments(): assert ch.add_attachments_to_metadata(payload, metadata) == updated_metadata +def test_slack_metadata(): + from rasa.core.channels.slack import SlackInput + from sanic.request import Request + + user = "user1" + channel = "channel1" + authed_users = ["XXXXXXX", "YYYYYYY", "ZZZZZZZ"] + direct_message_event = { + "authed_users": authed_users, + "event": { + "client_msg_id": "XXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX", + "type": "message", + "text": "hello world", + "user": user, + "ts": "1579802617.000800", + "team": "XXXXXXXXX", + "blocks": [ + { + "type": "rich_text", + "block_id": "XXXXX", + "elements": [ + { + "type": "rich_text_section", + "elements": [{"type": "text", "text": "hi"}], + } + ], + } + ], + "channel": channel, + "event_ts": "1579802617.000800", + "channel_type": "im", + }, + } + + input_channel = SlackInput( + slack_token="YOUR_SLACK_TOKEN", slack_channel="YOUR_SLACK_CHANNEL" + ) + + r = Mock() + r.json = direct_message_event + metadata = input_channel.get_metadata(request=r) + assert metadata["out_channel"] == channel + assert metadata["users"] == authed_users + + +def test_slack_metadata_missing_keys(): + from rasa.core.channels.slack import SlackInput + from sanic.request import Request + + channel = "channel1" + direct_message_event = { + "event": { + "client_msg_id": "XXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX", + "type": "message", + "text": "hello world", + "ts": "1579802617.000800", + "team": "XXXXXXXXX", + "blocks": [ + { + "type": "rich_text", + "block_id": "XXXXX", + "elements": [ + { + "type": "rich_text_section", + "elements": [{"type": "text", "text": "hi"}], + } + ], + } + ], + "channel": channel, + "event_ts": "1579802617.000800", + "channel_type": "im", + }, + } + + input_channel = SlackInput( + slack_token="YOUR_SLACK_TOKEN", slack_channel="YOUR_SLACK_CHANNEL" + ) + + r = Mock() + r.json = direct_message_event + metadata = input_channel.get_metadata(request=r) + assert metadata["users"] is None + assert metadata["out_channel"] == channel + + def test_slack_message_sanitization(): from rasa.core.channels.slack import SlackInput