Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/update type annotations #225

Merged
merged 4 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 19 additions & 17 deletions aiohttp_client_cache/backends/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import inspect
import pickle
from abc import ABCMeta, abstractmethod
from collections import UserDict
from datetime import datetime
from logging import getLogger
from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Optional, Tuple, Union
from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Union

from aiohttp import ClientResponse
from aiohttp.typedefs import StrOrURL
Expand Down Expand Up @@ -38,11 +40,11 @@ def __init__(
self,
cache_name: str = 'aiohttp-cache',
expire_after: ExpirationTime = -1,
urls_expire_after: Optional[ExpirationPatterns] = None,
allowed_codes: Tuple[int, ...] = (200,),
allowed_methods: Tuple[str, ...] = ('GET', 'HEAD'),
urls_expire_after: ExpirationPatterns | None = None,
allowed_codes: tuple[int, ...] = (200,),
allowed_methods: tuple[str, ...] = ('GET', 'HEAD'),
include_headers: bool = False,
ignored_params: Optional[Iterable[str]] = None,
ignored_params: Iterable[str] | None = None,
autoclose: bool = False,
cache_control: bool = False,
filter_fn: _FilterFn = lambda r: True,
Expand Down Expand Up @@ -82,7 +84,7 @@ def __init__(
self.ignored_params = set(ignored_params or [])

async def is_cacheable(
self, response: Optional[AnyResponse], actions: Optional[CacheActions] = None
self, response: AnyResponse | None, actions: CacheActions | None = None
) -> bool:
"""Perform all checks needed to determine if the given response should be cached"""
if not response:
Expand Down Expand Up @@ -110,7 +112,7 @@ async def request(
expire_after: ExpirationTime = None,
refresh: bool = False,
**kwargs,
) -> Tuple[Optional[CachedResponse], CacheActions]:
) -> tuple[CachedResponse | None, CacheActions]:
"""Fetch a cached response based on request info

Args:
Expand Down Expand Up @@ -139,7 +141,7 @@ async def request(
response = None if actions.skip_read else await self.get_response(actions.key)
return response, actions

async def get_response(self, key: str) -> Optional[CachedResponse]:
async def get_response(self, key: str) -> CachedResponse | None:
"""Fetch a cached response based on a cache key"""
# Attempt to fetch the cached response
logger.debug(f'Attempting to get cached response for key: {key}')
Expand All @@ -164,16 +166,16 @@ async def get_response(self, key: str) -> Optional[CachedResponse]:
# Response will be a CachedResponse or None by this point
return response # type: ignore

async def _get_redirect_response(self, key: str) -> Optional[CachedResponse]:
async def _get_redirect_response(self, key: str) -> CachedResponse | None:
"""Get the response referenced by a redirect key, if available"""
redirect_key = await self.redirects.read(key)
return await self.responses.read(redirect_key) if redirect_key else None # type: ignore

async def save_response(
self,
response: ClientResponse,
cache_key: Optional[str] = None,
expires: Optional[datetime] = None,
cache_key: str | None = None,
expires: datetime | None = None,
):
"""Save a response to the cache

Expand Down Expand Up @@ -223,7 +225,7 @@ async def delete_expired_responses(self):

async for key in self.responses.keys():
response = await self.responses.read(key)
if response and response.is_expired or not self.filter_fn(response):
if response and response.is_expired or not self.filter_fn(response): # type: ignore[union-attr,arg-type]
keys_to_delete.add(key)

logger.debug(f'Deleting {len(keys_to_delete)} expired cache entries')
Expand Down Expand Up @@ -279,21 +281,21 @@ class BaseCache(metaclass=ABCMeta):

def __init__(
self,
secret_key: Union[Iterable, str, bytes, None] = None,
salt: Union[str, bytes] = b'aiohttp-client-cache',
secret_key: Iterable | str | bytes | None = None,
salt: str | bytes = b'aiohttp-client-cache',
serializer=None,
**kwargs,
):
super().__init__()
self._serializer = serializer or self._get_serializer(secret_key, salt)

def serialize(self, item: ResponseOrKey = None) -> Optional[bytes]:
def serialize(self, item: ResponseOrKey = None) -> bytes | None:
"""Serialize a URL or response into bytes"""
if isinstance(item, bytes):
return item
return self._serializer.dumps(item) if item else None

def deserialize(self, item: ResponseOrKey) -> Union[CachedResponse, str, None]:
def deserialize(self, item: ResponseOrKey) -> CachedResponse | str | None:
"""Deserialize a cached URL or response"""
if isinstance(item, (CachedResponse, str)):
return item
Expand Down Expand Up @@ -385,7 +387,7 @@ async def keys(self) -> AsyncIterable[str]: # type: ignore
for key in self.data.keys():
yield key

async def read(self, key: str) -> Union[CachedResponse, str, None]:
async def read(self, key: str) -> CachedResponse | str | None:
"""An additional step is needed here for response data. The original response object
is still in memory, and hasn't gone through a serialize/deserialize loop. So, the file-like
response body has already been read, and needs to be reset.
Expand Down
10 changes: 6 additions & 4 deletions aiohttp_client_cache/backends/dynamodb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

from contextlib import asynccontextmanager
from logging import getLogger
from typing import Any, AsyncIterable, Dict, Optional
from typing import Any, AsyncIterable

import aioboto3
from aioboto3.session import ResourceCreatorContext
Expand Down Expand Up @@ -41,7 +43,7 @@ def __init__(
key_attr_name: str = 'k',
val_attr_name: str = 'v',
create_if_not_exists: bool = False,
context: Optional[ResourceCreatorContext] = None,
context: ResourceCreatorContext | None = None,
**kwargs: Any,
):
super().__init__(cache_name=cache_name, **kwargs)
Expand Down Expand Up @@ -127,10 +129,10 @@ async def _create_table(self, conn):

return table

def _doc(self, key) -> Dict:
def _doc(self, key) -> dict:
return {self.key_attr_name: f'{self.namespace}:{key}'}

async def _scan(self) -> AsyncIterable[Dict]:
async def _scan(self) -> AsyncIterable[dict]:
table = await self.get_table()
paginator = table.meta.client.get_paginator('scan')
iterator = paginator.paginate(
Expand Down
8 changes: 5 additions & 3 deletions aiohttp_client_cache/backends/filesystem.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

from contextlib import contextmanager
from os import listdir, makedirs
from os.path import abspath, expanduser, isabs, isfile, join
from pathlib import Path
from pickle import PickleError
from shutil import rmtree
from tempfile import gettempdir
from typing import Any, AsyncIterable, Union
from typing import Any, AsyncIterable

import aiofiles
import aiofiles.os
Expand All @@ -32,7 +34,7 @@ class FileBackend(CacheBackend):

def __init__(
self,
cache_name: Union[Path, str] = 'http_cache',
cache_name: Path | str = 'http_cache',
use_temp: bool = False,
autoclose: bool = True,
**kwargs: Any,
Expand Down Expand Up @@ -110,7 +112,7 @@ async def paths(self):
yield self._join(key)


def _get_cache_dir(cache_dir: Union[Path, str], use_temp: bool) -> str:
def _get_cache_dir(cache_dir: Path | str, use_temp: bool) -> str:
# Save to a temp directory, if specified
if use_temp and not isabs(cache_dir):
cache_dir = join(gettempdir(), cache_dir, 'responses')
Expand Down
2 changes: 2 additions & 0 deletions aiohttp_client_cache/backends/mongodb.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Any, AsyncIterable

from motor.motor_asyncio import AsyncIOMotorClient
Expand Down
8 changes: 5 additions & 3 deletions aiohttp_client_cache/backends/redis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any, AsyncIterable, Optional
from __future__ import annotations

from typing import Any, AsyncIterable

from redis.asyncio import Redis, from_url

Expand Down Expand Up @@ -46,7 +48,7 @@ def __init__(
namespace: str,
collection_name: str,
address: str = DEFAULT_ADDRESS,
connection: Optional[Redis] = None,
connection: Redis | None = None,
**kwargs: Any,
):
# Pop off BaseCache kwargs and use the rest as Redis connection kwargs
Expand All @@ -64,7 +66,7 @@ async def get_connection(self):

async def close(self):
if self._connection:
await self._connection.aclose()
await self._connection.aclose() # type: ignore[attr-defined]
self._connection = None

async def clear(self):
Expand Down
22 changes: 12 additions & 10 deletions aiohttp_client_cache/backends/sqlite.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
import sqlite3
from contextlib import asynccontextmanager
Expand All @@ -6,7 +8,7 @@
from os.path import abspath, basename, dirname, expanduser, isabs, join
from pathlib import Path
from tempfile import gettempdir
from typing import Any, AsyncIterable, AsyncIterator, Optional, Type, Union
from typing import Any, AsyncIterable, AsyncIterator

import aiosqlite

Expand Down Expand Up @@ -81,7 +83,7 @@ def __init__(
self.filename = _get_cache_filename(filename, use_temp)
self.table_name = table_name

self._connection: Optional[aiosqlite.Connection] = None
self._connection: aiosqlite.Connection | None = None
self._lock = asyncio.Lock()

@asynccontextmanager
Expand All @@ -97,8 +99,8 @@ async def get_connection(self, commit: bool = False) -> AsyncIterator[aiosqlite.
async def _init_db(self):
"""Initialize the database, if it hasn't already been"""
if self.fast_save:
await self._connection.execute('PRAGMA synchronous = 0;')
await self._connection.execute(
await self._connection.execute('PRAGMA synchronous = 0;') # type: ignore[union-attr]
await self._connection.execute( # type: ignore[union-attr]
f'CREATE TABLE IF NOT EXISTS `{self.table_name}` (key PRIMARY KEY, value)'
)
return self._connection
Expand Down Expand Up @@ -132,7 +134,7 @@ async def bulk_commit(self):
bulk_commit_var.set(True)
try:
yield
await self._connection.commit()
await self._connection.commit() # type: ignore[union-attr]
finally:
bulk_commit_var.set(False)

Expand Down Expand Up @@ -192,7 +194,7 @@ async def values(self) -> AsyncIterable[ResponseOrKey]:
async for row in cursor:
yield row[0]

async def write(self, key: str, item: Union[ResponseOrKey, sqlite3.Binary]):
async def write(self, key: str, item: ResponseOrKey | sqlite3.Binary):
async with self.get_connection(commit=True) as db:
await db.execute(
f'INSERT OR REPLACE INTO `{self.table_name}` (key,value) VALUES (?,?)',
Expand All @@ -213,22 +215,22 @@ async def values(self) -> AsyncIterable[ResponseOrKey]:
yield self.deserialize(row[0])

async def write(self, key, item):
await super().write(key, sqlite3.Binary(self.serialize(item)))
await super().write(key, sqlite3.Binary(self.serialize(item))) # type: ignore[arg-type]


def sqlite_template(
timeout: float = 5.0,
detect_types: int = 0,
isolation_level: Optional[str] = None,
isolation_level: str | None = None,
check_same_thread: bool = True,
factory: Optional[Type] = None,
factory: type | None = None,
cached_statements: int = 100,
uri: bool = False,
):
"""Template function to get an accurate function signature for :py:func:`sqlite3.connect`"""


def _get_cache_filename(filename: Union[Path, str], use_temp: bool) -> str:
def _get_cache_filename(filename: Path | str, use_temp: bool) -> str:
"""Get resolved path for database file"""
# Save to a temp directory, if specified
if use_temp and not isabs(filename):
Expand Down
16 changes: 9 additions & 7 deletions aiohttp_client_cache/cache_keys.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Functions for creating keys used for cache requests"""
from __future__ import annotations

import hashlib
from collections.abc import Mapping
from typing import Any, Dict, Iterable, Optional, Sequence, Union
from typing import Any, Iterable, Sequence, Union

from aiohttp.typedefs import StrOrURL
from multidict import MultiDict
Expand All @@ -14,12 +16,12 @@
def create_key(
method: str,
url: StrOrURL,
params: Optional[RequestParams] = None,
data: Optional[Dict] = None,
json: Optional[Dict] = None,
headers: Optional[Dict] = None,
params: RequestParams | None = None,
data: dict | None = None,
json: dict | None = None,
headers: dict | None = None,
include_headers: bool = False,
ignored_params: Optional[Iterable[str]] = None,
ignored_params: Iterable[str] | None = None,
**kwargs,
) -> str:
"""Create a unique cache key based on request details"""
Expand Down Expand Up @@ -50,7 +52,7 @@ def filter_ignored_params(data, ignored_params: Iterable[str]):
return MultiDict(((k, v) for k, v in data.items() if k not in ignored_params))


def normalize_url_params(url: StrOrURL, params: Optional[RequestParams] = None) -> URL:
def normalize_url_params(url: StrOrURL, params: RequestParams | None = None) -> URL:
"""Normalize any combination of request parameter formats that aiohttp accepts"""
if isinstance(url, str):
url = URL(url)
Expand Down
25 changes: 18 additions & 7 deletions aiohttp_client_cache/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from datetime import datetime
from http.cookies import SimpleCookie
from logging import getLogger
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
from unittest.mock import Mock
from functools import singledispatch

import attr
from aiohttp import ClientResponse, ClientResponseError, hdrs, multipart
Expand All @@ -33,7 +34,7 @@
}

# Default attriutes to add to ClientResponse objects
RESPONSE_DEFAULTS = {
CACHED_RESPONSE_DEFAULTS = {
'created_at': None,
'expires': None,
'from_cache': False,
Expand Down Expand Up @@ -69,7 +70,7 @@
expires: datetime | None = attr.ib(default=None)
raw_headers: RawHeaders = attr.ib(factory=tuple)
real_url: StrOrURL = attr.ib(default=None)
history: Iterable = attr.ib(factory=tuple)
history: tuple = attr.ib(factory=tuple)
last_used: datetime = attr.ib(factory=datetime.utcnow)

@classmethod
Expand Down Expand Up @@ -268,13 +269,23 @@
AnyResponse = Union[ClientResponse, CachedResponse]


def set_response_defaults(response: AnyResponse) -> AnyResponse:
@singledispatch
def set_response_defaults(response):
raise NotImplementedError

Check warning on line 274 in aiohttp_client_cache/response.py

View check run for this annotation

Codecov / codecov/patch

aiohttp_client_cache/response.py#L274

Added line #L274 was not covered by tests


@set_response_defaults.register
def _(response: CachedResponse) -> CachedResponse:
return response


@set_response_defaults.register
def _(response: ClientResponse) -> ClientResponse:
"""Set some default CachedResponse values on a ClientResponse object, so they can be
expected to always be present
"""
if not isinstance(response, CachedResponse):
for k, v in RESPONSE_DEFAULTS.items():
setattr(response, k, v)
for k, v in CACHED_RESPONSE_DEFAULTS.items():
setattr(response, k, v)
return response


Expand Down
Loading
Loading