Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/channel join #44

Merged
merged 2 commits into from
Jul 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions ronnia/bots/bot_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,15 @@ def get_streams(self, user_ids: List[int]):


class TwitchProcess(Process):
def __init__(self, user_list: List[int], join_lock: Lock):
def __init__(self, user_list: List[str], join_lock: Lock, max_users: int):
super().__init__()
self.join_lock = join_lock
self.user_list = user_list
self.max_users = max_users
self.bot = None

def initialize(self):
self.bot = TwitchBot(initial_channel_ids=self.user_list, join_lock=self.join_lock)
self.bot = TwitchBot(initial_channel_names=self.user_list, join_lock=self.join_lock, max_users=self.max_users)

def run(self) -> None:
self.initialize()
Expand All @@ -84,8 +85,8 @@ def __init__(self, ):
self.twitch_client = TwitchAPI(os.getenv('TWITCH_CLIENT_ID'), os.getenv('TWITCH_CLIENT_SECRET'))
self._loop = asyncio.get_event_loop()

self.user_per_instance = 100
self.sleep_after_instance = (self.user_per_instance // 20 + 1) * 10
self.user_per_instance = 150
self.sleep_after_instance = (self.user_per_instance // 20 + 1) * 11

self.servicebus_connection_string = os.getenv('SERVICE_BUS_CONNECTION_STR')
self.servicebus_webserver_queue_name = 'webserver-signups'
Expand Down Expand Up @@ -120,19 +121,20 @@ def start(self):
self._loop.run_until_complete(self.initialize_queues())
logger.info("Queues initialized")
all_users = self.users_db.execute('SELECT * FROM users;').fetchall()
all_user_twitch_names = [user[2] for user in all_users]
all_user_twitch_ids = [user[4] for user in all_users]
streaming_user_ids = [user['user_id'] for user in self.twitch_client.get_streams(all_user_twitch_ids)]
streaming_user_names = [user['user_login'] for user in self.twitch_client.get_streams(all_user_twitch_ids)]

for user_id in all_user_twitch_ids:
if user_id not in streaming_user_ids:
streaming_user_ids.append(user_id)
for user_id in all_user_twitch_names:
if user_id not in streaming_user_names:
streaming_user_names.append(user_id)

logger.info(f"Collected users: {len(streaming_user_ids)}")
for user_id_list in batcher(streaming_user_ids, self.user_per_instance):
p = TwitchProcess(user_id_list, self.join_lock)
logger.info(f"Collected users: {len(streaming_user_names)}")
for user_names_list in batcher(streaming_user_names, self.user_per_instance):
p = TwitchProcess(user_list=user_names_list, join_lock=self.join_lock, max_users=self.user_per_instance)
p.start()
logger.info(f"Started Twitch bot instance for {len(user_id_list)} users")
self.bot_processes[p] = user_id_list
logger.info(f"Started Twitch bot instance for {len(user_names_list)} users")
self.bot_processes[p] = user_names_list
# 20 join rate per 10 seconds
time.sleep(self.sleep_after_instance)

Expand Down Expand Up @@ -231,7 +233,7 @@ async def parse_and_send_message(self, message):
"""
if self.create_new_instance:
logger.info(f"Started a new bot instance. This is the {len(self.bot_processes)}th instance.")
p = TwitchProcess([], self.join_lock)
p = TwitchProcess([], self.join_lock, max_users=self.user_per_instance)
p.start()
self.create_new_instance = False

Expand Down
88 changes: 22 additions & 66 deletions ronnia/bots/twitch_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
import os
import sqlite3
import time
import traceback
from abc import ABC
from multiprocessing import Lock
Expand All @@ -21,15 +20,14 @@
from ronnia.helpers.beatmap_link_parser import parse_beatmap_link
from ronnia.helpers.database_helper import UserDatabase, StatisticsDatabase
from ronnia.helpers.utils import convert_seconds_to_readable
from websocket.ws import RetryableWSConnection

logger = logging.getLogger(__name__)


class TwitchBot(commands.Bot, ABC):
PER_REQUEST_COOLDOWN = 30 # each request has 30 seconds cooldown

def __init__(self, initial_channel_ids: List[int], join_lock: Lock):
def __init__(self, initial_channel_names: List[str], join_lock: Lock, max_users: int):
self.users_db = UserDatabase()
self.messages_db = StatisticsDatabase()
self.osu_api = OsuApiV2(os.getenv('OSU_CLIENT_ID'), os.getenv('OSU_CLIENT_SECRET'))
Expand All @@ -42,24 +40,13 @@ def __init__(self, initial_channel_ids: List[int], join_lock: Lock):
'client_id': os.getenv('TWITCH_CLIENT_ID'),
'client_secret': os.getenv('TWITCH_CLIENT_SECRET'),
'prefix': os.getenv('BOT_PREFIX'),
'heartbeat': 20
'initial_channels': [os.getenv('BOT_NICK'), *initial_channel_names]
}
logger.debug(f'Sending args to super().__init__: {args}')
super().__init__(**args)

conn_args = {
'token': token,
'initial_channels': [os.getenv('BOT_NICK')],
'heartbeat': 30
}
self._connection = RetryableWSConnection(
client=self,
loop=self.loop,
**conn_args
)

self.initial_channel_names = initial_channel_names
self.environment = os.getenv('ENVIRONMENT')
self.connected_channel_ids = initial_channel_ids
self.servicebus_connection_string = os.getenv('SERVICE_BUS_CONNECTION_STR')
self.servicebus_client = ServiceBusClient.from_connection_string(conn_str=self.servicebus_connection_string)
self.signup_queue_name = 'bot-signups'
Expand All @@ -70,19 +57,14 @@ def __init__(self, initial_channel_ids: List[int], join_lock: Lock):
self.main_prefix = None
self.user_last_request = {}

self.join_channels_first_time = True
self.max_users = 100

async def join_channels(self, channels: Union[List[str], Tuple[str]]):
with self._join_lock:
await super(TwitchBot, self).join_channels(channels)
self.max_users = max_users

async def servicebus_message_receiver(self):
"""
Start a queue listener for messages from the website sign-up.
"""
# Each instance of bot can only have one 50 users.
if len(self.connected_channel_ids) == self.max_users:
# Each instance of bot can only have 100 users.
if len(self.initial_channel_names) >= self.max_users:
logger.info(f'Reached {self.max_users} members, stopped listening to sign-up queue.')
return

Expand All @@ -102,7 +84,7 @@ async def servicebus_message_receiver(self):
logger.info(f'Sending reply message to sign-up queue: {reply_message}')
await sender.send_messages(reply_message)

if len(self.connected_channel_ids) == 100:
if len(self.initial_channel_names) == 100:
logger.warning(
'Reached 100 members, sending manager signal to create a new process.')
bot_full_message = ServiceBusMessage("bot-full")
Expand All @@ -128,7 +110,7 @@ async def receive_and_parse_message(self, message):
osu_id = message_dict['osu_id']
twitch_id = message_dict['twitch_id']

self.connected_channel_ids.append(twitch_id)
self.initial_channel_names.append(twitch_username)
await self.users_db.add_user(twitch_username=twitch_username,
twitch_id=twitch_id,
osu_username=osu_username,
Expand Down Expand Up @@ -163,6 +145,9 @@ async def event_message(self, message: Message):
return
logger.info(f"{message.channel.name} - {message.author.name}: {message.content}")

if self.environment == "testing":
return

await self.handle_commands(message)
try:
await self.check_channel_enabled(message.channel.name)
Expand All @@ -187,10 +172,10 @@ async def handle_request(self, message: Message):
beatmap_info=beatmap_info,
beatmapset_info=beatmapset_info)

await self._send_beatmap_to_irc(message=message,
beatmap_info=beatmap_info,
beatmapset_info=beatmapset_info,
given_mods=given_mods)
await self._send_beatmap_to_in_game(message=message,
beatmap_info=beatmap_info,
beatmapset_info=beatmapset_info,
given_mods=given_mods)
await self.messages_db.add_request(requested_beatmap_id=int(beatmap_info['id']),
requested_channel_name=message.channel.name,
requester_channel_name=message.author.name,
Expand Down Expand Up @@ -326,7 +311,8 @@ async def _prune_cooldowns(self, time_right_now: datetime.datetime):

return

async def _send_beatmap_to_irc(self, message: Message, beatmap_info: dict, beatmapset_info: dict, given_mods: str):
async def _send_beatmap_to_in_game(self, message: Message, beatmap_info: dict, beatmapset_info: dict,
given_mods: str):
"""
Sends the beatmap request message to osu!irc bot
:param message: Twitch Message object
Expand Down Expand Up @@ -419,17 +405,7 @@ async def event_ready(self):
await self.messages_db.initialize()

logger.debug(f'Successfully initialized databases!')

logger.debug(f'Populating users: {self.connected_channel_ids}')
channel_names = await self.fetch_users(ids=self.connected_channel_ids)
channels_to_join = [ch.name for ch in channel_names]

logger.info(f'Joining channels: {channels_to_join}')
# Join channels
channel_join_start = time.time()
await self.join_channels(channels_to_join)

logger.info(f'Joined {len(self.connected_channels)} after {time.time() - channel_join_start:.2f}s')
logger.debug(f'Started bot instance with: {self.initial_channel_names}')
logger.info(f'Connected channels: {self.connected_channels}')

initial_extensions = ['cogs.admin_cog']
Expand All @@ -440,44 +416,24 @@ async def event_ready(self):
self.loop.create_task(self.servicebus_message_receiver())
self.routine_update_user_information.start(stop_on_error=False)
self.routine_show_connected_channels.start(stop_on_error=False)
self.routine_join_channels.start(stop_on_error=False)

logger.info(f'Successfully initialized bot!')
logger.info(f'Ready | {self.nick}')

@routines.routine(minutes=1)
async def routine_show_connected_channels(self):
connected_channel_names = [channel.name for channel in self.connected_channels]
connected_channel_names = [channel.name for channel in list(filter(None, self.connected_channels))]
logger.info(f'Connected channels: {connected_channel_names}')

@routines.routine(hours=1)
async def routine_join_channels(self):
logger.info('Started join channels routine')
if self.join_channels_first_time:
self.join_channels_first_time = False
return
all_user_details = await self.users_db.get_multiple_users(twitch_ids=self.connected_channel_ids)
twitch_users = {user['twitch_username'] for user in all_user_details}
connected_channels = {chan.name for chan in self.connected_channels}
unconnected_channels = (twitch_users - connected_channels)
unconnected_channels.update(set(self.channels_join_failed))
channels_to_join = list(unconnected_channels)
logger.info(f'Users from database: {twitch_users}')
logger.info(f'self.connected_channels: {connected_channels}')
logger.info(f'Failed connections: {self.channels_join_failed}')
logger.info(f'Joining channels: {channels_to_join}')
self.channels_join_failed = []
await self.join_channels(channels_to_join)

@routines.routine(hours=1)
async def routine_update_user_information(self):
"""
Checks and updates user information changes. This routine runs every hour.
:return:
"""
logger.info('Started user information update routine')
connected_users = await self.users_db.get_multiple_users(self.connected_channel_ids)
twitch_users = await self.fetch_users(ids=self.connected_channel_ids)
connected_users = await self.users_db.get_multiple_users_by_username(self.initial_channel_names)
twitch_users = await self.fetch_users(names=self.initial_channel_names)
twitch_users_by_id = {user.id: user for user in twitch_users}

if len(twitch_users) != len(connected_users):
Expand Down Expand Up @@ -526,7 +482,7 @@ async def handle_banned_users(self, connected_users: List[sqlite3.Row], twitch_u
for connected_user in connected_users:
if connected_user['twitch_id'] not in twitch_user_ids:
logger.info(f'{connected_user["twitch_username"]} does not exist anymore!')
self.connected_channel_ids.remove(connected_user['twitch_id'])
self.initial_channel_names.remove(connected_user['twitch_username'])
await self.users_db.remove_user(connected_user['twitch_username'])
else:
existing_users.append(connected_user)
Expand Down
13 changes: 12 additions & 1 deletion ronnia/helpers/database_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ async def remove_user(self, twitch_username: str) -> None:
await self.c.execute(f"DELETE FROM users WHERE twitch_username=?", (twitch_username,))
await self.conn.commit()

async def get_multiple_users(self, twitch_ids: List[int]) -> List[sqlite3.Row]:
async def get_multiple_users_by_ids(self, twitch_ids: List[int]) -> List[sqlite3.Row]:
"""
Gets multiple users from database
:param twitch_ids: List of twitch ids
Expand All @@ -177,6 +177,17 @@ async def get_multiple_users(self, twitch_ids: List[int]) -> List[sqlite3.Row]:
users = await result.fetchall()
return users

async def get_multiple_users_by_username(self, twitch_names: List[str]) -> List[sqlite3.Row]:
"""
Gets multiple users from database
:param twitch_ids: List of twitch ids
:return: List of users
"""
query = f"SELECT * FROM users WHERE twitch_username IN ({','.join('?' for i in twitch_names)})"
result = await self.c.execute(query, twitch_names)
users = await result.fetchall()
return users

async def get_user_from_osu_username(self, osu_username: str) -> sqlite3.Row:
"""
Gets the user details from database using osu username
Expand Down
1 change: 0 additions & 1 deletion ronnia/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
irc>=19.0.0,<21.0.0
twitchio==2.1.5
aiosqlite==0.17.0
azure-servicebus==7.6.0
24 changes: 0 additions & 24 deletions ronnia/websocket/ws.py

This file was deleted.