Skip to content

Commit

Permalink
feat: add SubscriptionIterator
Browse files Browse the repository at this point in the history
  • Loading branch information
Snipy7374 committed Dec 16, 2024
1 parent 979d116 commit 66823ef
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 4 deletions.
100 changes: 100 additions & 0 deletions disnake/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .guild_scheduled_event import GuildScheduledEvent
from .integrations import PartialIntegration
from .object import Object
from .subscription import Subscription
from .threads import Thread
from .utils import maybe_coroutine, snowflake_time, time_snowflake

Expand All @@ -39,6 +40,7 @@
"MemberIterator",
"GuildScheduledEventUserIterator",
"EntitlementIterator",
"SubscriptionIterator",
"PollAnswerIterator",
)

Expand All @@ -60,6 +62,7 @@
GuildScheduledEventUser as GuildScheduledEventUserPayload,
)
from .types.message import Message as MessagePayload
from .types.subscription import Subscription as SubscriptionPayload
from .types.threads import Thread as ThreadPayload
from .types.user import PartialUser as PartialUserPayload
from .user import User
Expand Down Expand Up @@ -1147,6 +1150,103 @@ async def _after_strategy(self, retrieve: int) -> List[EntitlementPayload]:
return data


class SubscriptionIterator(_AsyncIterator["Subscription"]):
def __init__(
self,
sku_id: int,
*,
user_id: Optional[int] = None,
state: ConnectionState,
limit: Optional[int] = None,
before: Optional[Union[Snowflake, datetime.datetime]] = None,
after: Optional[Union[Snowflake, datetime.datetime]] = None,
) -> None:
if isinstance(before, datetime.datetime):
before = Object(id=time_snowflake(before, high=False))
if isinstance(after, datetime.datetime):
after = Object(id=time_snowflake(after, high=True))

self.sku_id: int = sku_id
self.user_id: Optional[int] = user_id
self.limit: Optional[int] = limit
self.before: Optional[Snowflake] = before
self.after: Snowflake = after or OLDEST_OBJECT

self._state: ConnectionState = state
self.request = self._state.http.get_subscriptions
self.subscriptions: asyncio.Queue[Subscription] = asyncio.Queue()

self._filter: Optional[Callable[[SubscriptionPayload], bool]] = None
if self.before:
self._strategy = self._before_strategy
if self.after != OLDEST_OBJECT:
self._filter = lambda s: int(s["id"]) > self.after.id
else:
self._strategy = self._after_strategy

async def next(self) -> Subscription:
if self.subscriptions.empty():
await self._fill()

try:
return self.subscriptions.get_nowait()
except asyncio.QueueEmpty:
raise NoMoreItems from None

def _get_retrieve(self) -> bool:
limit = self.limit
if limit is None or limit > 100:
retrieve = 100
else:
retrieve = limit
self.retrieve: int = retrieve
return retrieve > 0

async def _fill(self) -> None:
if not self._get_retrieve():
return

data = await self._strategy(self.retrieve)
if len(data) < 100:
self.limit = 0 # terminate loop

if self._filter:
data = filter(self._filter, data)

for subscription in data:
await self.subscriptions.put(Subscription(data=subscription, state=self._state))

async def _before_strategy(self, retrieve: int) -> List[SubscriptionPayload]:
before = self.before.id if self.before else None
data = await self.request(
self.sku_id,
before=before,
limit=retrieve,
user_id=self.user_id,
)

if len(data):
if self.limit is not None:
self.limit -= retrieve
self.before = Object(id=int(data[-1]["id"]))
return data

async def _after_strategy(self, retrieve: int) -> List[SubscriptionPayload]:
after = self.after.id
data = await self.request(
self.sku_id,
after=after,
limit=retrieve,
user_id=self.user_id,
)

if len(data):
if self.limit is not None:
self.limit -= retrieve
self.after = Object(id=int(data[-1]["id"]))
return data


class PollAnswerIterator(_AsyncIterator[Union["User", "Member"]]):
def __init__(
self,
Expand Down
54 changes: 50 additions & 4 deletions disnake/sku.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@
from __future__ import annotations

import datetime
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

from .enums import SKUType, try_enum
from .flags import SKUFlags
from .iterators import SubscriptionIterator
from .mixins import Hashable
from .subscription import Subscription
from .utils import snowflake_time

if TYPE_CHECKING:
from .abc import Snowflake, SnowflakeTime
from .state import ConnectionState
from .types.sku import SKU as SKUPayload

Expand Down Expand Up @@ -85,12 +87,56 @@ def flags(self) -> SKUFlags:
""":class:`SKUFlags`: Returns the SKU's flags."""
return SKUFlags._from_value(self._flags)

async def subscriptions(self):
async def subscriptions(
self,
*,
limit: Optional[int] = 50,
before: Optional[SnowflakeTime] = None,
after: Optional[SnowflakeTime] = None,
user: Optional[Snowflake] = None,
) -> SubscriptionIterator:
"""|coro|
Retrieve all the subscriptions for this SKU.
Retrieves an :class:`.AsyncIterator` that enabled receiving subscriptions for the SKU.
All parameters are optional.
Parameters
----------
limit: Optional[:class:`int`]
The number of subscriptions to retrieve.
If ``None``, retrieves every subscription.
Note, however, that this would make it a slow operation.
Defaults to ``50``.
before: Union[:class:`.abc.Snowflake`, :class:`datetime.datetime`]
Retrieves subscriptions created before this date or object.
If a datetime is provided, it is recommended to use a UTC aware datetime.
If the datetime is naive, it is assumed to be local time.
after: Union[:class:`.abc.Snowflake`, :class:`datetime.datetime`]
Retrieve subscriptions created after this date or object.
If a datetime is provided, it is recommended to use a UTC aware datetime.
If the datetime is naive, it is assumed to be local time.
user: Optional[:class:`.abc.Snowflake`]
The user to retrieve subscriptions for.
Raises
------
HTTPException
Retrieving the subscriptions failed.
Yields
------
:class:`.Subscription`
The subscriptions for the given parameters.
"""
...
return SubscriptionIterator(
self.id,
state=self._state,
limit=limit,
before=before,
after=after,
user_id=user.id if user is not None else None,
)

async def fetch_subscription(self, subscription_id: int, /) -> Subscription:
"""|coro|
Expand Down

0 comments on commit 66823ef

Please sign in to comment.