Skip to content

Commit

Permalink
Fix/channel join (#44)
Browse files Browse the repository at this point in the history
* Populate channels in `initial_channels` of each bot instance

* Increase max user per bot instance
Remove RetryableWSConnection
Add "testing" guards for handling requests
Remove irc package from requirements.txt
  • Loading branch information
aticie authored Jul 5, 2022
1 parent 1bf666e commit 6e40ff3
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 106 deletions.
30 changes: 16 additions & 14 deletions ronnia/bots/bot_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,15 @@ def get_streams(self, user_ids: List[int]):


class TwitchProcess(Process):
def __init__(self, user_list: List[int], join_lock: Lock):
def __init__(self, user_list: List[str], join_lock: Lock, max_users: int):
super().__init__()
self.join_lock = join_lock
self.user_list = user_list
self.max_users = max_users
self.bot = None

def initialize(self):
self.bot = TwitchBot(initial_channel_ids=self.user_list, join_lock=self.join_lock)
self.bot = TwitchBot(initial_channel_names=self.user_list, join_lock=self.join_lock, max_users=self.max_users)

def run(self) -> None:
self.initialize()
Expand All @@ -84,8 +85,8 @@ def __init__(self, ):
self.twitch_client = TwitchAPI(os.getenv('TWITCH_CLIENT_ID'), os.getenv('TWITCH_CLIENT_SECRET'))
self._loop = asyncio.get_event_loop()

self.user_per_instance = 100
self.sleep_after_instance = (self.user_per_instance // 20 + 1) * 10
self.user_per_instance = 150
self.sleep_after_instance = (self.user_per_instance // 20 + 1) * 11

self.servicebus_connection_string = os.getenv('SERVICE_BUS_CONNECTION_STR')
self.servicebus_webserver_queue_name = 'webserver-signups'
Expand Down Expand Up @@ -120,19 +121,20 @@ def start(self):
self._loop.run_until_complete(self.initialize_queues())
logger.info("Queues initialized")
all_users = self.users_db.execute('SELECT * FROM users;').fetchall()
all_user_twitch_names = [user[2] for user in all_users]
all_user_twitch_ids = [user[4] for user in all_users]
streaming_user_ids = [user['user_id'] for user in self.twitch_client.get_streams(all_user_twitch_ids)]
streaming_user_names = [user['user_login'] for user in self.twitch_client.get_streams(all_user_twitch_ids)]

for user_id in all_user_twitch_ids:
if user_id not in streaming_user_ids:
streaming_user_ids.append(user_id)
for user_id in all_user_twitch_names:
if user_id not in streaming_user_names:
streaming_user_names.append(user_id)

logger.info(f"Collected users: {len(streaming_user_ids)}")
for user_id_list in batcher(streaming_user_ids, self.user_per_instance):
p = TwitchProcess(user_id_list, self.join_lock)
logger.info(f"Collected users: {len(streaming_user_names)}")
for user_names_list in batcher(streaming_user_names, self.user_per_instance):
p = TwitchProcess(user_list=user_names_list, join_lock=self.join_lock, max_users=self.user_per_instance)
p.start()
logger.info(f"Started Twitch bot instance for {len(user_id_list)} users")
self.bot_processes[p] = user_id_list
logger.info(f"Started Twitch bot instance for {len(user_names_list)} users")
self.bot_processes[p] = user_names_list
# 20 join rate per 10 seconds
time.sleep(self.sleep_after_instance)

Expand Down Expand Up @@ -231,7 +233,7 @@ async def parse_and_send_message(self, message):
"""
if self.create_new_instance:
logger.info(f"Started a new bot instance. This is the {len(self.bot_processes)}th instance.")
p = TwitchProcess([], self.join_lock)
p = TwitchProcess([], self.join_lock, max_users=self.user_per_instance)
p.start()
self.create_new_instance = False

Expand Down
88 changes: 22 additions & 66 deletions ronnia/bots/twitch_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
import os
import sqlite3
import time
import traceback
from abc import ABC
from multiprocessing import Lock
Expand All @@ -21,15 +20,14 @@
from ronnia.helpers.beatmap_link_parser import parse_beatmap_link
from ronnia.helpers.database_helper import UserDatabase, StatisticsDatabase
from ronnia.helpers.utils import convert_seconds_to_readable
from websocket.ws import RetryableWSConnection

logger = logging.getLogger(__name__)


class TwitchBot(commands.Bot, ABC):
PER_REQUEST_COOLDOWN = 30 # each request has 30 seconds cooldown

def __init__(self, initial_channel_ids: List[int], join_lock: Lock):
def __init__(self, initial_channel_names: List[str], join_lock: Lock, max_users: int):
self.users_db = UserDatabase()
self.messages_db = StatisticsDatabase()
self.osu_api = OsuApiV2(os.getenv('OSU_CLIENT_ID'), os.getenv('OSU_CLIENT_SECRET'))
Expand All @@ -42,24 +40,13 @@ def __init__(self, initial_channel_ids: List[int], join_lock: Lock):
'client_id': os.getenv('TWITCH_CLIENT_ID'),
'client_secret': os.getenv('TWITCH_CLIENT_SECRET'),
'prefix': os.getenv('BOT_PREFIX'),
'heartbeat': 20
'initial_channels': [os.getenv('BOT_NICK'), *initial_channel_names]
}
logger.debug(f'Sending args to super().__init__: {args}')
super().__init__(**args)

conn_args = {
'token': token,
'initial_channels': [os.getenv('BOT_NICK')],
'heartbeat': 30
}
self._connection = RetryableWSConnection(
client=self,
loop=self.loop,
**conn_args
)

self.initial_channel_names = initial_channel_names
self.environment = os.getenv('ENVIRONMENT')
self.connected_channel_ids = initial_channel_ids
self.servicebus_connection_string = os.getenv('SERVICE_BUS_CONNECTION_STR')
self.servicebus_client = ServiceBusClient.from_connection_string(conn_str=self.servicebus_connection_string)
self.signup_queue_name = 'bot-signups'
Expand All @@ -70,19 +57,14 @@ def __init__(self, initial_channel_ids: List[int], join_lock: Lock):
self.main_prefix = None
self.user_last_request = {}

self.join_channels_first_time = True
self.max_users = 100

async def join_channels(self, channels: Union[List[str], Tuple[str]]):
with self._join_lock:
await super(TwitchBot, self).join_channels(channels)
self.max_users = max_users

async def servicebus_message_receiver(self):
"""
Start a queue listener for messages from the website sign-up.
"""
# Each instance of bot can only have one 50 users.
if len(self.connected_channel_ids) == self.max_users:
# Each instance of bot can only have 100 users.
if len(self.initial_channel_names) >= self.max_users:
logger.info(f'Reached {self.max_users} members, stopped listening to sign-up queue.')
return

Expand All @@ -102,7 +84,7 @@ async def servicebus_message_receiver(self):
logger.info(f'Sending reply message to sign-up queue: {reply_message}')
await sender.send_messages(reply_message)

if len(self.connected_channel_ids) == 100:
if len(self.initial_channel_names) == 100:
logger.warning(
'Reached 100 members, sending manager signal to create a new process.')
bot_full_message = ServiceBusMessage("bot-full")
Expand All @@ -128,7 +110,7 @@ async def receive_and_parse_message(self, message):
osu_id = message_dict['osu_id']
twitch_id = message_dict['twitch_id']

self.connected_channel_ids.append(twitch_id)
self.initial_channel_names.append(twitch_username)
await self.users_db.add_user(twitch_username=twitch_username,
twitch_id=twitch_id,
osu_username=osu_username,
Expand Down Expand Up @@ -163,6 +145,9 @@ async def event_message(self, message: Message):
return
logger.info(f"{message.channel.name} - {message.author.name}: {message.content}")

if self.environment == "testing":
return

await self.handle_commands(message)
try:
await self.check_channel_enabled(message.channel.name)
Expand All @@ -187,10 +172,10 @@ async def handle_request(self, message: Message):
beatmap_info=beatmap_info,
beatmapset_info=beatmapset_info)

await self._send_beatmap_to_irc(message=message,
beatmap_info=beatmap_info,
beatmapset_info=beatmapset_info,
given_mods=given_mods)
await self._send_beatmap_to_in_game(message=message,
beatmap_info=beatmap_info,
beatmapset_info=beatmapset_info,
given_mods=given_mods)
await self.messages_db.add_request(requested_beatmap_id=int(beatmap_info['id']),
requested_channel_name=message.channel.name,
requester_channel_name=message.author.name,
Expand Down Expand Up @@ -326,7 +311,8 @@ async def _prune_cooldowns(self, time_right_now: datetime.datetime):

return

async def _send_beatmap_to_irc(self, message: Message, beatmap_info: dict, beatmapset_info: dict, given_mods: str):
async def _send_beatmap_to_in_game(self, message: Message, beatmap_info: dict, beatmapset_info: dict,
given_mods: str):
"""
Sends the beatmap request message to osu!irc bot
:param message: Twitch Message object
Expand Down Expand Up @@ -419,17 +405,7 @@ async def event_ready(self):
await self.messages_db.initialize()

logger.debug(f'Successfully initialized databases!')

logger.debug(f'Populating users: {self.connected_channel_ids}')
channel_names = await self.fetch_users(ids=self.connected_channel_ids)
channels_to_join = [ch.name for ch in channel_names]

logger.info(f'Joining channels: {channels_to_join}')
# Join channels
channel_join_start = time.time()
await self.join_channels(channels_to_join)

logger.info(f'Joined {len(self.connected_channels)} after {time.time() - channel_join_start:.2f}s')
logger.debug(f'Started bot instance with: {self.initial_channel_names}')
logger.info(f'Connected channels: {self.connected_channels}')

initial_extensions = ['cogs.admin_cog']
Expand All @@ -440,44 +416,24 @@ async def event_ready(self):
self.loop.create_task(self.servicebus_message_receiver())
self.routine_update_user_information.start(stop_on_error=False)
self.routine_show_connected_channels.start(stop_on_error=False)
self.routine_join_channels.start(stop_on_error=False)

logger.info(f'Successfully initialized bot!')
logger.info(f'Ready | {self.nick}')

@routines.routine(minutes=1)
async def routine_show_connected_channels(self):
connected_channel_names = [channel.name for channel in self.connected_channels]
connected_channel_names = [channel.name for channel in list(filter(None, self.connected_channels))]
logger.info(f'Connected channels: {connected_channel_names}')

@routines.routine(hours=1)
async def routine_join_channels(self):
logger.info('Started join channels routine')
if self.join_channels_first_time:
self.join_channels_first_time = False
return
all_user_details = await self.users_db.get_multiple_users(twitch_ids=self.connected_channel_ids)
twitch_users = {user['twitch_username'] for user in all_user_details}
connected_channels = {chan.name for chan in self.connected_channels}
unconnected_channels = (twitch_users - connected_channels)
unconnected_channels.update(set(self.channels_join_failed))
channels_to_join = list(unconnected_channels)
logger.info(f'Users from database: {twitch_users}')
logger.info(f'self.connected_channels: {connected_channels}')
logger.info(f'Failed connections: {self.channels_join_failed}')
logger.info(f'Joining channels: {channels_to_join}')
self.channels_join_failed = []
await self.join_channels(channels_to_join)

@routines.routine(hours=1)
async def routine_update_user_information(self):
"""
Checks and updates user information changes. This routine runs every hour.
:return:
"""
logger.info('Started user information update routine')
connected_users = await self.users_db.get_multiple_users(self.connected_channel_ids)
twitch_users = await self.fetch_users(ids=self.connected_channel_ids)
connected_users = await self.users_db.get_multiple_users_by_username(self.initial_channel_names)
twitch_users = await self.fetch_users(names=self.initial_channel_names)
twitch_users_by_id = {user.id: user for user in twitch_users}

if len(twitch_users) != len(connected_users):
Expand Down Expand Up @@ -526,7 +482,7 @@ async def handle_banned_users(self, connected_users: List[sqlite3.Row], twitch_u
for connected_user in connected_users:
if connected_user['twitch_id'] not in twitch_user_ids:
logger.info(f'{connected_user["twitch_username"]} does not exist anymore!')
self.connected_channel_ids.remove(connected_user['twitch_id'])
self.initial_channel_names.remove(connected_user['twitch_username'])
await self.users_db.remove_user(connected_user['twitch_username'])
else:
existing_users.append(connected_user)
Expand Down
13 changes: 12 additions & 1 deletion ronnia/helpers/database_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ async def remove_user(self, twitch_username: str) -> None:
await self.c.execute(f"DELETE FROM users WHERE twitch_username=?", (twitch_username,))
await self.conn.commit()

async def get_multiple_users(self, twitch_ids: List[int]) -> List[sqlite3.Row]:
async def get_multiple_users_by_ids(self, twitch_ids: List[int]) -> List[sqlite3.Row]:
"""
Gets multiple users from database
:param twitch_ids: List of twitch ids
Expand All @@ -177,6 +177,17 @@ async def get_multiple_users(self, twitch_ids: List[int]) -> List[sqlite3.Row]:
users = await result.fetchall()
return users

async def get_multiple_users_by_username(self, twitch_names: List[str]) -> List[sqlite3.Row]:
"""
Gets multiple users from database
:param twitch_ids: List of twitch ids
:return: List of users
"""
query = f"SELECT * FROM users WHERE twitch_username IN ({','.join('?' for i in twitch_names)})"
result = await self.c.execute(query, twitch_names)
users = await result.fetchall()
return users

async def get_user_from_osu_username(self, osu_username: str) -> sqlite3.Row:
"""
Gets the user details from database using osu username
Expand Down
1 change: 0 additions & 1 deletion ronnia/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
irc>=19.0.0,<21.0.0
twitchio==2.1.5
aiosqlite==0.17.0
azure-servicebus==7.6.0
24 changes: 0 additions & 24 deletions ronnia/websocket/ws.py

This file was deleted.

0 comments on commit 6e40ff3

Please sign in to comment.