Skip to content

Commit

Permalink
Fix new mypy errors and improve typing with Self
Browse files Browse the repository at this point in the history
  • Loading branch information
bayandin committed Nov 22, 2024
1 parent 825d662 commit 3697856
Show file tree
Hide file tree
Showing 15 changed files with 64 additions and 70 deletions.
6 changes: 4 additions & 2 deletions scripts/force_layer_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,11 @@ async def main_impl(args, report_out, client: Client):
tenant_ids = await client.get_tenant_ids()
get_timeline_id_coros = [client.get_timeline_ids(tenant_id) for tenant_id in tenant_ids]
gathered = await asyncio.gather(*get_timeline_id_coros, return_exceptions=True)
assert len(tenant_ids) == len(gathered)
tenant_and_timline_ids = []
for tid, tlids in zip(tenant_ids, gathered, strict=False):
for tid, tlids in zip(tenant_ids, gathered, strict=True):
# TODO: add error handling if tlids isinstance(Exception)
assert isinstance(tlids, list)

for tlid in tlids:
tenant_and_timline_ids.append((tid, tlid))
elif len(comps) == 1:
Expand Down
6 changes: 1 addition & 5 deletions test_runner/fixtures/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,6 @@ def from_json(cls, d: dict[str, Any]) -> TenantTimelineId:
)


# Workaround for compat with python 3.9, which does not have `typing.Self`
TTenantShardId = TypeVar("TTenantShardId", bound="TenantShardId")


class TenantShardId:
def __init__(self, tenant_id: TenantId, shard_number: int, shard_count: int):
self.tenant_id = tenant_id
Expand All @@ -202,7 +198,7 @@ def __init__(self, tenant_id: TenantId, shard_number: int, shard_count: int):
assert self.shard_number < self.shard_count or self.shard_count == 0

@classmethod
def parse(cls: type[TTenantShardId], input: str) -> TTenantShardId:
def parse(cls: type[TenantShardId], input: str) -> TenantShardId:
if len(input) == 32:
return cls(
tenant_id=TenantId(input),
Expand Down
2 changes: 1 addition & 1 deletion test_runner/fixtures/compute_reconfigure.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def handler(request: Request) -> Response:
# This causes the endpoint to query storage controller for its location, which
# is redundant since we already have it here, but this avoids extending the
# neon_local CLI to take full lists of locations
reconfigure_threads.submit(lambda workload=workload: workload.reconfigure()) # type: ignore[no-any-return]
reconfigure_threads.submit(lambda workload=workload: workload.reconfigure()) # type: ignore[misc]

return Response(status=200)

Expand Down
3 changes: 0 additions & 3 deletions test_runner/fixtures/neon_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,9 @@
if TYPE_CHECKING:
from typing import (
Any,
TypeVar,
cast,
)

T = TypeVar("T")


# Used to be an ABC. abc.ABC removed due to linter without name change.
class AbstractNeonCli:
Expand Down
67 changes: 31 additions & 36 deletions test_runner/fixtures/neon_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,7 @@

if TYPE_CHECKING:
from collections.abc import Callable
from typing import (
Any,
TypeVar,
)
from typing import Any, Self, TypeVar

from fixtures.paths import SnapshotDirLocked

Expand Down Expand Up @@ -838,7 +835,7 @@ def cleanup_remote_storage(self):
if isinstance(x, S3Storage):
x.do_cleanup()

def __enter__(self) -> NeonEnvBuilder:
def __enter__(self) -> Self:
return self

def __exit__(
Expand Down Expand Up @@ -1148,21 +1145,19 @@ def start(self, timeout_in_seconds: int | None = None):
with concurrent.futures.ThreadPoolExecutor(
max_workers=2 + len(self.pageservers) + len(self.safekeepers)
) as executor:
futs.append(
executor.submit(lambda: self.broker.start() or None)
) # The `or None` is for the linter
futs.append(executor.submit(lambda: self.broker.start()))

for pageserver in self.pageservers:
futs.append(
executor.submit(
lambda ps=pageserver: ps.start(timeout_in_seconds=timeout_in_seconds)
lambda ps=pageserver: ps.start(timeout_in_seconds=timeout_in_seconds) # type: ignore[misc]
)
)

for safekeeper in self.safekeepers:
futs.append(
executor.submit(
lambda sk=safekeeper: sk.start(timeout_in_seconds=timeout_in_seconds)
lambda sk=safekeeper: sk.start(timeout_in_seconds=timeout_in_seconds) # type: ignore[misc]
)
)

Expand Down Expand Up @@ -1602,13 +1597,13 @@ def start(
timeout_in_seconds: int | None = None,
instance_id: int | None = None,
base_port: int | None = None,
):
) -> Self:
assert not self.running
self.env.neon_cli.storage_controller_start(timeout_in_seconds, instance_id, base_port)
self.running = True
return self

def stop(self, immediate: bool = False) -> NeonStorageController:
def stop(self, immediate: bool = False) -> Self:
if self.running:
self.env.neon_cli.storage_controller_stop(immediate)
self.running = False
Expand Down Expand Up @@ -2282,7 +2277,7 @@ def set_preferred_azs(self, preferred_azs: dict[TenantShardId, str]) -> list[Ten
response.raise_for_status()
return [TenantShardId.parse(tid) for tid in response.json()["updated"]]

def __enter__(self) -> NeonStorageController:
def __enter__(self) -> Self:
return self

def __exit__(
Expand All @@ -2304,7 +2299,7 @@ def start(
timeout_in_seconds: int | None = None,
instance_id: int | None = None,
base_port: int | None = None,
):
) -> Self:
assert instance_id is not None and base_port is not None

self.env.neon_cli.storage_controller_start(timeout_in_seconds, instance_id, base_port)
Expand All @@ -2324,7 +2319,7 @@ def stop_instance(
self.running = any(meta["running"] for meta in self.instances.values())
return self

def stop(self, immediate: bool = False) -> NeonStorageController:
def stop(self, immediate: bool = False) -> Self:
for iid, details in self.instances.items():
if details["running"]:
self.env.neon_cli.storage_controller_stop(immediate, iid)
Expand Down Expand Up @@ -2446,7 +2441,7 @@ def start(
self,
extra_env_vars: dict[str, str] | None = None,
timeout_in_seconds: int | None = None,
) -> NeonPageserver:
) -> Self:
"""
Start the page server.
`overrides` allows to add some config to this pageserver start.
Expand Down Expand Up @@ -2481,7 +2476,7 @@ def start(

return self

def stop(self, immediate: bool = False) -> NeonPageserver:
def stop(self, immediate: bool = False) -> Self:
"""
Stop the page server.
Returns self.
Expand Down Expand Up @@ -2529,7 +2524,7 @@ def complete():

wait_until(20, 0.5, complete)

def __enter__(self) -> NeonPageserver:
def __enter__(self) -> Self:
return self

def __exit__(
Expand Down Expand Up @@ -2957,7 +2952,7 @@ def get_subdir_size(self, subdir: Path) -> int:
"""Return size of pgdatadir subdirectory in bytes."""
return get_dir_size(self.pgdatadir / subdir)

def __enter__(self) -> VanillaPostgres:
def __enter__(self) -> Self:
return self

def __exit__(
Expand Down Expand Up @@ -3006,7 +3001,7 @@ def get_subdir_size(self, subdir) -> int:
# See https://www.postgresql.org/docs/14/functions-admin.html#FUNCTIONS-ADMIN-GENFILE
raise Exception("cannot get size of a Postgres instance")

def __enter__(self) -> RemotePostgres:
def __enter__(self) -> Self:
return self

def __exit__(
Expand Down Expand Up @@ -3220,7 +3215,7 @@ def __init__(
self.http_timeout_seconds = 15
self._popen: subprocess.Popen[bytes] | None = None

def start(self) -> NeonProxy:
def start(self) -> Self:
assert self._popen is None

# generate key of it doesn't exist
Expand Down Expand Up @@ -3348,7 +3343,7 @@ async def find_auth_link(link_auth_uri, proc):
log.info(f"SUCCESS, found auth url: {line}")
return line

def __enter__(self) -> NeonProxy:
def __enter__(self) -> Self:
return self

def __exit__(
Expand Down Expand Up @@ -3438,7 +3433,7 @@ def __init__(
self.http_timeout_seconds = 15
self._popen: subprocess.Popen[bytes] | None = None

def start(self) -> NeonAuthBroker:
def start(self) -> Self:
assert self._popen is None

# generate key of it doesn't exist
Expand Down Expand Up @@ -3507,7 +3502,7 @@ def get_metrics(self) -> str:
request_result = requests.get(f"http://{self.host}:{self.http_port}/metrics")
return request_result.text

def __enter__(self) -> NeonAuthBroker:
def __enter__(self) -> Self:
return self

def __exit__(
Expand Down Expand Up @@ -3704,7 +3699,7 @@ def create(
config_lines: list[str] | None = None,
pageserver_id: int | None = None,
allow_multiple: bool = False,
) -> Endpoint:
) -> Self:
"""
Create a new Postgres endpoint.
Returns self.
Expand Down Expand Up @@ -3750,7 +3745,7 @@ def start(
safekeepers: list[int] | None = None,
allow_multiple: bool = False,
basebackup_request_tries: int | None = None,
) -> Endpoint:
) -> Self:
"""
Start the Postgres instance.
Returns self.
Expand Down Expand Up @@ -3797,7 +3792,7 @@ def config_file_path(self) -> Path:
"""Path to the postgresql.conf in the endpoint directory (not the one in pgdata)"""
return self.endpoint_path() / "postgresql.conf"

def config(self, lines: list[str]) -> Endpoint:
def config(self, lines: list[str]) -> Self:
"""
Add lines to postgresql.conf.
Lines should be an array of valid postgresql.conf rows.
Expand Down Expand Up @@ -3873,7 +3868,7 @@ def stop(
self,
mode: str = "fast",
sks_wait_walreceiver_gone: tuple[list[Safekeeper], TimelineId] | None = None,
) -> Endpoint:
) -> Self:
"""
Stop the Postgres instance if it's running.
Expand Down Expand Up @@ -3907,7 +3902,7 @@ def stop(

return self

def stop_and_destroy(self, mode: str = "immediate") -> Endpoint:
def stop_and_destroy(self, mode: str = "immediate") -> Self:
"""
Stop the Postgres instance, then destroy the endpoint.
Returns self.
Expand All @@ -3934,7 +3929,7 @@ def create_start(
pageserver_id: int | None = None,
allow_multiple: bool = False,
basebackup_request_tries: int | None = None,
) -> Endpoint:
) -> Self:
"""
Create an endpoint, apply config, and start Postgres.
Returns self.
Expand All @@ -3957,7 +3952,7 @@ def create_start(

return self

def __enter__(self) -> Endpoint:
def __enter__(self) -> Self:
return self

def __exit__(
Expand Down Expand Up @@ -4058,7 +4053,7 @@ def create(
pageserver_id=pageserver_id,
)

def stop_all(self, fail_on_error=True) -> EndpointFactory:
def stop_all(self, fail_on_error=True) -> Self:
exception = None
for ep in self.endpoints:
try:
Expand Down Expand Up @@ -4154,7 +4149,7 @@ def __init__(

def start(
self, extra_opts: list[str] | None = None, timeout_in_seconds: int | None = None
) -> Safekeeper:
) -> Self:
if extra_opts is None:
# Apply either the extra_opts passed in, or the ones from our constructor: we do not merge the two.
extra_opts = self.extra_opts
Expand Down Expand Up @@ -4189,7 +4184,7 @@ def start(
break # success
return self

def stop(self, immediate: bool = False) -> Safekeeper:
def stop(self, immediate: bool = False) -> Self:
self.env.neon_cli.safekeeper_stop(self.id, immediate)
self.running = False
return self
Expand Down Expand Up @@ -4367,13 +4362,13 @@ def __init__(self, env: NeonEnv):
def start(
self,
timeout_in_seconds: int | None = None,
):
) -> Self:
assert not self.running
self.env.neon_cli.storage_broker_start(timeout_in_seconds)
self.running = True
return self

def stop(self):
def stop(self) -> Self:
if self.running:
self.env.neon_cli.storage_broker_stop()
self.running = False
Expand Down
1 change: 1 addition & 0 deletions test_runner/fixtures/parametrize.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def pytest_generate_tests(metafunc: Metafunc):

metafunc.parametrize("build_type", build_types)

pg_versions: list[PgVersion]
if (v := os.getenv("DEFAULT_PG_VERSION")) is None:
pg_versions = [version for version in PgVersion if version != PgVersion.NOT_SET]
else:
Expand Down
2 changes: 1 addition & 1 deletion test_runner/fixtures/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
self._endpoint: Endpoint | None = None
self._endpoint_opts = endpoint_opts or {}

def reconfigure(self):
def reconfigure(self) -> None:
"""
Request the endpoint to reconfigure based on location reported by storage controller
"""
Expand Down
7 changes: 3 additions & 4 deletions test_runner/regress/test_compute_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

if TYPE_CHECKING:
from types import TracebackType
from typing import TypedDict
from typing import Self, TypedDict

from fixtures.neon_fixtures import NeonEnv
from fixtures.pg_version import PgVersion
Expand Down Expand Up @@ -185,7 +185,7 @@ def start(self) -> None:
def stop(self) -> None:
raise NotImplementedError()

def __enter__(self) -> SqlExporterRunner:
def __enter__(self) -> Self:
self.start()

return self
Expand Down Expand Up @@ -242,8 +242,7 @@ def __init__(
self.with_volume_mapping(str(config_file), container_config_file, "z")
self.with_volume_mapping(str(collector_file), container_collector_file, "z")

@override
def start(self) -> SqlExporterContainer:
def start(self) -> Self:
super().start()

log.info("Waiting for sql_exporter to be ready")
Expand Down
4 changes: 2 additions & 2 deletions test_runner/regress/test_ddl_forwarding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from werkzeug.wrappers.response import Response

if TYPE_CHECKING:
from typing import Any
from typing import Any, Self


def handle_db(dbs, roles, operation):
Expand Down Expand Up @@ -91,7 +91,7 @@ def __init__(self, httpserver: HTTPServer, vanilla_pg: VanillaPostgres, host: st
lambda request: ddl_forward_handler(request, self.dbs, self.roles, self)
)

def __enter__(self):
def __enter__(self) -> Self:
self.pg.start()
return self

Expand Down
Loading

0 comments on commit 3697856

Please sign in to comment.