Skip to content

Commit

Permalink
Use partial messages in Streams cog to avoid potential leakage (#4742)
Browse files Browse the repository at this point in the history
* Use partial messages in Streams cog to avoid leakage

* Stop trying to save bot object to Config...

* Put guild id as part of message data

* Fix AttributeError

* Pass bot object to stream classes in commands

* ugh

* Another place we use this class in

* more...
  • Loading branch information
Jackenmen committed Apr 5, 2021
1 parent c25095b commit 67fa735
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 42 deletions.
73 changes: 43 additions & 30 deletions redbot/cogs/streams/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ async def twitchstream(self, ctx: commands.Context, channel_name: str):
await self.maybe_renew_twitch_bearer_token()
token = (await self.bot.get_shared_api_tokens("twitch")).get("client_id")
stream = TwitchStream(
_bot=self.bot,
name=channel_name,
token=token,
bearer=self.ttv_bearer_cache.get("access_token", None),
Expand All @@ -224,21 +225,25 @@ async def youtubestream(self, ctx: commands.Context, channel_id_or_name: str):
apikey = await self.bot.get_shared_api_tokens("youtube")
is_name = self.check_name_or_id(channel_id_or_name)
if is_name:
stream = YoutubeStream(name=channel_id_or_name, token=apikey, config=self.config)
stream = YoutubeStream(
_bot=self.bot, name=channel_id_or_name, token=apikey, config=self.config
)
else:
stream = YoutubeStream(id=channel_id_or_name, token=apikey, config=self.config)
stream = YoutubeStream(
_bot=self.bot, id=channel_id_or_name, token=apikey, config=self.config
)
await self.check_online(ctx, stream)

@commands.command()
async def smashcast(self, ctx: commands.Context, channel_name: str):
"""Check if a smashcast channel is live."""
stream = HitboxStream(name=channel_name)
stream = HitboxStream(_bot=self.bot, name=channel_name)
await self.check_online(ctx, stream)

@commands.command()
async def picarto(self, ctx: commands.Context, channel_name: str):
"""Check if a Picarto channel is live."""
stream = PicartoStream(name=channel_name)
stream = PicartoStream(_bot=self.bot, name=channel_name)
await self.check_online(ctx, stream)

async def check_online(
Expand Down Expand Up @@ -396,19 +401,22 @@ async def stream_alert(self, ctx: commands.Context, _class, channel_name):
is_yt = _class.__name__ == "YoutubeStream"
is_twitch = _class.__name__ == "TwitchStream"
if is_yt and not self.check_name_or_id(channel_name):
stream = _class(id=channel_name, token=token, config=self.config)
stream = _class(_bot=self.bot, id=channel_name, token=token, config=self.config)
elif is_twitch:
await self.maybe_renew_twitch_bearer_token()
stream = _class(
_bot=self.bot,
name=channel_name,
token=token.get("client_id"),
bearer=self.ttv_bearer_cache.get("access_token", None),
)
else:
if is_yt:
stream = _class(name=channel_name, token=token, config=self.config)
stream = _class(
_bot=self.bot, name=channel_name, token=token, config=self.config
)
else:
stream = _class(name=channel_name, token=token)
stream = _class(_bot=self.bot, name=channel_name, token=token)
try:
exists = await self.check_exists(stream)
except InvalidTwitchCredentials:
Expand Down Expand Up @@ -714,14 +722,23 @@ async def _stream_alerts(self):
await asyncio.sleep(await self.config.refresh_timer())

async def _send_stream_alert(
self, stream, channel: discord.TextChannel, embed: discord.Embed, content: str = None
self,
stream,
channel: discord.TextChannel,
embed: discord.Embed,
content: str = None,
*,
is_schedule: bool = False,
):
m = await channel.send(
content,
embed=embed,
allowed_mentions=discord.AllowedMentions(roles=True, everyone=True),
)
stream._messages_cache.append(m)
message_data = {"guild": m.guild.id, "channel": m.channel.id, "message": m.id}
if is_schedule:
message_data["is_schedule"] = True
stream.messages.append(message_data)

async def check_streams(self):
to_remove = []
Expand All @@ -744,19 +761,25 @@ async def check_streams(self):
to_remove.append(stream)
continue
except OfflineStream:
if not stream._messages_cache:
if not stream.messages:
continue
for message in stream._messages_cache:
if await self.bot.cog_disabled_in_guild(self, message.guild):

for msg_data in stream.iter_messages():
partial_msg = msg_data["partial_message"]
if partial_msg is None:
continue
if await self.bot.cog_disabled_in_guild(self, partial_msg.guild):
continue
autodelete = await self.config.guild(message.guild).autodelete()
if autodelete:
with contextlib.suppress(discord.NotFound):
await message.delete()
stream._messages_cache.clear()
if not await self.config.guild(partial_msg.guild).autodelete():
continue

with contextlib.suppress(discord.NotFound):
await partial_msg.delete()

stream.messages.clear()
await self.save_streams()
else:
if stream._messages_cache:
if stream.messages:
continue
for channel_id in stream.channels:
channel = self.bot.get_channel(channel_id)
Expand All @@ -772,7 +795,7 @@ async def check_streams(self):
continue
if is_schedule:
# skip messages and mentions
await self._send_stream_alert(stream, channel, embed)
await self._send_stream_alert(stream, channel, embed, is_schedule=True)
await self.save_streams()
continue
await set_contextual_locales_from_guild(self.bot, channel.guild)
Expand Down Expand Up @@ -874,17 +897,6 @@ async def load_streams(self):
_class = getattr(_streamtypes, raw_stream["type"], None)
if not _class:
continue
raw_msg_cache = raw_stream["messages"]
raw_stream["_messages_cache"] = []
for raw_msg in raw_msg_cache:
chn = self.bot.get_channel(raw_msg["channel"])
if chn is not None:
try:
msg = await chn.fetch_message(raw_msg["message"])
except discord.HTTPException:
pass
else:
raw_stream["_messages_cache"].append(msg)
token = await self.bot.get_shared_api_tokens(_class.token_name)
if token:
if _class.__name__ == "TwitchStream":
Expand All @@ -894,6 +906,7 @@ async def load_streams(self):
if _class.__name__ == "YoutubeStream":
raw_stream["config"] = self.config
raw_stream["token"] = token
raw_stream["_bot"] = self.bot
streams.append(_class(**raw_stream))

return streams
Expand Down
39 changes: 27 additions & 12 deletions redbot/cogs/streams/streamtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,11 @@ class Stream:
token_name: ClassVar[Optional[str]] = None

def __init__(self, **kwargs):
self._bot = kwargs.pop("_bot")
self.name = kwargs.pop("name", None)
self.channels = kwargs.pop("channels", [])
# self.already_online = kwargs.pop("already_online", False)
self._messages_cache = kwargs.pop("_messages_cache", [])
self.messages = kwargs.pop("messages", [])
self.type = self.__class__.__name__

async def is_online(self):
Expand All @@ -70,14 +71,24 @@ async def is_online(self):
def make_embed(self):
raise NotImplementedError()

def iter_messages(self):
for msg_data in self.messages:
data = msg_data.copy()
# "guild" key might not exist for old config data (available since GH-4742)
if guild_id := msg_data.get("guild"):
guild = self._bot.get_guild(guild_id)
channel = guild and guild.get_channel(msg_data["channel"])
else:
channel = self._bot.get_channel(msg_data["channel"])
if channel is not None:
data["partial_message"] = channel.get_partial_message(data["message"])
yield data

def export(self):
data = {}
for k, v in self.__dict__.items():
if not k.startswith("_"):
data[k] = v
data["messages"] = []
for m in self._messages_cache:
data["messages"].append({"channel": m.channel.id, "message": m.id})
return data

def __repr__(self):
Expand Down Expand Up @@ -211,17 +222,21 @@ async def make_embed(self, data):
embed.timestamp = start_time
is_schedule = True
else:
# repost message
# delete the message(s) about the stream schedule
to_remove = []
for message in self._messages_cache:
if message.embeds[0].description is discord.Embed.Empty:
for msg_data in self.iter_messages():
if not msg_data.get("is_schedule", False):
continue
with contextlib.suppress(Exception):
autodelete = await self._config.guild(message.guild).autodelete()
partial_msg = msg_data["partial_message"]
if partial_msg is not None:
autodelete = await self._config.guild(partial_msg.guild).autodelete()
if autodelete:
await message.delete()
to_remove.append(message.id)
self._messages_cache = [x for x in self._messages_cache if x.id not in to_remove]
with contextlib.suppress(discord.NotFound):
await partial_msg.delete()
to_remove.append(msg_data["message"])
self.messages = [
data for data in self.messages if data["message"] not in to_remove
]
embed.set_author(name=channel_title)
embed.set_image(url=rnd(thumbnail))
embed.colour = 0x9255A5
Expand Down

0 comments on commit 67fa735

Please sign in to comment.