diff --git a/slack_bot/run.py b/slack_bot/run.py index eed1e9ef..4475e5de 100755 --- a/slack_bot/run.py +++ b/slack_bot/run.py @@ -1,12 +1,12 @@ import argparse +import asyncio import logging import os import pathlib import sys -import threading -from slack_sdk.socket_mode import SocketModeClient -from slack_sdk.web import WebClient +from slack_sdk.socket_mode.aiohttp import SocketModeClient +from slack_sdk.web.async_client import AsyncWebClient from slack_bot import MODELS, Bot @@ -17,7 +17,7 @@ DEFAULT_HF_MODEL = "StabilityAI/stablelm-tuned-alpha-3b" -if __name__ == "__main__": +async def main(): # Parse command line arguments parser = argparse.ArgumentParser() parser.add_argument( @@ -214,19 +214,25 @@ logging.error("SLACK_APP_TOKEN is not set") sys.exit(1) - # Initialize SocketModeClient with an app-level token + WebClient + # Initialize SocketModeClient with an app-level token + AsyncWebClient client = SocketModeClient( # This app-level token will be used only for establishing a connection app_token=os.environ.get("SLACK_APP_TOKEN"), - # You will be using this WebClient for performing Web API calls in listeners - web_client=WebClient(token=os.environ.get("SLACK_BOT_TOKEN")), + # You will be using this AsyncWebClient for performing Web API calls in listeners + web_client=AsyncWebClient(token=os.environ.get("SLACK_BOT_TOKEN")), + # To ensure connection doesn't go stale - we can adjust as needed. + ping_interval=60, ) # Add a new listener to receive messages from Slack client.socket_mode_request_listeners.append(slack_bot) # Establish a WebSocket connection to the Socket Mode servers - client.connect() + await client.connect() # Listen for events logging.info("Listening for requests...") - threading.Event().wait() + await asyncio.sleep(float("inf")) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/slack_bot/slack_bot/bot/bot.py b/slack_bot/slack_bot/bot/bot.py index c4934e5b..f1aedde9 100644 --- a/slack_bot/slack_bot/bot/bot.py +++ b/slack_bot/slack_bot/bot/bot.py @@ -1,26 +1,53 @@ +import asyncio import logging -from slack_sdk.socket_mode import SocketModeClient -from slack_sdk.socket_mode.listeners import SocketModeRequestListener +from slack_sdk.socket_mode.aiohttp import SocketModeClient +from slack_sdk.socket_mode.async_listeners import AsyncSocketModeRequestListener from slack_sdk.socket_mode.request import SocketModeRequest from slack_sdk.socket_mode.response import SocketModeResponse from ..models.base import ResponseModel -class Bot(SocketModeRequestListener): +class Bot(AsyncSocketModeRequestListener): def __init__(self, model: ResponseModel) -> None: self.model = model - - def __call__(self, client: SocketModeClient, req: SocketModeRequest) -> None: + self.queue = asyncio.Queue(maxsize=10) + + async def __call__(self, client: SocketModeClient, req: SocketModeRequest) -> None: + self.queue.put_nowait(self._process_request(client, req)) + logging.info(f"There are currently {self.queue.qsize()} items in the queue.") + + # Create three worker tasks to process the queue concurrently. + tasks = [] + for i in range(3): + task = asyncio.create_task(self.worker(self.queue)) + tasks.append(task) + + # await self.queue.join() + + # for task in tasks: + # task.cancel() + + @staticmethod + async def worker(queue): + while True: + coro = await queue.get() + await coro + # Notify the queue that the "work item" has been processed. + queue.task_done() + + async def _process_request( + self, client: SocketModeClient, req: SocketModeRequest + ) -> None: if req.type != "events_api": logging.info(f"Received unexpected request of type '{req.type}'") return None # Acknowledge the request - logging.info(f"Received an events_api request") + logging.info("Received an events_api request") response = SocketModeResponse(envelope_id=req.envelope_id) - client.send_socket_mode_response(response) + await client.send_socket_mode_response(response) try: # Extract event from payload @@ -45,12 +72,12 @@ def __call__(self, client: SocketModeClient, req: SocketModeRequest) -> None: # If this is a direct message to REGinald... if event_type == "message" and event_subtype is None: - self.react(client, event["channel"], event["ts"]) + await self.react(client, event["channel"], event["ts"]) model_response = self.model.direct_message(message, user_id) # If @REGinald is mentioned in a channel elif event_type == "app_mention": - self.react(client, event["channel"], event["ts"]) + await self.react(client, event["channel"], event["ts"]) model_response = self.model.channel_mention(message, user_id) # Otherwise @@ -58,15 +85,15 @@ def __call__(self, client: SocketModeClient, req: SocketModeRequest) -> None: logging.info(f"Received unexpected event of type '{event['type']}'.") return None - # Add an emoji and a reply as required + # Add a reply as required if model_response and model_response.message: logging.info(f"Posting reply {model_response.message}.") - client.web_client.chat_postMessage( + await client.web_client.chat_postMessage( channel=event["channel"], text=f"<@{user_id}>, you asked me: '{message}'.\n{model_response.message}", ) else: - logging.info(f"No reply was generated.") + logging.info("No reply was generated.") except KeyError as exc: logging.warning(f"Attempted to access key that does not exist.\n{str(exc)}") @@ -77,14 +104,16 @@ def __call__(self, client: SocketModeClient, req: SocketModeRequest) -> None: ) raise - def react(self, client: SocketModeClient, channel: str, timestamp: str) -> None: + async def react( + self, client: SocketModeClient, channel: str, timestamp: str + ) -> None: """Emoji react to the input message""" if self.model.emoji: logging.info(f"Reacting with emoji {self.model.emoji}.") - client.web_client.reactions_add( + await client.web_client.reactions_add( name=self.model.emoji, channel=channel, timestamp=timestamp, ) else: - logging.info(f"No emoji defined for this model.") + logging.info("No emoji defined for this model.") diff --git a/slack_bot/slack_bot/models/hello.py b/slack_bot/slack_bot/models/hello.py index 9c9e54fa..85e91fef 100644 --- a/slack_bot/slack_bot/models/hello.py +++ b/slack_bot/slack_bot/models/hello.py @@ -1,3 +1,5 @@ +import time + from .base import MessageResponse, ResponseModel @@ -6,7 +8,9 @@ def __init__(self): super().__init__(emoji="wave") def direct_message(self, message: str, user_id: str) -> MessageResponse: + time.sleep(5) return MessageResponse("Let's discuss this in a channel!") def channel_mention(self, message: str, user_id: str) -> MessageResponse: + time.sleep(5) return MessageResponse(f"Hello <@{user_id}>")