Skip to content

Commit

Permalink
[PR aio-libs#7819/dfc3f899 backport][3.9] Skip filtering `CookieJar
Browse files Browse the repository at this point in the history
…` when the jar is empty or all cookies have expired (aio-libs#7822)

**This is a backport of PR aio-libs#7819 as merged into master (dfc3f89).**

<!-- Thank you for your contribution! -->

The filtering itself and its preparation in `CookieJar.filter_cookies()`
is expensive. Sometimes there are no cookies in the jar or all cookies
have expired. Skip filtering and its preparation in this case.

Because the empty check is much cheaper than `_do_expiration()`, I think
it deserves to be duplicated before and after calling
`_do_expiration()`.

```console
$ python3.11 -m timeit -s 'from collections import defaultdict; d=defaultdict(foo="bar")' \
> 'if not d: pass'
50000000 loops, best of 5: 8.3 nsec per loop
$ python3.11 -m timeit -s 'from collections import defaultdict; d=defaultdict()' \
> 'if not d: pass'
50000000 loops, best of 5: 8.74 nsec per loop
$ python3.11 -m timeit -s 'from aiohttp import CookieJar; cj = CookieJar()' \
> 'cj._do_expiration()'
200000 loops, best of 5: 1.86 usec per loop
```

<!-- Please give a short brief about these changes. -->

No.

<!-- Outline any notable behaviour for the end users. -->

aio-libs#7583 (comment)

<!-- Are there any issues opened that will be resolved by merging this
change? -->

- [x] I think the code is well written
- [ ] Unit tests for the changes exist
- [ ] Documentation reflects the changes
- [x] If you provide code modification, please add yourself to
`CONTRIBUTORS.txt`
  * The format is &lt;Name&gt; &lt;Surname&gt;.
  * Please keep alphabetical order, the file is sorted by names.
- [x] Add a new news fragment into the `CHANGES` folder
  * name it `<issue_id>.<type>` for example (588.bugfix)
* if you don't have an `issue_id` change it to the pr id after creating
the pr
  * ensure type is one of the following:
    * `.feature`: Signifying a new feature.
    * `.bugfix`: Signifying a bug fix.
    * `.doc`: Signifying a documentation improvement.
    * `.removal`: Signifying a deprecation or removal of public API.
* `.misc`: A ticket has been closed, but it is not of interest to users.
* Make sure to use full sentences with correct case and punctuation, for
example: "Fix issue with non-ascii contents in doctest text files."
  • Loading branch information
Rongronggg9 authored and Xiang Li committed Dec 4, 2023
1 parent a96fc47 commit 8b1e777
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 54 deletions.
1 change: 1 addition & 0 deletions CHANGES/7819.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Skip filtering ``CookieJar`` when the jar is empty or all cookies have expired.
110 changes: 56 additions & 54 deletions aiohttp/cookiejar.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import calendar
import asyncio
import contextlib
import datetime
import os # noqa
import pathlib
import pickle
import re
import time
import warnings
from collections import defaultdict
from http.cookies import BaseCookie, Morsel, SimpleCookie
from math import ceil
from typing import ( # noqa
DefaultDict,
Dict,
Expand All @@ -27,7 +24,7 @@
from yarl import URL

from .abc import AbstractCookieJar, ClearCookiePredicate
from .helpers import is_ip_address
from .helpers import is_ip_address, next_whole_second
from .typedefs import LooseCookies, PathLike, StrOrURL

__all__ = ("CookieJar", "DummyCookieJar")
Expand Down Expand Up @@ -55,32 +52,20 @@ class CookieJar(AbstractCookieJar):

DATE_YEAR_RE = re.compile(r"(\d{2,4})")

# calendar.timegm() fails for timestamps after datetime.datetime.max
# Minus one as a loss of precision occurs when timestamp() is called.
MAX_TIME = (
int(datetime.datetime.max.replace(tzinfo=datetime.timezone.utc).timestamp()) - 1
)
try:
calendar.timegm(time.gmtime(MAX_TIME))
except (OSError, ValueError):
# Hit the maximum representable time on Windows
# https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/localtime-localtime32-localtime64
# Throws ValueError on PyPy 3.8 and 3.9, OSError elsewhere
MAX_TIME = calendar.timegm((3000, 12, 31, 23, 59, 59, -1, -1, -1))
except OverflowError:
# #4515: datetime.max may not be representable on 32-bit platforms
MAX_TIME = 2**31 - 1
# Avoid minuses in the future, 3x faster
SUB_MAX_TIME = MAX_TIME - 1
MAX_TIME = datetime.datetime.max.replace(tzinfo=datetime.timezone.utc)

MAX_32BIT_TIME = datetime.datetime.fromtimestamp(2**31 - 1, datetime.timezone.utc)

def __init__(
self,
*,
unsafe: bool = False,
quote_cookie: bool = True,
treat_as_secure_origin: Union[StrOrURL, List[StrOrURL], None] = None
treat_as_secure_origin: Union[StrOrURL, List[StrOrURL], None] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
self._cookies: DefaultDict[Tuple[str, str], SimpleCookie] = defaultdict(
super().__init__(loop=loop)
self._cookies: DefaultDict[Tuple[str, str], SimpleCookie[str]] = defaultdict(
SimpleCookie
)
self._host_only_cookies: Set[Tuple[str, str]] = set()
Expand All @@ -98,8 +83,14 @@ def __init__(
for url in treat_as_secure_origin
]
self._treat_as_secure_origin = treat_as_secure_origin
self._next_expiration: float = ceil(time.time())
self._expirations: Dict[Tuple[str, str, str], float] = {}
self._next_expiration = next_whole_second()
self._expirations: Dict[Tuple[str, str, str], datetime.datetime] = {}
# #4515: datetime.max may not be representable on 32-bit platforms
self._max_time = self.MAX_TIME
try:
self._max_time.timestamp()
except OverflowError:
self._max_time = self.MAX_32BIT_TIME

def save(self, file_path: PathLike) -> None:
file_path = pathlib.Path(file_path)
Expand All @@ -113,14 +104,14 @@ def load(self, file_path: PathLike) -> None:

def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
if predicate is None:
self._next_expiration = ceil(time.time())
self._next_expiration = next_whole_second()
self._cookies.clear()
self._host_only_cookies.clear()
self._expirations.clear()
return

to_del = []
now = time.time()
now = datetime.datetime.now(datetime.timezone.utc)
for (domain, path), cookie in self._cookies.items():
for name, morsel in cookie.items():
key = (domain, path, name)
Expand All @@ -136,11 +127,13 @@ def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
del self._expirations[(domain, path, name)]
self._cookies[(domain, path)].pop(name, None)

self._next_expiration = (
min(*self._expirations.values(), self.SUB_MAX_TIME) + 1
if self._expirations
else self.MAX_TIME
)
next_expiration = min(self._expirations.values(), default=self._max_time)
try:
self._next_expiration = next_expiration.replace(
microsecond=0
) + datetime.timedelta(seconds=1)
except OverflowError:
self._next_expiration = self._max_time

def clear_domain(self, domain: str) -> None:
self.clear(lambda x: self._is_domain_match(domain, x["domain"]))
Expand All @@ -156,7 +149,9 @@ def __len__(self) -> int:
def _do_expiration(self) -> None:
self.clear(lambda x: False)

def _expire_cookie(self, when: float, domain: str, path: str, name: str) -> None:
def _expire_cookie(
self, when: datetime.datetime, domain: str, path: str, name: str
) -> None:
self._next_expiration = min(self._next_expiration, when)
self._expirations[(domain, path, name)] = when

Expand All @@ -173,7 +168,7 @@ def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> No

for name, cookie in cookies:
if not isinstance(cookie, Morsel):
tmp = SimpleCookie()
tmp: SimpleCookie[str] = SimpleCookie()
tmp[name] = cookie # type: ignore[assignment]
cookie = tmp[name]

Expand Down Expand Up @@ -214,7 +209,12 @@ def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> No
if max_age:
try:
delta_seconds = int(max_age)
max_age_expiration = min(time.time() + delta_seconds, self.MAX_TIME)
try:
max_age_expiration = datetime.datetime.now(
datetime.timezone.utc
) + datetime.timedelta(seconds=delta_seconds)
except OverflowError:
max_age_expiration = self._max_time
self._expire_cookie(max_age_expiration, domain, path, name)
except ValueError:
cookie["max-age"] = ""
Expand All @@ -232,17 +232,11 @@ def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> No

self._do_expiration()

def filter_cookies(self, request_url: URL = URL()) -> "BaseCookie[str]":
def filter_cookies(
self, request_url: URL = URL()
) -> Union["BaseCookie[str]", "SimpleCookie[str]"]:
"""Returns this jar's cookies filtered by their attributes."""
if not isinstance(request_url, URL):
warnings.warn(
"The method accepts yarl.URL instances only, got {}".format(
type(request_url)
),
DeprecationWarning,
)
request_url = URL(request_url)
filtered: Union[SimpleCookie, "BaseCookie[str]"] = (
filtered: Union["SimpleCookie[str]", "BaseCookie[str]"] = (
SimpleCookie() if self._quote_cookie else BaseCookie()
)
if not self._cookies:
Expand All @@ -252,14 +246,16 @@ def filter_cookies(self, request_url: URL = URL()) -> "BaseCookie[str]":
if not self._cookies:
# Skip rest of function if no non-expired cookies.
return filtered
request_url = URL(request_url)
hostname = request_url.raw_host or ""
request_origin = URL()
with contextlib.suppress(ValueError):
request_origin = request_url.origin()

is_not_secure = request_url.scheme not in ("https", "wss")
if is_not_secure and self._treat_as_secure_origin:
request_origin = URL()
with contextlib.suppress(ValueError):
request_origin = request_url.origin()
is_not_secure = request_origin not in self._treat_as_secure_origin
is_not_secure = (
request_url.scheme not in ("https", "wss")
and request_origin not in self._treat_as_secure_origin
)

# Point 2: https://www.rfc-editor.org/rfc/rfc6265.html#section-5.4
for cookie in sorted(self, key=lambda c: len(c["path"])):
Expand Down Expand Up @@ -330,7 +326,7 @@ def _is_path_match(req_path: str, cookie_path: str) -> bool:
return non_matching.startswith("/")

@classmethod
def _parse_date(cls, date_str: str) -> Optional[int]:
def _parse_date(cls, date_str: str) -> Optional[datetime.datetime]:
"""Implements date string parsing adhering to RFC 6265."""
if not date_str:
return None
Expand All @@ -346,6 +342,7 @@ def _parse_date(cls, date_str: str) -> Optional[int]:
year = 0

for token_match in cls.DATE_TOKENS_RE.finditer(date_str):

token = token_match.group("token")

if not found_time:
Expand Down Expand Up @@ -390,7 +387,9 @@ def _parse_date(cls, date_str: str) -> Optional[int]:
if year < 1601 or hour > 23 or minute > 59 or second > 59:
return None

return calendar.timegm((year, month, day, hour, minute, second, -1, -1, -1))
return datetime.datetime(
year, month, day, hour, minute, second, tzinfo=datetime.timezone.utc
)


class DummyCookieJar(AbstractCookieJar):
Expand All @@ -400,6 +399,9 @@ class DummyCookieJar(AbstractCookieJar):
"""

def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(loop=loop)

def __iter__(self) -> "Iterator[Morsel[str]]":
while False:
yield None
Expand Down

0 comments on commit 8b1e777

Please sign in to comment.