Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve issues with discord.Thread and discord.ThreadOption when used in discord.Option #1427

Merged
merged 9 commits into from
Jun 24, 2022
4 changes: 3 additions & 1 deletion discord/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,14 @@ def __repr__(self) -> str:
return f"<{self.__class__.__name__} {joined}>"

def _update(self, guild: Guild, data: Union[TextChannelPayload, ForumChannelPayload]) -> None:
# This data will always exist
self.guild: Guild = guild
self.name: str = data["name"]
self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id")
self._type: int = data["type"]

if not data.get("_invoke_flag"):
# This data may be missing depending on how this object is being created/updated
if data.get("position", None) is not None:
self.topic: Optional[str] = data.get("topic")
self.position: int = data.get("position")
self.nsfw: bool = data.get("nsfw", False)
Expand Down
25 changes: 15 additions & 10 deletions discord/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
Union,
)

from ..channel import _guild_channel_factory
from ..channel import _threaded_guild_channel_factory
from ..enums import MessageType, SlashCommandOptionType, try_enum, Enum as DiscordEnum
from ..errors import (
ApplicationCommandError,
Expand All @@ -61,6 +61,7 @@
from ..message import Attachment, Message
from ..object import Object
from ..role import Role
from ..threads import Thread
from ..user import User
from ..utils import async_all, find, utcnow
from .context import ApplicationContext, AutocompleteContext
Expand Down Expand Up @@ -812,10 +813,11 @@ async def _invoke(self, ctx: ApplicationContext) -> None:
else:
arg = Object(id=int(arg))
elif (_data := resolved.get(f"{op.input_type.name}s", {}).get(arg)) is not None:
if op.input_type is SlashCommandOptionType.channel and int(arg) in ctx.guild._channels:
arg = ctx.guild.get_channel(int(arg))
_data["_invoke_flag"] = True
arg._update(ctx.guild, _data)
if op.input_type is SlashCommandOptionType.channel and (
int(arg) in ctx.guild._channels or int(arg) in ctx.guild._threads
):
arg = ctx.guild.get_channel_or_thread(int(arg))
arg._update(_data) if isinstance(arg, Thread) else arg._update(ctx.guild, _data)
else:
obj_type = None
kw = {}
Expand All @@ -826,11 +828,14 @@ async def _invoke(self, ctx: ApplicationContext) -> None:
kw["guild"] = ctx.guild
elif op.input_type is SlashCommandOptionType.channel:
# NOTE:
# This is a fallback in case the channel is not found in the guild's channels.
# If this fallback occurs, at the very minimum, permissions will be incorrect
# due to a lack of permission_overwrite data.
obj_type = _guild_channel_factory(_data["type"])[0]
kw["guild"] = ctx.guild
# This is a fallback in case the channel/thread is not found in the
# guild's channels/threads. For channels, if this fallback occurs, at the very minimum,
# permissions will be incorrect due to a lack of permission_overwrite data.
# For threads, if this fallback occurs, info like thread owner id, message count,
# flags, and more will be missing due to a lack of data sent by Discord.
obj_type = _threaded_guild_channel_factory(_data["type"])[0]
if op._raw_type is not Thread:
baronkobama marked this conversation as resolved.
Show resolved Hide resolved
kw["guild"] = ctx.guild
elif op.input_type is SlashCommandOptionType.attachment:
obj_type = Attachment
arg = obj_type(state=ctx.interaction._state, data=_data, **kw)
Expand Down
4 changes: 1 addition & 3 deletions discord/commands/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"StageChannel": ChannelType.stage_voice,
"CategoryChannel": ChannelType.category,
"Thread": ChannelType.public_thread,
"ThreadOption": ChannelType.public_thread,
baronkobama marked this conversation as resolved.
Show resolved Hide resolved
}


Expand Down Expand Up @@ -158,9 +159,6 @@ def __init__(self, input_type: Any = str, /, description: Optional[str] = None,
for i in input_type:
if i.__name__ == "GuildChannel":
continue
if isinstance(i, ThreadOption):
self.channel_types.append(i._type)
continue
baronkobama marked this conversation as resolved.
Show resolved Hide resolved

channel_type = channel_type_map[i.__name__]
self.channel_types.append(channel_type)
Expand Down
4 changes: 2 additions & 2 deletions discord/guild.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,8 @@ def _add_member(self, member: Member, /) -> None:
self._members[member.id] = member

def _get_and_update_member(self, payload: MemberPayload, user_id: int, cache_flag: bool, /) -> Member:
# we always get the member, and we only update if the cache_flag (this cache flag should
# always be MemberCacheFlag.interaction or MemberCacheFlag.option) is set to True
# we always get the member, and we only update if the cache_flag (this cache
# flag should always be MemberCacheFlag.interaction) is set to True
if user_id in self._members:
member = self.get_member(user_id)
member._update(payload) if cache_flag else None
Expand Down
19 changes: 13 additions & 6 deletions discord/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,23 @@ def __str__(self) -> str:
return self.name

def _from_data(self, data: ThreadPayload):
# This data will always exist
self.id = int(data["id"])
self.parent_id = int(data["parent_id"])
self.owner_id = int(data["owner_id"])
self.name = data["name"]
self._type = try_enum(ChannelType, data["type"])
self.last_message_id = _get_as_snowflake(data, "last_message_id")
self.slowmode_delay = data.get("rate_limit_per_user", 0)
self.message_count = data["message_count"]
self.member_count = data["member_count"]
self.flags: ChannelFlags = ChannelFlags._from_value(data.get("flags", 0))

# This data may be missing depending on how this object is being created
try:
self.owner_id = int(data["owner_id"])
self.last_message_id = _get_as_snowflake(data, "last_message_id")
self.slowmode_delay = data.get("rate_limit_per_user", 0)
self.message_count = data["message_count"]
self.member_count = data["member_count"]
self.flags: ChannelFlags = ChannelFlags._from_value(data.get("flags", 0))
except KeyError:
pass
baronkobama marked this conversation as resolved.
Show resolved Hide resolved

self._unroll_metadata(data["thread_metadata"])

try:
Expand Down