Skip to content

Commit

Permalink
Use partial messages in Streams cog to avoid leakage
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackenmen committed Feb 1, 2021
1 parent 8139587 commit 655f41d
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 35 deletions.
51 changes: 28 additions & 23 deletions redbot/cogs/streams/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,14 +714,23 @@ async def _stream_alerts(self):
await asyncio.sleep(await self.config.refresh_timer())

async def _send_stream_alert(
self, stream, channel: discord.TextChannel, embed: discord.Embed, content: str = None
self,
stream,
channel: discord.TextChannel,
embed: discord.Embed,
content: str = None,
*,
is_schedule: bool = False,
):
m = await channel.send(
content,
embed=embed,
allowed_mentions=discord.AllowedMentions(roles=True, everyone=True),
)
stream._messages_cache.append(m)
message_data = {"channel": m.channel.id, "message": m.id}
if is_schedule:
message_data["is_schedule"] = True
stream.messages.append(message_data)

async def check_streams(self):
for stream in self.streams:
Expand All @@ -739,19 +748,25 @@ async def check_streams(self):
else:
embed = await stream.is_online()
except OfflineStream:
if not stream._messages_cache:
if not stream.messages:
continue
for message in stream._messages_cache:
if await self.bot.cog_disabled_in_guild(self, message.guild):

for msg_data in stream.iter_messages():
partial_msg = msg_data["partial_message"]
if partial_msg is None:
continue
if await self.bot.cog_disabled_in_guild(self, partial_msg.guild):
continue
autodelete = await self.config.guild(message.guild).autodelete()
if autodelete:
with contextlib.suppress(discord.NotFound):
await message.delete()
stream._messages_cache.clear()
if not await self._config.guild(partial_msg.guild).autodelete():
continue

with contextlib.suppress(discord.NotFound):
await partial_msg.delete()

stream.messages.clear()
await self.save_streams()
else:
if stream._messages_cache:
if stream.messages:
continue
for channel_id in stream.channels:
channel = self.bot.get_channel(channel_id)
Expand All @@ -767,7 +782,7 @@ async def check_streams(self):
continue
if is_schedule:
# skip messages and mentions
await self._send_stream_alert(stream, channel, embed)
await self._send_stream_alert(stream, channel, embed, is_schedule=True)
await self.save_streams()
continue
await set_contextual_locales_from_guild(self.bot, channel.guild)
Expand Down Expand Up @@ -864,17 +879,6 @@ async def load_streams(self):
_class = getattr(_streamtypes, raw_stream["type"], None)
if not _class:
continue
raw_msg_cache = raw_stream["messages"]
raw_stream["_messages_cache"] = []
for raw_msg in raw_msg_cache:
chn = self.bot.get_channel(raw_msg["channel"])
if chn is not None:
try:
msg = await chn.fetch_message(raw_msg["message"])
except discord.HTTPException:
pass
else:
raw_stream["_messages_cache"].append(msg)
token = await self.bot.get_shared_api_tokens(_class.token_name)
if token:
if _class.__name__ == "TwitchStream":
Expand All @@ -884,6 +888,7 @@ async def load_streams(self):
if _class.__name__ == "YoutubeStream":
raw_stream["config"] = self.config
raw_stream["token"] = token
raw_stream["bot"] = self.bot
streams.append(_class(**raw_stream))

return streams
Expand Down
34 changes: 22 additions & 12 deletions redbot/cogs/streams/streamtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,11 @@ class Stream:
token_name: ClassVar[Optional[str]] = None

def __init__(self, **kwargs):
self.bot = kwargs.pop("bot")
self.name = kwargs.pop("name", None)
self.channels = kwargs.pop("channels", [])
# self.already_online = kwargs.pop("already_online", False)
self._messages_cache = kwargs.pop("_messages_cache", [])
self.messages = kwargs.pop("messages", [])
self.type = self.__class__.__name__

async def is_online(self):
Expand All @@ -68,14 +69,19 @@ async def is_online(self):
def make_embed(self):
raise NotImplementedError()

def iter_messages(self):
for msg_data in self.messages:
data = msg_data.copy()
channel = self.bot.get_channel(msg_data["channel"])
if channel is not None:
data["partial_message"] = channel.get_partial_message(data["message"])
yield data

def export(self):
data = {}
for k, v in self.__dict__.items():
if not k.startswith("_"):
data[k] = v
data["messages"] = []
for m in self._messages_cache:
data["messages"].append({"channel": m.channel.id, "message": m.id})
return data

def __repr__(self):
Expand Down Expand Up @@ -190,17 +196,21 @@ async def make_embed(self, data):
embed.timestamp = start_time
is_schedule = True
else:
# repost message
# delete the message(s) about the stream schedule
to_remove = []
for message in self._messages_cache:
if message.embeds[0].description is discord.Embed.Empty:
for msg_data in self.iter_messages():
if not msg_data.get("is_schedule", False):
continue
with contextlib.suppress(Exception):
autodelete = await self._config.guild(message.guild).autodelete()
partial_msg = msg_data["partial_message"]
if partial_msg is not None:
autodelete = await self._config.guild(partial_msg.guild).autodelete()
if autodelete:
await message.delete()
to_remove.append(message.id)
self._messages_cache = [x for x in self._messages_cache if x.id not in to_remove]
with contextlib.suppress(discord.NotFound):
await partial_msg.delete()
to_remove.append(msg_data["message"])
self.messages = [
data for data in self.messages if data["message"] not in to_remove
]
embed.set_author(name=channel_title)
embed.set_image(url=rnd(thumbnail))
embed.colour = 0x9255A5
Expand Down

0 comments on commit 655f41d

Please sign in to comment.