Skip to content

Commit

Permalink
Merge pull request #89 from alan-turing-institute/llama2-query-queue
Browse files Browse the repository at this point in the history
Adding queuing system
  • Loading branch information
rchan26 authored Sep 15, 2023
2 parents 9d37254 + bfb9ea9 commit 72c186d
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 24 deletions.
24 changes: 15 additions & 9 deletions slack_bot/run.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import argparse
import asyncio
import logging
import os
import pathlib
import sys
import threading

from slack_sdk.socket_mode import SocketModeClient
from slack_sdk.web import WebClient
from slack_sdk.socket_mode.aiohttp import SocketModeClient
from slack_sdk.web.async_client import AsyncWebClient

from slack_bot import MODELS, Bot

Expand All @@ -17,7 +17,7 @@
DEFAULT_HF_MODEL = "StabilityAI/stablelm-tuned-alpha-3b"


if __name__ == "__main__":
async def main():
# Parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -214,19 +214,25 @@
logging.error("SLACK_APP_TOKEN is not set")
sys.exit(1)

# Initialize SocketModeClient with an app-level token + WebClient
# Initialize SocketModeClient with an app-level token + AsyncWebClient
client = SocketModeClient(
# This app-level token will be used only for establishing a connection
app_token=os.environ.get("SLACK_APP_TOKEN"),
# You will be using this WebClient for performing Web API calls in listeners
web_client=WebClient(token=os.environ.get("SLACK_BOT_TOKEN")),
# You will be using this AsyncWebClient for performing Web API calls in listeners
web_client=AsyncWebClient(token=os.environ.get("SLACK_BOT_TOKEN")),
# To ensure connection doesn't go stale - we can adjust as needed.
ping_interval=60,
)

# Add a new listener to receive messages from Slack
client.socket_mode_request_listeners.append(slack_bot)
# Establish a WebSocket connection to the Socket Mode servers
client.connect()
await client.connect()

# Listen for events
logging.info("Listening for requests...")
threading.Event().wait()
await asyncio.sleep(float("inf"))


if __name__ == "__main__":
asyncio.run(main())
59 changes: 44 additions & 15 deletions slack_bot/slack_bot/bot/bot.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,53 @@
import asyncio
import logging

from slack_sdk.socket_mode import SocketModeClient
from slack_sdk.socket_mode.listeners import SocketModeRequestListener
from slack_sdk.socket_mode.aiohttp import SocketModeClient
from slack_sdk.socket_mode.async_listeners import AsyncSocketModeRequestListener
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse

from ..models.base import ResponseModel


class Bot(SocketModeRequestListener):
class Bot(AsyncSocketModeRequestListener):
def __init__(self, model: ResponseModel) -> None:
self.model = model

def __call__(self, client: SocketModeClient, req: SocketModeRequest) -> None:
self.queue = asyncio.Queue(maxsize=10)

async def __call__(self, client: SocketModeClient, req: SocketModeRequest) -> None:
self.queue.put_nowait(self._process_request(client, req))
logging.info(f"There are currently {self.queue.qsize()} items in the queue.")

# Create three worker tasks to process the queue concurrently.
tasks = []
for i in range(3):
task = asyncio.create_task(self.worker(self.queue))
tasks.append(task)

# await self.queue.join()

# for task in tasks:
# task.cancel()

@staticmethod
async def worker(queue):
while True:
coro = await queue.get()
await coro
# Notify the queue that the "work item" has been processed.
queue.task_done()

async def _process_request(
self, client: SocketModeClient, req: SocketModeRequest
) -> None:
if req.type != "events_api":
logging.info(f"Received unexpected request of type '{req.type}'")
return None

# Acknowledge the request
logging.info(f"Received an events_api request")
logging.info("Received an events_api request")
response = SocketModeResponse(envelope_id=req.envelope_id)
client.send_socket_mode_response(response)
await client.send_socket_mode_response(response)

try:
# Extract event from payload
Expand All @@ -45,28 +72,28 @@ def __call__(self, client: SocketModeClient, req: SocketModeRequest) -> None:

# If this is a direct message to REGinald...
if event_type == "message" and event_subtype is None:
self.react(client, event["channel"], event["ts"])
await self.react(client, event["channel"], event["ts"])
model_response = self.model.direct_message(message, user_id)

# If @REGinald is mentioned in a channel
elif event_type == "app_mention":
self.react(client, event["channel"], event["ts"])
await self.react(client, event["channel"], event["ts"])
model_response = self.model.channel_mention(message, user_id)

# Otherwise
else:
logging.info(f"Received unexpected event of type '{event['type']}'.")
return None

# Add an emoji and a reply as required
# Add a reply as required
if model_response and model_response.message:
logging.info(f"Posting reply {model_response.message}.")
client.web_client.chat_postMessage(
await client.web_client.chat_postMessage(
channel=event["channel"],
text=f"<@{user_id}>, you asked me: '{message}'.\n{model_response.message}",
)
else:
logging.info(f"No reply was generated.")
logging.info("No reply was generated.")

except KeyError as exc:
logging.warning(f"Attempted to access key that does not exist.\n{str(exc)}")
Expand All @@ -77,14 +104,16 @@ def __call__(self, client: SocketModeClient, req: SocketModeRequest) -> None:
)
raise

def react(self, client: SocketModeClient, channel: str, timestamp: str) -> None:
async def react(
self, client: SocketModeClient, channel: str, timestamp: str
) -> None:
"""Emoji react to the input message"""
if self.model.emoji:
logging.info(f"Reacting with emoji {self.model.emoji}.")
client.web_client.reactions_add(
await client.web_client.reactions_add(
name=self.model.emoji,
channel=channel,
timestamp=timestamp,
)
else:
logging.info(f"No emoji defined for this model.")
logging.info("No emoji defined for this model.")
4 changes: 4 additions & 0 deletions slack_bot/slack_bot/models/hello.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import time

from .base import MessageResponse, ResponseModel


Expand All @@ -6,7 +8,9 @@ def __init__(self):
super().__init__(emoji="wave")

def direct_message(self, message: str, user_id: str) -> MessageResponse:
time.sleep(5)
return MessageResponse("Let's discuss this in a channel!")

def channel_mention(self, message: str, user_id: str) -> MessageResponse:
time.sleep(5)
return MessageResponse(f"Hello <@{user_id}>")

0 comments on commit 72c186d

Please sign in to comment.