From 67fa7355559bebce8eda9729438b2f6e17f85214 Mon Sep 17 00:00:00 2001 From: jack1142 <6032823+jack1142@users.noreply.github.com> Date: Mon, 5 Apr 2021 21:39:33 +0200 Subject: [PATCH] Use partial messages in Streams cog to avoid potential leakage (#4742) * Use partial messages in Streams cog to avoid leakage * Stop trying to save bot object to Config... * Put guild id as part of message data * Fix AttributeError * Pass bot object to stream classes in commands * ugh * Another place we use this class in * more... --- redbot/cogs/streams/streams.py | 73 ++++++++++++++++++------------ redbot/cogs/streams/streamtypes.py | 39 +++++++++++----- 2 files changed, 70 insertions(+), 42 deletions(-) diff --git a/redbot/cogs/streams/streams.py b/redbot/cogs/streams/streams.py index 33d63a486d7..06d2f48fb53 100644 --- a/redbot/cogs/streams/streams.py +++ b/redbot/cogs/streams/streams.py @@ -209,6 +209,7 @@ async def twitchstream(self, ctx: commands.Context, channel_name: str): await self.maybe_renew_twitch_bearer_token() token = (await self.bot.get_shared_api_tokens("twitch")).get("client_id") stream = TwitchStream( + _bot=self.bot, name=channel_name, token=token, bearer=self.ttv_bearer_cache.get("access_token", None), @@ -224,21 +225,25 @@ async def youtubestream(self, ctx: commands.Context, channel_id_or_name: str): apikey = await self.bot.get_shared_api_tokens("youtube") is_name = self.check_name_or_id(channel_id_or_name) if is_name: - stream = YoutubeStream(name=channel_id_or_name, token=apikey, config=self.config) + stream = YoutubeStream( + _bot=self.bot, name=channel_id_or_name, token=apikey, config=self.config + ) else: - stream = YoutubeStream(id=channel_id_or_name, token=apikey, config=self.config) + stream = YoutubeStream( + _bot=self.bot, id=channel_id_or_name, token=apikey, config=self.config + ) await self.check_online(ctx, stream) @commands.command() async def smashcast(self, ctx: commands.Context, channel_name: str): """Check if a smashcast channel is live.""" - stream = HitboxStream(name=channel_name) + stream = HitboxStream(_bot=self.bot, name=channel_name) await self.check_online(ctx, stream) @commands.command() async def picarto(self, ctx: commands.Context, channel_name: str): """Check if a Picarto channel is live.""" - stream = PicartoStream(name=channel_name) + stream = PicartoStream(_bot=self.bot, name=channel_name) await self.check_online(ctx, stream) async def check_online( @@ -396,19 +401,22 @@ async def stream_alert(self, ctx: commands.Context, _class, channel_name): is_yt = _class.__name__ == "YoutubeStream" is_twitch = _class.__name__ == "TwitchStream" if is_yt and not self.check_name_or_id(channel_name): - stream = _class(id=channel_name, token=token, config=self.config) + stream = _class(_bot=self.bot, id=channel_name, token=token, config=self.config) elif is_twitch: await self.maybe_renew_twitch_bearer_token() stream = _class( + _bot=self.bot, name=channel_name, token=token.get("client_id"), bearer=self.ttv_bearer_cache.get("access_token", None), ) else: if is_yt: - stream = _class(name=channel_name, token=token, config=self.config) + stream = _class( + _bot=self.bot, name=channel_name, token=token, config=self.config + ) else: - stream = _class(name=channel_name, token=token) + stream = _class(_bot=self.bot, name=channel_name, token=token) try: exists = await self.check_exists(stream) except InvalidTwitchCredentials: @@ -714,14 +722,23 @@ async def _stream_alerts(self): await asyncio.sleep(await self.config.refresh_timer()) async def _send_stream_alert( - self, stream, channel: discord.TextChannel, embed: discord.Embed, content: str = None + self, + stream, + channel: discord.TextChannel, + embed: discord.Embed, + content: str = None, + *, + is_schedule: bool = False, ): m = await channel.send( content, embed=embed, allowed_mentions=discord.AllowedMentions(roles=True, everyone=True), ) - stream._messages_cache.append(m) + message_data = {"guild": m.guild.id, "channel": m.channel.id, "message": m.id} + if is_schedule: + message_data["is_schedule"] = True + stream.messages.append(message_data) async def check_streams(self): to_remove = [] @@ -744,19 +761,25 @@ async def check_streams(self): to_remove.append(stream) continue except OfflineStream: - if not stream._messages_cache: + if not stream.messages: continue - for message in stream._messages_cache: - if await self.bot.cog_disabled_in_guild(self, message.guild): + + for msg_data in stream.iter_messages(): + partial_msg = msg_data["partial_message"] + if partial_msg is None: + continue + if await self.bot.cog_disabled_in_guild(self, partial_msg.guild): continue - autodelete = await self.config.guild(message.guild).autodelete() - if autodelete: - with contextlib.suppress(discord.NotFound): - await message.delete() - stream._messages_cache.clear() + if not await self.config.guild(partial_msg.guild).autodelete(): + continue + + with contextlib.suppress(discord.NotFound): + await partial_msg.delete() + + stream.messages.clear() await self.save_streams() else: - if stream._messages_cache: + if stream.messages: continue for channel_id in stream.channels: channel = self.bot.get_channel(channel_id) @@ -772,7 +795,7 @@ async def check_streams(self): continue if is_schedule: # skip messages and mentions - await self._send_stream_alert(stream, channel, embed) + await self._send_stream_alert(stream, channel, embed, is_schedule=True) await self.save_streams() continue await set_contextual_locales_from_guild(self.bot, channel.guild) @@ -874,17 +897,6 @@ async def load_streams(self): _class = getattr(_streamtypes, raw_stream["type"], None) if not _class: continue - raw_msg_cache = raw_stream["messages"] - raw_stream["_messages_cache"] = [] - for raw_msg in raw_msg_cache: - chn = self.bot.get_channel(raw_msg["channel"]) - if chn is not None: - try: - msg = await chn.fetch_message(raw_msg["message"]) - except discord.HTTPException: - pass - else: - raw_stream["_messages_cache"].append(msg) token = await self.bot.get_shared_api_tokens(_class.token_name) if token: if _class.__name__ == "TwitchStream": @@ -894,6 +906,7 @@ async def load_streams(self): if _class.__name__ == "YoutubeStream": raw_stream["config"] = self.config raw_stream["token"] = token + raw_stream["_bot"] = self.bot streams.append(_class(**raw_stream)) return streams diff --git a/redbot/cogs/streams/streamtypes.py b/redbot/cogs/streams/streamtypes.py index 3e5ae817a3b..5a12d44e890 100644 --- a/redbot/cogs/streams/streamtypes.py +++ b/redbot/cogs/streams/streamtypes.py @@ -58,10 +58,11 @@ class Stream: token_name: ClassVar[Optional[str]] = None def __init__(self, **kwargs): + self._bot = kwargs.pop("_bot") self.name = kwargs.pop("name", None) self.channels = kwargs.pop("channels", []) # self.already_online = kwargs.pop("already_online", False) - self._messages_cache = kwargs.pop("_messages_cache", []) + self.messages = kwargs.pop("messages", []) self.type = self.__class__.__name__ async def is_online(self): @@ -70,14 +71,24 @@ async def is_online(self): def make_embed(self): raise NotImplementedError() + def iter_messages(self): + for msg_data in self.messages: + data = msg_data.copy() + # "guild" key might not exist for old config data (available since GH-4742) + if guild_id := msg_data.get("guild"): + guild = self._bot.get_guild(guild_id) + channel = guild and guild.get_channel(msg_data["channel"]) + else: + channel = self._bot.get_channel(msg_data["channel"]) + if channel is not None: + data["partial_message"] = channel.get_partial_message(data["message"]) + yield data + def export(self): data = {} for k, v in self.__dict__.items(): if not k.startswith("_"): data[k] = v - data["messages"] = [] - for m in self._messages_cache: - data["messages"].append({"channel": m.channel.id, "message": m.id}) return data def __repr__(self): @@ -211,17 +222,21 @@ async def make_embed(self, data): embed.timestamp = start_time is_schedule = True else: - # repost message + # delete the message(s) about the stream schedule to_remove = [] - for message in self._messages_cache: - if message.embeds[0].description is discord.Embed.Empty: + for msg_data in self.iter_messages(): + if not msg_data.get("is_schedule", False): continue - with contextlib.suppress(Exception): - autodelete = await self._config.guild(message.guild).autodelete() + partial_msg = msg_data["partial_message"] + if partial_msg is not None: + autodelete = await self._config.guild(partial_msg.guild).autodelete() if autodelete: - await message.delete() - to_remove.append(message.id) - self._messages_cache = [x for x in self._messages_cache if x.id not in to_remove] + with contextlib.suppress(discord.NotFound): + await partial_msg.delete() + to_remove.append(msg_data["message"]) + self.messages = [ + data for data in self.messages if data["message"] not in to_remove + ] embed.set_author(name=channel_title) embed.set_image(url=rnd(thumbnail)) embed.colour = 0x9255A5