Skip to content

Commit

Permalink
feat(polls): implement Polls (#1176)
Browse files Browse the repository at this point in the history
Co-authored-by: shiftinv <8530778+shiftinv@users.noreply.github.com>
Co-authored-by: Victor <67214928+Victorsitou@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 24, 2024
1 parent cebfb89 commit d4972ab
Show file tree
Hide file tree
Showing 23 changed files with 1,013 additions and 3 deletions.
6 changes: 6 additions & 0 deletions changelog/1175.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Add the new poll discord API feature. This includes the following new classes and events:

- New types: :class:`Poll`, :class:`PollAnswer`, :class:`PollMedia`, :class:`RawMessagePollVoteActionEvent` and :class:`PollLayoutType`.
- Edited :meth:`abc.Messageable.send`, :meth:`Webhook.send`, :meth:`ext.commands.Context.send` and :meth:`disnake.InteractionResponse.send_message` to be able to send polls.
- Edited :class:`Message` to store a new :attr:`Message.poll` attribute for polls.
- Edited :class:`Event` to contain the new :func:`on_message_poll_vote_add`, :func:`on_message_poll_vote_remove`, :func:`on_raw_message_poll_vote_add` and :func:`on_raw_message_poll_vote_remove`.
1 change: 1 addition & 0 deletions disnake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from .partial_emoji import *
from .permissions import *
from .player import *
from .poll import *
from .raw_models import *
from .reaction import *
from .role import *
Expand Down
21 changes: 20 additions & 1 deletion disnake/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
from .iterators import HistoryIterator
from .member import Member
from .message import Message, MessageReference, PartialMessage
from .poll import Poll
from .state import ConnectionState
from .threads import AnyThreadArchiveDuration, ForumTag
from .types.channel import (
Expand Down Expand Up @@ -640,6 +641,7 @@ def _apply_implict_permissions(self, base: Permissions) -> None:
if not base.send_messages:
base.send_tts_messages = False
base.send_voice_messages = False
base.send_polls = False
base.mention_everyone = False
base.embed_links = False
base.attach_files = False
Expand Down Expand Up @@ -887,6 +889,7 @@ async def set_permissions(
request_to_speak: Optional[bool] = ...,
send_messages: Optional[bool] = ...,
send_messages_in_threads: Optional[bool] = ...,
send_polls: Optional[bool] = ...,
send_tts_messages: Optional[bool] = ...,
send_voice_messages: Optional[bool] = ...,
speak: Optional[bool] = ...,
Expand Down Expand Up @@ -1435,6 +1438,7 @@ async def send(
mention_author: bool = ...,
view: View = ...,
components: Components[MessageUIComponent] = ...,
poll: Poll = ...,
) -> Message:
...

Expand All @@ -1456,6 +1460,7 @@ async def send(
mention_author: bool = ...,
view: View = ...,
components: Components[MessageUIComponent] = ...,
poll: Poll = ...,
) -> Message:
...

Expand All @@ -1477,6 +1482,7 @@ async def send(
mention_author: bool = ...,
view: View = ...,
components: Components[MessageUIComponent] = ...,
poll: Poll = ...,
) -> Message:
...

Expand All @@ -1498,6 +1504,7 @@ async def send(
mention_author: bool = ...,
view: View = ...,
components: Components[MessageUIComponent] = ...,
poll: Poll = ...,
) -> Message:
...

Expand All @@ -1520,6 +1527,7 @@ async def send(
mention_author: Optional[bool] = None,
view: Optional[View] = None,
components: Optional[Components[MessageUIComponent]] = None,
poll: Optional[Poll] = None,
):
"""|coro|
Expand All @@ -1528,7 +1536,7 @@ async def send(
The content must be a type that can convert to a string through ``str(content)``.
At least one of ``content``, ``embed``/``embeds``, ``file``/``files``,
``stickers``, ``components``, or ``view`` must be provided.
``stickers``, ``components``, ``poll`` or ``view`` must be provided.
To upload a single file, the ``file`` parameter should be used with a
single :class:`.File` object. To upload multiple files, the ``files``
Expand Down Expand Up @@ -1624,6 +1632,11 @@ async def send(
.. versionadded:: 2.9
poll: :class:`.Poll`
The poll to send with the message.
.. versionadded:: 2.10
Raises
------
HTTPException
Expand Down Expand Up @@ -1676,6 +1689,10 @@ async def send(
if stickers is not None:
stickers_payload = [sticker.id for sticker in stickers]

poll_payload = None
if poll:
poll_payload = poll._to_dict()

allowed_mentions_payload = None
if allowed_mentions is None:
allowed_mentions_payload = state.allowed_mentions and state.allowed_mentions.to_dict()
Expand Down Expand Up @@ -1737,6 +1754,7 @@ async def send(
message_reference=reference_payload,
stickers=stickers_payload,
components=components_payload,
poll=poll_payload,
flags=flags_payload,
)
finally:
Expand All @@ -1753,6 +1771,7 @@ async def send(
message_reference=reference_payload,
stickers=stickers_payload,
components=components_payload,
poll=poll_payload,
flags=flags_payload,
)

Expand Down
21 changes: 21 additions & 0 deletions disnake/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
"OnboardingPromptType",
"SKUType",
"EntitlementType",
"PollLayoutType",
)


Expand Down Expand Up @@ -1215,6 +1216,14 @@ class Event(Enum):
"""Called when messages are bulk deleted.
Represents the :func:`on_bulk_message_delete` event.
"""
poll_vote_add = "poll_vote_add"
"""Called when a vote is added on a `Poll`.
Represents the :func:`on_poll_vote_add` event.
"""
poll_vote_remove = "poll_vote_remove"
"""Called when a vote is removed from a `Poll`.
Represents the :func:`on_poll_vote_remove` event.
"""
raw_message_edit = "raw_message_edit"
"""Called when a message is edited regardless of the state of the internal message cache.
Represents the :func:`on_raw_message_edit` event.
Expand All @@ -1227,6 +1236,14 @@ class Event(Enum):
"""Called when a bulk delete is triggered regardless of the messages being in the internal message cache or not.
Represents the :func:`on_raw_bulk_message_delete` event.
"""
raw_poll_vote_add = "raw_poll_vote_add"
"""Called when a vote is added on a `Poll` regardless of the internal message cache.
Represents the :func:`on_raw_poll_vote_add` event.
"""
raw_poll_vote_remove = "raw_poll_vote_remove"
"""Called when a vote is removed from a `Poll` regardless of the internal message cache.
Represents the :func:`on_raw_poll_vote_remove` event.
"""
reaction_add = "reaction_add"
"""Called when a message has a reaction added to it.
Represents the :func:`on_reaction_add` event.
Expand Down Expand Up @@ -1364,6 +1381,10 @@ class EntitlementType(Enum):
application_subscription = 8


class PollLayoutType(Enum):
default = 1


T = TypeVar("T")


Expand Down
1 change: 1 addition & 0 deletions disnake/ext/commands/base_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,7 @@ def default_member_permissions(
request_to_speak: bool = ...,
send_messages: bool = ...,
send_messages_in_threads: bool = ...,
send_polls: bool = ...,
send_tts_messages: bool = ...,
send_voice_messages: bool = ...,
speak: bool = ...,
Expand Down
4 changes: 4 additions & 0 deletions disnake/ext/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2032,6 +2032,7 @@ def has_permissions(
request_to_speak: bool = ...,
send_messages: bool = ...,
send_messages_in_threads: bool = ...,
send_polls: bool = ...,
send_tts_messages: bool = ...,
send_voice_messages: bool = ...,
speak: bool = ...,
Expand Down Expand Up @@ -2157,6 +2158,7 @@ def bot_has_permissions(
request_to_speak: bool = ...,
send_messages: bool = ...,
send_messages_in_threads: bool = ...,
send_polls: bool = ...,
send_tts_messages: bool = ...,
send_voice_messages: bool = ...,
speak: bool = ...,
Expand Down Expand Up @@ -2260,6 +2262,7 @@ def has_guild_permissions(
request_to_speak: bool = ...,
send_messages: bool = ...,
send_messages_in_threads: bool = ...,
send_polls: bool = ...,
send_tts_messages: bool = ...,
send_voice_messages: bool = ...,
speak: bool = ...,
Expand Down Expand Up @@ -2360,6 +2363,7 @@ def bot_has_guild_permissions(
request_to_speak: bool = ...,
send_messages: bool = ...,
send_messages_in_threads: bool = ...,
send_polls: bool = ...,
send_tts_messages: bool = ...,
send_voice_messages: bool = ...,
speak: bool = ...,
Expand Down
58 changes: 58 additions & 0 deletions disnake/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,11 +1028,13 @@ def __init__(
automod_execution: bool = ...,
bans: bool = ...,
dm_messages: bool = ...,
dm_polls: bool = ...,
dm_reactions: bool = ...,
dm_typing: bool = ...,
emojis: bool = ...,
emojis_and_stickers: bool = ...,
guild_messages: bool = ...,
guild_polls: bool = ...,
guild_reactions: bool = ...,
guild_scheduled_events: bool = ...,
guild_typing: bool = ...,
Expand All @@ -1043,6 +1045,7 @@ def __init__(
message_content: bool = ...,
messages: bool = ...,
moderation: bool = ...,
polls: bool = ...,
presences: bool = ...,
reactions: bool = ...,
typing: bool = ...,
Expand Down Expand Up @@ -1598,6 +1601,61 @@ def automod(self):
"""
return (1 << 20) | (1 << 21)

@alias_flag_value
def polls(self):
""":class:`bool`: Whether guild and direct message polls related events are enabled.
This is a shortcut to set or get both :attr:`guild_polls` and :attr:`dm_polls`.
This corresponds to the following events:
- :func:`on_poll_vote_add` (both guilds and DMs)
- :func:`on_poll_vote_remove` (both guilds and DMs)
- :func:`on_raw_poll_vote_add` (both guilds and DMs)
- :func:`on_raw_poll_vote_remove` (both guilds and DMs)
"""
return (1 << 24) | (1 << 25)

@flag_value
def guild_polls(self):
""":class:`bool`: Whether guild polls related events are enabled.
.. versionadded:: 2.10
This corresponds to the following events:
- :func:`on_poll_vote_add` (only for guilds)
- :func:`on_poll_vote_remove` (only for guilds)
- :func:`on_raw_poll_vote_add` (only for guilds)
- :func:`on_raw_poll_vote_remove` (only for guilds)
This also corresponds to the following attributes and classes in terms of cache:
- :attr:`Message.poll` (only for guild messages)
- :class:`Poll` and all its attributes.
"""
return 1 << 24

@flag_value
def dm_polls(self):
""":class:`bool`: Whether direct message polls related events are enabled.
.. versionadded:: 2.10
This corresponds to the following events:
- :func:`on_poll_vote_add` (only for DMs)
- :func:`on_poll_vote_remove` (only for DMs)
- :func:`on_raw_poll_vote_add` (only for DMs)
- :func:`on_raw_poll_vote_remove` (only for DMs)
This also corresponds to the following attributes and classes in terms of cache:
- :attr:`Message.poll` (only for DM messages)
- :class:`Poll` and all its attributes.
"""
return 1 << 25


class MemberCacheFlags(BaseFlags):
"""Controls the library's cache policy when it comes to members.
Expand Down
49 changes: 49 additions & 0 deletions disnake/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
member,
message,
onboarding,
poll,
role,
sku,
sticker,
Expand Down Expand Up @@ -528,6 +529,7 @@ def send_message(
message_reference: Optional[message.MessageReference] = None,
stickers: Optional[Sequence[Snowflake]] = None,
components: Optional[Sequence[components.Component]] = None,
poll: Optional[poll.PollCreatePayload] = None,
flags: Optional[int] = None,
) -> Response[message.Message]:
r = Route("POST", "/channels/{channel_id}/messages", channel_id=channel_id)
Expand Down Expand Up @@ -563,8 +565,50 @@ def send_message(
if flags is not None:
payload["flags"] = flags

if poll is not None:
payload["poll"] = poll

return self.request(r, json=payload)

def get_poll_answer_voters(
self,
channel_id: Snowflake,
message_id: Snowflake,
answer_id: int,
*,
after: Optional[Snowflake] = None,
limit: Optional[int] = None,
) -> Response[poll.PollVoters]:
params: Dict[str, Any] = {}

if after is not None:
params["after"] = after
if limit is not None:
params["limit"] = limit

return self.request(
Route(
"GET",
"/channels/{channel_id}/polls/{message_id}/answers/{answer_id}",
channel_id=channel_id,
message_id=message_id,
answer_id=answer_id,
),
params=params,
)

def expire_poll(
self, channel_id: Snowflake, message_id: Snowflake
) -> Response[message.Message]:
return self.request(
Route(
"POST",
"/channels/{channel_id}/polls/{message_id}/expire",
channel_id=channel_id,
message_id=message_id,
)
)

def send_typing(self, channel_id: Snowflake) -> Response[None]:
return self.request(Route("POST", "/channels/{channel_id}/typing", channel_id=channel_id))

Expand All @@ -582,6 +626,7 @@ def send_multipart_helper(
message_reference: Optional[message.MessageReference] = None,
stickers: Optional[Sequence[Snowflake]] = None,
components: Optional[Sequence[components.Component]] = None,
poll: Optional[poll.PollCreatePayload] = None,
flags: Optional[int] = None,
) -> Response[message.Message]:
payload: Dict[str, Any] = {"tts": tts}
Expand All @@ -603,6 +648,8 @@ def send_multipart_helper(
payload["sticker_ids"] = stickers
if flags is not None:
payload["flags"] = flags
if poll:
payload["poll"] = poll

multipart = to_multipart_with_attachments(payload, files)

Expand All @@ -622,6 +669,7 @@ def send_files(
message_reference: Optional[message.MessageReference] = None,
stickers: Optional[Sequence[Snowflake]] = None,
components: Optional[Sequence[components.Component]] = None,
poll: Optional[poll.PollCreatePayload] = None,
flags: Optional[int] = None,
) -> Response[message.Message]:
r = Route("POST", "/channels/{channel_id}/messages", channel_id=channel_id)
Expand All @@ -637,6 +685,7 @@ def send_files(
message_reference=message_reference,
stickers=stickers,
components=components,
poll=poll,
flags=flags,
)

Expand Down
Loading

0 comments on commit d4972ab

Please sign in to comment.