From b78cf1c0cec941cfcbe69e819aba69daa809b310 Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Tue, 18 Apr 2023 23:06:34 +0300 Subject: [PATCH] Improve typing Fix #1057 --- .pre-commit-config.yaml | 1 + changelog/1057.trivial | 1 + pyproject.toml | 16 +-- src/xdist/dsession.py | 182 +++++++++++++++++++++---------- src/xdist/looponfail.py | 61 ++++++----- src/xdist/newhooks.py | 53 +++++++-- src/xdist/plugin.py | 54 +++++---- src/xdist/remote.py | 138 +++++++++++++++-------- src/xdist/report.py | 10 +- src/xdist/scheduler/each.py | 47 +++++--- src/xdist/scheduler/load.py | 54 +++++---- src/xdist/scheduler/loadfile.py | 8 +- src/xdist/scheduler/loadgroup.py | 8 +- src/xdist/scheduler/loadscope.py | 57 ++++++---- src/xdist/scheduler/worksteal.py | 59 ++++++---- src/xdist/workermanage.py | 122 ++++++++++++++------- testing/acceptance_test.py | 77 ++++++------- testing/conftest.py | 18 +-- testing/test_dsession.py | 177 +++++++++++++++++------------- testing/test_looponfail.py | 6 +- testing/test_plugin.py | 10 +- testing/test_remote.py | 72 ++++++++---- testing/test_workermanage.py | 102 +++++++++++------ testing/util.py | 2 +- 24 files changed, 858 insertions(+), 477 deletions(-) create mode 100644 changelog/1057.trivial diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 280c65c8..74922831 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,3 +32,4 @@ repos: - pytest>=7.0.0 - execnet>=2.1.0 - types-psutil + - setproctitle diff --git a/changelog/1057.trivial b/changelog/1057.trivial new file mode 100644 index 00000000..df1b3c04 --- /dev/null +++ b/changelog/1057.trivial @@ -0,0 +1 @@ +The internals of pytest-xdist are now fully typed. The typing is not exposed yet. diff --git a/pyproject.toml b/pyproject.toml index 71d429db..d65a24dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -137,19 +137,11 @@ lines-after-imports = 2 [tool.mypy] mypy_path = ["src"] files = ["src", "testing"] -# TODO: Enable this & fix errors. -# check_untyped_defs = true -disallow_any_generics = true -ignore_missing_imports = true -no_implicit_optional = true -show_error_codes = true -strict_equality = true -warn_redundant_casts = true -warn_return_any = true +strict = true warn_unreachable = true -warn_unused_configs = true -# TODO: Enable this & fix errors. -# no_implicit_reexport = true +[[tool.mypy.overrides]] +module = ["xdist._version"] +ignore_missing_imports = true [tool.towncrier] diff --git a/src/xdist/dsession.py b/src/xdist/dsession.py index 9671b2fd..62079a28 100644 --- a/src/xdist/dsession.py +++ b/src/xdist/dsession.py @@ -5,11 +5,15 @@ from queue import Empty from queue import Queue import sys +from typing import Any from typing import Sequence +import warnings +import execnet import pytest from xdist.remote import Producer +from xdist.remote import WorkerInfo from xdist.scheduler import EachScheduling from xdist.scheduler import LoadFileScheduling from xdist.scheduler import LoadGroupScheduling @@ -18,6 +22,7 @@ from xdist.scheduler import Scheduling from xdist.scheduler import WorkStealingScheduling from xdist.workermanage import NodeManager +from xdist.workermanage import WorkerController class Interrupted(KeyboardInterrupt): @@ -38,29 +43,31 @@ class DSession: it will wait for instructions. """ - def __init__(self, config): + shouldstop: bool | str + + def __init__(self, config: pytest.Config) -> None: self.config = config self.log = Producer("dsession", enabled=config.option.debug) - self.nodemanager = None - self.sched = None + self.nodemanager: NodeManager | None = None + self.sched: Scheduling | None = None self.shuttingdown = False self.countfailures = 0 - self.maxfail = config.getvalue("maxfail") - self.queue = Queue() - self._session = None - self._failed_collection_errors = {} - self._active_nodes = set() + self.maxfail: int = config.getvalue("maxfail") + self.queue: Queue[tuple[str, dict[str, Any]]] = Queue() + self._session: pytest.Session | None = None + self._failed_collection_errors: dict[object, bool] = {} + self._active_nodes: set[WorkerController] = set() self._failed_nodes_count = 0 self._max_worker_restart = get_default_max_worker_restart(self.config) # summary message to print at the end of the session - self._summary_report = None + self._summary_report: str | None = None self.terminal = config.pluginmanager.getplugin("terminalreporter") if self.terminal: self.trdist = TerminalDistReporter(config) config.pluginmanager.register(self.trdist, "terminaldistreporter") @property - def session_finished(self): + def session_finished(self) -> bool: """Return True if the distributed session has finished. This means all nodes have executed all test items. This is @@ -68,12 +75,12 @@ def session_finished(self): """ return bool(self.shuttingdown and not self._active_nodes) - def report_line(self, line): + def report_line(self, line: str) -> None: if self.terminal and self.config.option.verbose >= 0: self.terminal.write_line(line) @pytest.hookimpl(trylast=True) - def pytest_sessionstart(self, session): + def pytest_sessionstart(self, session: pytest.Session) -> None: """Creates and starts the nodes. The nodes are setup to put their events onto self.queue. As @@ -85,7 +92,7 @@ def pytest_sessionstart(self, session): self._session = session @pytest.hookimpl - def pytest_sessionfinish(self, session): + def pytest_sessionfinish(self) -> None: """Shutdown all nodes.""" nm = getattr(self, "nodemanager", None) # if not fully initialized if nm is not None: @@ -93,12 +100,16 @@ def pytest_sessionfinish(self, session): self._session = None @pytest.hookimpl - def pytest_collection(self): + def pytest_collection(self) -> bool: # prohibit collection of test items in controller process return True @pytest.hookimpl(trylast=True) - def pytest_xdist_make_scheduler(self, config, log) -> Scheduling | None: + def pytest_xdist_make_scheduler( + self, + config: pytest.Config, + log: Producer, + ) -> Scheduling | None: dist = config.getvalue("dist") if dist == "each": return EachScheduling(config, log) @@ -115,7 +126,7 @@ def pytest_xdist_make_scheduler(self, config, log) -> Scheduling | None: return None @pytest.hookimpl - def pytest_runtestloop(self): + def pytest_runtestloop(self) -> bool: self.sched = self.config.hook.pytest_xdist_make_scheduler( config=self.config, log=self.log ) @@ -132,7 +143,7 @@ def pytest_runtestloop(self): raise pending_exception return True - def loop_once(self): + def loop_once(self) -> None: """Process one callback from one of the workers.""" while 1: if not self._active_nodes: @@ -150,6 +161,7 @@ def loop_once(self): call = getattr(self, method) self.log("calling method", method, kwargs) call(**kwargs) + assert self.sched is not None if self.sched.tests_finished: self.triggershutdown() @@ -157,7 +169,11 @@ def loop_once(self): # callbacks for processing events from workers # - def worker_workerready(self, node, workerinfo): + def worker_workerready( + self, + node: WorkerController, + workerinfo: WorkerInfo, + ) -> None: """Emitted when a node first starts up. This adds the node to the scheduler, nodes continue with @@ -171,9 +187,10 @@ def worker_workerready(self, node, workerinfo): if self.shuttingdown: node.shutdown() else: + assert self.sched is not None self.sched.add_node(node) - def worker_workerfinished(self, node): + def worker_workerfinished(self, node: WorkerController) -> None: """Emitted when node executes its pytest_sessionfinish hook. Removes the node from the scheduler. @@ -194,12 +211,15 @@ def worker_workerfinished(self, node): self.shouldstop = shouldx break else: + assert self.sched is not None if node in self.sched.nodes: crashitem = self.sched.remove_node(node) assert not crashitem, (crashitem, node) self._active_nodes.remove(node) - def worker_internal_error(self, node, formatted_error): + def worker_internal_error( + self, node: WorkerController, formatted_error: str + ) -> None: """ pytest_internalerror() was called on the worker. @@ -215,9 +235,10 @@ def worker_internal_error(self, node, formatted_error): excrepr = excinfo.getrepr() self.config.hook.pytest_internalerror(excrepr=excrepr, excinfo=excinfo) - def worker_errordown(self, node, error): + def worker_errordown(self, node: WorkerController, error: object | None) -> None: """Emitted by the WorkerController when a node dies.""" self.config.hook.pytest_testnodedown(node=node, error=error) + assert self.sched is not None try: crashitem = self.sched.remove_node(node) except KeyError: @@ -235,7 +256,7 @@ def worker_errordown(self, node, error): if self._max_worker_restart == 0: msg = f"worker {node.gateway.id} crashed and worker restarting disabled" else: - msg = "maximum crashed workers reached: %d" % self._max_worker_restart + msg = f"maximum crashed workers reached: {self._max_worker_restart}" self._summary_report = msg self.report_line("\n" + msg) self.triggershutdown() @@ -246,11 +267,13 @@ def worker_errordown(self, node, error): self._active_nodes.remove(node) @pytest.hookimpl - def pytest_terminal_summary(self, terminalreporter): + def pytest_terminal_summary(self, terminalreporter: Any) -> None: if self.config.option.verbose >= 0 and self._summary_report: terminalreporter.write_sep("=", f"xdist: {self._summary_report}") - def worker_collectionfinish(self, node, ids): + def worker_collectionfinish( + self, node: WorkerController, ids: Sequence[str] + ) -> None: """Worker has finished test collection. This adds the collection for this node to the scheduler. If @@ -264,7 +287,9 @@ def worker_collectionfinish(self, node, ids): self.config.hook.pytest_xdist_node_collection_finished(node=node, ids=ids) # tell session which items were effectively collected otherwise # the controller node will finish the session with EXIT_NOTESTSCOLLECTED + assert self._session is not None self._session.testscollected = len(ids) + assert self.sched is not None self.sched.add_node_collection(node, ids) if self.terminal: self.trdist.setstatus( @@ -280,29 +305,44 @@ def worker_collectionfinish(self, node, ids): ) self.sched.schedule() - def worker_logstart(self, node, nodeid, location): + def worker_logstart( + self, + node: WorkerController, + nodeid: str, + location: tuple[str, int | None, str], + ) -> None: """Emitted when a node calls the pytest_runtest_logstart hook.""" self.config.hook.pytest_runtest_logstart(nodeid=nodeid, location=location) - def worker_logfinish(self, node, nodeid, location): + def worker_logfinish( + self, + node: WorkerController, + nodeid: str, + location: tuple[str, int | None, str], + ) -> None: """Emitted when a node calls the pytest_runtest_logfinish hook.""" self.config.hook.pytest_runtest_logfinish(nodeid=nodeid, location=location) - def worker_testreport(self, node, rep): + def worker_testreport(self, node: WorkerController, rep: pytest.TestReport) -> None: """Emitted when a node calls the pytest_runtest_logreport hook.""" - rep.node = node + rep.node = node # type: ignore[attr-defined] self.config.hook.pytest_runtest_logreport(report=rep) self._handlefailures(rep) - def worker_runtest_protocol_complete(self, node, item_index, duration): + def worker_runtest_protocol_complete( + self, node: WorkerController, item_index: int, duration: float + ) -> None: """ Emitted when a node fires the 'runtest_protocol_complete' event, signalling that a test has completed the runtestprotocol and should be removed from the pending list in the scheduler. """ + assert self.sched is not None self.sched.mark_test_complete(node, item_index, duration) - def worker_unscheduled(self, node, indices): + def worker_unscheduled( + self, node: WorkerController, indices: Sequence[int] + ) -> None: """ Emitted when a node fires the 'unscheduled' event, signalling that some tests have been removed from the worker's queue and should be @@ -311,9 +351,14 @@ def worker_unscheduled(self, node, indices): This should happen only in response to 'steal' command, so schedulers not using 'steal' command don't have to implement it. """ + assert self.sched is not None self.sched.remove_pending_tests_from_node(node, indices) - def worker_collectreport(self, node, rep): + def worker_collectreport( + self, + node: WorkerController, + rep: pytest.CollectReport | pytest.TestReport, + ) -> None: """Emitted when a node calls the pytest_collectreport hook. Because we only need the report when there's a failure/skip, as optimization @@ -322,14 +367,20 @@ def worker_collectreport(self, node, rep): assert not rep.passed self._failed_worker_collectreport(node, rep) - def worker_warning_recorded(self, warning_message, when, nodeid, location): + def worker_warning_recorded( + self, + warning_message: warnings.WarningMessage, + when: str, + nodeid: str, + location: tuple[str, int, str] | None, + ) -> None: """Emitted when a node calls the pytest_warning_recorded hook.""" kwargs = dict( warning_message=warning_message, when=when, nodeid=nodeid, location=location ) self.config.hook.pytest_warning_recorded.call_historic(kwargs=kwargs) - def _clone_node(self, node): + def _clone_node(self, node: WorkerController) -> WorkerController: """Return new node based on an existing one. This is normally for when a node dies, this will copy the spec @@ -339,12 +390,17 @@ def _clone_node(self, node): """ spec = node.gateway.spec spec.id = None + assert self.nodemanager is not None self.nodemanager.group.allocate_id(spec) - node = self.nodemanager.setup_node(spec, self.queue.put) - self._active_nodes.add(node) - return node - - def _failed_worker_collectreport(self, node, rep): + clone = self.nodemanager.setup_node(spec, self.queue.put) + self._active_nodes.add(clone) + return clone + + def _failed_worker_collectreport( + self, + node: WorkerController, + rep: pytest.CollectReport | pytest.TestReport, + ) -> None: # Check we haven't already seen this report (from # another worker). if rep.longrepr not in self._failed_collection_errors: @@ -352,7 +408,10 @@ def _failed_worker_collectreport(self, node, rep): self.config.hook.pytest_collectreport(report=rep) self._handlefailures(rep) - def _handlefailures(self, rep): + def _handlefailures( + self, + rep: pytest.CollectReport | pytest.TestReport, + ) -> None: if rep.failed: self.countfailures += 1 if ( @@ -362,22 +421,28 @@ def _handlefailures(self, rep): ): self.shouldstop = f"stopping after {self.countfailures} failures" - def triggershutdown(self): + def triggershutdown(self) -> None: if not self.shuttingdown: self.log("triggering shutdown") self.shuttingdown = True + assert self.sched is not None for node in self.sched.nodes: node.shutdown() - def handle_crashitem(self, nodeid, worker): + def handle_crashitem(self, nodeid: str, worker: WorkerController) -> None: # XXX get more reporting info by recording pytest_runtest_logstart? # XXX count no of failures and retry N times fspath = nodeid.split("::")[0] msg = f"worker {worker.gateway.id!r} crashed while running {nodeid!r}" rep = pytest.TestReport( - nodeid, (fspath, None, fspath), (), "failed", msg, "???" + nodeid=nodeid, + location=(fspath, None, fspath), + keywords={}, + outcome="failed", + longrepr=msg, + when="???", # type: ignore[arg-type] ) - rep.node = worker + rep.node = worker # type: ignore[attr-defined] self.config.hook.pytest_handlecrashitem( crashitem=nodeid, @@ -404,10 +469,10 @@ class WorkerStatus(Enum): class TerminalDistReporter: - def __init__(self, config) -> None: + def __init__(self, config: pytest.Config) -> None: self.config = config self.tr = config.pluginmanager.getplugin("terminalreporter") - self._status: dict[str, tuple[WorkerStatus, int]] = {} + self._status: dict[object, tuple[WorkerStatus, int]] = {} self._lastlen = 0 self._isatty = getattr(self.tr, "isatty", self.tr.hasmarkup) @@ -419,7 +484,12 @@ def ensure_show_status(self) -> None: self.write_line(self.getstatus()) def setstatus( - self, spec, status: WorkerStatus, *, tests_collected: int, show: bool = True + self, + spec: execnet.XSpec, + status: WorkerStatus, + *, + tests_collected: int, + show: bool = True, ) -> None: self._status[spec.id] = (status, tests_collected) if show and self._isatty: @@ -433,7 +503,7 @@ def getstatus(self) -> str: return "bringing up nodes..." - def rewrite(self, line, newline=False): + def rewrite(self, line: str, newline: bool = False) -> None: pline = line + " " * max(self._lastlen - len(line), 0) if newline: self._lastlen = 0 @@ -443,7 +513,7 @@ def rewrite(self, line, newline=False): self.tr.rewrite(pline, bold=True) @pytest.hookimpl - def pytest_xdist_setupnodes(self, specs) -> None: + def pytest_xdist_setupnodes(self, specs: Sequence[execnet.XSpec]) -> None: self._specs = specs for spec in specs: self.setstatus(spec, WorkerStatus.Created, tests_collected=0, show=False) @@ -451,7 +521,7 @@ def pytest_xdist_setupnodes(self, specs) -> None: self.ensure_show_status() @pytest.hookimpl - def pytest_xdist_newgateway(self, gateway) -> None: + def pytest_xdist_newgateway(self, gateway: execnet.Gateway) -> None: if self.config.option.verbose > 0: rinfo = gateway._rinfo() different_interpreter = rinfo.executable != sys.executable @@ -464,7 +534,7 @@ def pytest_xdist_newgateway(self, gateway) -> None: self.setstatus(gateway.spec, WorkerStatus.Initialized, tests_collected=0) @pytest.hookimpl - def pytest_testnodeready(self, node) -> None: + def pytest_testnodeready(self, node: WorkerController) -> None: if self.config.option.verbose > 0: d = node.workerinfo different_interpreter = d.get("executable") != sys.executable @@ -476,23 +546,25 @@ def pytest_testnodeready(self, node) -> None: ) @pytest.hookimpl - def pytest_testnodedown(self, node, error) -> None: + def pytest_testnodedown(self, node: WorkerController, error: object) -> None: if not error: return self.write_line(f"[{node.gateway.id}] node down: {error}") -def get_default_max_worker_restart(config): +def get_default_max_worker_restart(config: pytest.Config) -> int | None: """Gets the default value of --max-worker-restart option if it is not provided. Use a reasonable default to avoid workers from restarting endlessly due to crashing collections (#226). """ - result = config.option.maxworkerrestart - if result is not None: - result = int(result) + result_str: str | None = config.option.maxworkerrestart + if result_str is not None: + result = int(result_str) elif config.option.numprocesses: # if --max-worker-restart was not provided, use a reasonable default (#226) result = config.option.numprocesses * 4 + else: + result = None return result diff --git a/src/xdist/looponfail.py b/src/xdist/looponfail.py index 8c2a60ab..7e30e4cf 100644 --- a/src/xdist/looponfail.py +++ b/src/xdist/looponfail.py @@ -13,6 +13,7 @@ from pathlib import Path import sys import time +from typing import Any from typing import Sequence from _pytest._io import TerminalWriter @@ -23,7 +24,7 @@ @pytest.hookimpl -def pytest_addoption(parser): +def pytest_addoption(parser: pytest.Parser) -> None: group = parser.getgroup("xdist", "distributed and subprocess testing") group._addoption( "-f", @@ -37,13 +38,14 @@ def pytest_addoption(parser): @pytest.hookimpl -def pytest_cmdline_main(config): +def pytest_cmdline_main(config: pytest.Config) -> int | None: if config.getoption("looponfail"): usepdb = config.getoption("usepdb", False) # a core option if usepdb: raise pytest.UsageError("--pdb is incompatible with --looponfail.") looponfail_main(config) return 2 # looponfail only can get stop with ctrl-C anyway + return None def looponfail_main(config: pytest.Config) -> None: @@ -68,19 +70,21 @@ def looponfail_main(config: pytest.Config) -> None: class RemoteControl: - def __init__(self, config): + gateway: execnet.Gateway + + def __init__(self, config: pytest.Config) -> None: self.config = config - self.failures = [] + self.failures: list[str] = [] - def trace(self, *args): + def trace(self, *args: object) -> None: if self.config.option.debug: msg = " ".join(str(x) for x in args) print("RemoteControl:", msg) - def initgateway(self): + def initgateway(self) -> execnet.Gateway: return execnet.makegateway("popen") - def setup(self): + def setup(self) -> None: if hasattr(self, "gateway"): raise ValueError("already have gateway %r" % self.gateway) self.trace("setting up worker session") @@ -90,17 +94,17 @@ def setup(self): args=self.config.args, option_dict=vars(self.config.option), ) - remote_outchannel = channel.receive() + remote_outchannel: execnet.Channel = channel.receive() out = TerminalWriter() - def write(s): + def write(s: str) -> None: out._file.write(s) out._file.flush() remote_outchannel.setcallback(write) - def ensure_teardown(self): + def ensure_teardown(self) -> None: if hasattr(self, "channel"): if not self.channel.isclosed(): self.trace("closing", self.channel) @@ -111,12 +115,12 @@ def ensure_teardown(self): self.gateway.exit() del self.gateway - def runsession(self): + def runsession(self) -> tuple[list[str], list[str], bool]: try: self.trace("sending", self.failures) self.channel.send(self.failures) try: - return self.channel.receive() + return self.channel.receive() # type: ignore[no-any-return] except self.channel.RemoteError: e = sys.exc_info()[1] self.trace("ERROR", e) @@ -124,7 +128,7 @@ def runsession(self): finally: self.ensure_teardown() - def loop_once(self): + def loop_once(self) -> None: self.setup() self.wasfailing = self.failures and len(self.failures) result = self.runsession() @@ -139,7 +143,9 @@ def loop_once(self): self.failures = uniq_failures -def repr_pytest_looponfailinfo(failreports, rootdirs): +def repr_pytest_looponfailinfo( + failreports: Sequence[str], rootdirs: Sequence[Path] +) -> None: tr = TerminalWriter() if failreports: tr.sep("#", "LOOPONFAILING", bold=True) @@ -151,12 +157,16 @@ def repr_pytest_looponfailinfo(failreports, rootdirs): tr.line(f"### Watching: {rootdir}", bold=True) -def init_worker_session(channel, args, option_dict): +def init_worker_session( + channel: "execnet.Channel", # noqa: UP037 + args: list[str], + option_dict: dict[str, "Any"], # noqa: UP037 +) -> None: import os import sys outchannel = channel.gateway.newchannel() - sys.stdout = sys.stderr = outchannel.makefile("w") + sys.stdout = sys.stderr = outchannel.makefile("w") # type: ignore[assignment] channel.send(outchannel) # prune sys.path to not contain relative paths newpaths = [] @@ -179,21 +189,21 @@ def init_worker_session(channel, args, option_dict): class WorkerFailSession: - def __init__(self, config, channel): + def __init__(self, config: pytest.Config, channel: execnet.Channel) -> None: self.config = config self.channel = channel - self.recorded_failures = [] + self.recorded_failures: list[pytest.CollectReport | pytest.TestReport] = [] self.collection_failed = False config.pluginmanager.register(self) config.option.looponfail = False config.option.usepdb = False - def DEBUG(self, *args): + def DEBUG(self, *args: object) -> None: if self.config.option.debug: print(" ".join(map(str, args))) @pytest.hookimpl - def pytest_collection(self, session): + def pytest_collection(self, session: pytest.Session) -> bool: self.session = session self.trails = self.current_command hook = self.session.ihook @@ -208,17 +218,17 @@ def pytest_collection(self, session): return True @pytest.hookimpl - def pytest_runtest_logreport(self, report): + def pytest_runtest_logreport(self, report: pytest.TestReport) -> None: if report.failed: self.recorded_failures.append(report) @pytest.hookimpl - def pytest_collectreport(self, report): + def pytest_collectreport(self, report: pytest.CollectReport) -> None: if report.failed: self.recorded_failures.append(report) self.collection_failed = True - def main(self): + def main(self) -> None: self.DEBUG("WORKER: received configuration, waiting for command trails") try: command = self.channel.receive() @@ -233,7 +243,8 @@ def main(self): loc = rep.longrepr loc = str(getattr(loc, "reprcrash", loc)) failreports.append(loc) - self.channel.send((trails, failreports, self.collection_failed)) + result = (trails, failreports, self.collection_failed) + self.channel.send(result) class StatRecorder: @@ -248,7 +259,7 @@ def fil(self, p: Path) -> bool: def rec(self, p: Path) -> bool: return not p.name.startswith(".") and p.exists() - def waitonchange(self, checkinterval=1.0): + def waitonchange(self, checkinterval: float = 1.0) -> None: while 1: changed = self.check() if changed: diff --git a/src/xdist/newhooks.py b/src/xdist/newhooks.py index ceac11ed..5bfce7c4 100644 --- a/src/xdist/newhooks.py +++ b/src/xdist/newhooks.py @@ -12,16 +12,32 @@ http://pytest.org/en/latest/writing_plugins.html#optionally-using-hooks-from-3rd-party-plugins """ +from __future__ import annotations + +import os +from typing import Any +from typing import Sequence +from typing import TYPE_CHECKING + +import execnet import pytest +if TYPE_CHECKING: + from xdist.remote import Producer + from xdist.scheduler.protocol import Scheduling + from xdist.workermanage import WorkerController + + @pytest.hookspec() -def pytest_xdist_setupnodes(config, specs): +def pytest_xdist_setupnodes( + config: pytest.Config, specs: Sequence[execnet.XSpec] +) -> None: """Called before any remote node is set up.""" @pytest.hookspec() -def pytest_xdist_newgateway(gateway): +def pytest_xdist_newgateway(gateway: execnet.Gateway) -> None: """Called on new raw gateway creation.""" @@ -30,7 +46,10 @@ def pytest_xdist_newgateway(gateway): "rsync feature is deprecated and will be removed in pytest-xdist 4.0" ) ) -def pytest_xdist_rsyncstart(source, gateways): +def pytest_xdist_rsyncstart( + source: str | os.PathLike[str], + gateways: Sequence[execnet.Gateway], +) -> None: """Called before rsyncing a directory to remote gateways takes place.""" @@ -39,52 +58,62 @@ def pytest_xdist_rsyncstart(source, gateways): "rsync feature is deprecated and will be removed in pytest-xdist 4.0" ) ) -def pytest_xdist_rsyncfinish(source, gateways): +def pytest_xdist_rsyncfinish( + source: str | os.PathLike[str], + gateways: Sequence[execnet.Gateway], +) -> None: """Called after rsyncing a directory to remote gateways takes place.""" @pytest.hookspec(firstresult=True) -def pytest_xdist_getremotemodule(): +def pytest_xdist_getremotemodule() -> Any: """Called when creating remote node.""" @pytest.hookspec() -def pytest_configure_node(node): +def pytest_configure_node(node: WorkerController) -> None: """Configure node information before it gets instantiated.""" @pytest.hookspec() -def pytest_testnodeready(node): +def pytest_testnodeready(node: WorkerController) -> None: """Test Node is ready to operate.""" @pytest.hookspec() -def pytest_testnodedown(node, error): +def pytest_testnodedown(node: WorkerController, error: object | None) -> None: """Test Node is down.""" @pytest.hookspec() -def pytest_xdist_node_collection_finished(node, ids): +def pytest_xdist_node_collection_finished( + node: WorkerController, ids: Sequence[str] +) -> None: """Called by the controller node when a worker node finishes collecting.""" @pytest.hookspec(firstresult=True) -def pytest_xdist_make_scheduler(config, log): +def pytest_xdist_make_scheduler( + config: pytest.Config, log: Producer +) -> Scheduling | None: """Return a node scheduler implementation.""" @pytest.hookspec(firstresult=True) -def pytest_xdist_auto_num_workers(config): +def pytest_xdist_auto_num_workers(config: pytest.Config) -> int: """ Return the number of workers to spawn when ``--numprocesses=auto`` is given in the command-line. .. versionadded:: 2.1 """ + raise NotImplementedError() @pytest.hookspec(firstresult=True) -def pytest_handlecrashitem(crashitem, report, sched): +def pytest_handlecrashitem( + crashitem: str, report: pytest.TestReport, sched: Scheduling +) -> None: """ Handle a crashitem, modifying the report if necessary. diff --git a/src/xdist/plugin.py b/src/xdist/plugin.py index 14d081f7..f670d9de 100644 --- a/src/xdist/plugin.py +++ b/src/xdist/plugin.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import os import sys +from typing import Literal import uuid import warnings @@ -10,7 +13,7 @@ @pytest.hookimpl -def pytest_xdist_auto_num_workers(config): +def pytest_xdist_auto_num_workers(config: pytest.Config) -> int: env_var = os.environ.get("PYTEST_XDIST_AUTO_NUM_WORKERS") if env_var: try: @@ -25,14 +28,14 @@ def pytest_xdist_auto_num_workers(config): except ImportError: pass else: - use_logical = config.option.numprocesses == "logical" + use_logical: bool = config.option.numprocesses == "logical" count = psutil.cpu_count(logical=use_logical) or psutil.cpu_count() if count: return count try: from os import sched_getaffinity - def cpu_count(): + def cpu_count() -> int: return len(sched_getaffinity(0)) except ImportError: @@ -40,7 +43,7 @@ def cpu_count(): # workaround https://bitbucket.org/pypy/pypy/issues/2375 return 2 try: - from os import cpu_count + from os import cpu_count # type: ignore[assignment] except ImportError: from multiprocessing import cpu_count try: @@ -50,15 +53,15 @@ def cpu_count(): return n if n else 1 -def parse_numprocesses(s): +def parse_numprocesses(s: str) -> int | Literal["auto", "logical"]: if s in ("auto", "logical"): - return s + return s # type: ignore[return-value] elif s is not None: return int(s) @pytest.hookimpl -def pytest_addoption(parser): +def pytest_addoption(parser: pytest.Parser) -> None: # 'Help' formatting (same rules as pytest's): # Start with capitalized letters. # If a single phrase, do not end with period. If more than one phrase, all phrases end with periods. @@ -206,7 +209,7 @@ def pytest_addoption(parser): @pytest.hookimpl -def pytest_addhooks(pluginmanager): +def pytest_addhooks(pluginmanager: pytest.PytestPluginManager) -> None: from xdist import newhooks pluginmanager.add_hookspecs(newhooks) @@ -218,7 +221,7 @@ def pytest_addhooks(pluginmanager): @pytest.hookimpl(trylast=True) -def pytest_configure(config): +def pytest_configure(config: pytest.Config) -> None: config_line = ( "xdist_group: specify group for tests should run in same session." "in relation to one another. Provided by pytest-xdist." @@ -256,16 +259,13 @@ def pytest_configure(config): config.issue_config_time_warning(warning, 2) -def _is_distribution_mode(config): - """Return `True` if distribution mode is on, `False` otherwise. - - :param config: the `pytest` `config` object - """ - return config.getoption("dist") != "no" and config.getoption("tx") +def _is_distribution_mode(config: pytest.Config) -> bool: + """Whether distribution mode is on.""" + return config.getoption("dist") != "no" and bool(config.getoption("tx")) @pytest.hookimpl(tryfirst=True) -def pytest_cmdline_main(config): +def pytest_cmdline_main(config: pytest.Config) -> None: if config.option.distload: config.option.dist = "load" @@ -302,7 +302,9 @@ def pytest_cmdline_main(config): # ------------------------------------------------------------------------- -def is_xdist_worker(request_or_session) -> bool: +def is_xdist_worker( + request_or_session: pytest.FixtureRequest | pytest.Session, +) -> bool: """Return `True` if this is an xdist worker, `False` otherwise. :param request_or_session: the `pytest` `request` or `session` object @@ -310,7 +312,9 @@ def is_xdist_worker(request_or_session) -> bool: return hasattr(request_or_session.config, "workerinput") -def is_xdist_controller(request_or_session) -> bool: +def is_xdist_controller( + request_or_session: pytest.FixtureRequest | pytest.Session, +) -> bool: """Return `True` if this is the xdist controller, `False` otherwise. Note: this method also returns `False` when distribution has not been @@ -328,7 +332,9 @@ def is_xdist_controller(request_or_session) -> bool: is_xdist_master = is_xdist_controller -def get_xdist_worker_id(request_or_session): +def get_xdist_worker_id( + request_or_session: pytest.FixtureRequest | pytest.Session, +) -> str: """Return the id of the current worker ('gw0', 'gw1', etc) or 'master' if running on the controller node. @@ -338,14 +344,15 @@ def get_xdist_worker_id(request_or_session): :param request_or_session: the `pytest` `request` or `session` object """ if hasattr(request_or_session.config, "workerinput"): - return request_or_session.config.workerinput["workerid"] + workerid: str = request_or_session.config.workerinput["workerid"] + return workerid else: # TODO: remove "master", ideally for a None return "master" @pytest.fixture(scope="session") -def worker_id(request): +def worker_id(request: pytest.FixtureRequest) -> str: """Return the id of the current worker ('gw0', 'gw1', etc) or 'master' if running on the master node. """ @@ -354,9 +361,10 @@ def worker_id(request): @pytest.fixture(scope="session") -def testrun_uid(request): +def testrun_uid(request: pytest.FixtureRequest) -> str: """Return the unique id of the current test.""" if hasattr(request.config, "workerinput"): - return request.config.workerinput["testrunuid"] + testrunid: str = request.config.workerinput["testrunuid"] + return testrunid else: return uuid.uuid4().hex diff --git a/src/xdist/remote.py b/src/xdist/remote.py index ac1bf1ca..dd1f9883 100644 --- a/src/xdist/remote.py +++ b/src/xdist/remote.py @@ -6,16 +6,22 @@ needs not to be installed in remote environments. """ +from __future__ import annotations + import contextlib import enum import os import sys import time from typing import Any +from typing import Generator +from typing import Literal +from typing import Sequence +from typing import TypedDict +import warnings from _pytest.config import _prepareconfig -from execnet.gateway_base import DumpError -from execnet.gateway_base import dumps +import execnet import pytest @@ -23,7 +29,7 @@ from setproctitle import setproctitle except ImportError: - def setproctitle(title): + def setproctitle(title: str) -> None: pass @@ -35,7 +41,7 @@ class Producer: to have the other way around. """ - def __init__(self, name: str, *, enabled: bool = True): + def __init__(self, name: str, *, enabled: bool = True) -> None: self.name = name self.enabled = enabled @@ -46,11 +52,11 @@ def __call__(self, *a: Any, **k: Any) -> None: if self.enabled: print(f"[{self.name}]", *a, **k, file=sys.stderr) - def __getattr__(self, name: str) -> "Producer": + def __getattr__(self, name: str) -> Producer: return type(self)(name, enabled=self.enabled) -def worker_title(title): +def worker_title(title: str) -> None: try: setproctitle(title) except Exception: @@ -64,59 +70,63 @@ class Marker(enum.Enum): class WorkerInteractor: - def __init__(self, config, channel): + def __init__(self, config: pytest.Config, channel: execnet.Channel) -> None: self.config = config - self.workerid = config.workerinput.get("workerid", "?") - self.testrunuid = config.workerinput["testrunuid"] + workerinput: dict[str, Any] = config.workerinput # type: ignore[attr-defined] + self.workerid = workerinput.get("workerid", "?") + self.testrunuid = workerinput["testrunuid"] self.log = Producer(f"worker-{self.workerid}", enabled=config.option.debug) self.channel = channel self.torun = self._make_queue() - self.nextitem_index = None + self.nextitem_index: int | None | Literal[Marker.SHUTDOWN] = None config.pluginmanager.register(self) - def _make_queue(self): + def _make_queue(self) -> Any: return self.channel.gateway.execmodel.queue.Queue() - def _get_next_item_index(self): + def _get_next_item_index(self) -> int | Literal[Marker.SHUTDOWN]: """Gets the next item from test queue. Handles the case when the queue is replaced concurrently in another thread. """ result = self.torun.get() while result is Marker.QUEUE_REPLACED: result = self.torun.get() - return result + return result # type: ignore[no-any-return] - def sendevent(self, name, **kwargs): + def sendevent(self, name: str, **kwargs: object) -> None: self.log("sending", name, kwargs) self.channel.send((name, kwargs)) @pytest.hookimpl - def pytest_internalerror(self, excrepr): + def pytest_internalerror(self, excrepr: object) -> None: formatted_error = str(excrepr) for line in formatted_error.split("\n"): self.log("IERROR>", line) interactor.sendevent("internal_error", formatted_error=formatted_error) @pytest.hookimpl - def pytest_sessionstart(self, session): + def pytest_sessionstart(self, session: pytest.Session) -> None: self.session = session workerinfo = getinfodict() self.sendevent("workerready", workerinfo=workerinfo) @pytest.hookimpl(hookwrapper=True) - def pytest_sessionfinish(self, exitstatus): + def pytest_sessionfinish(self, exitstatus: int) -> Generator[None, object, None]: + workeroutput: dict[str, Any] = self.config.workeroutput # type: ignore[attr-defined] # in pytest 5.0+, exitstatus is an IntEnum object - self.config.workeroutput["exitstatus"] = int(exitstatus) - self.config.workeroutput["shouldfail"] = self.session.shouldfail - self.config.workeroutput["shouldstop"] = self.session.shouldstop + workeroutput["exitstatus"] = int(exitstatus) + workeroutput["shouldfail"] = self.session.shouldfail + workeroutput["shouldstop"] = self.session.shouldstop yield - self.sendevent("workerfinished", workeroutput=self.config.workeroutput) + self.sendevent("workerfinished", workeroutput=workeroutput) @pytest.hookimpl - def pytest_collection(self, session): + def pytest_collection(self) -> None: self.sendevent("collectionstart") - def handle_command(self, command): + def handle_command( + self, command: tuple[str, dict[str, Any]] | Literal[Marker.SHUTDOWN] + ) -> None: if command is Marker.SHUTDOWN: self.torun.put(Marker.SHUTDOWN) return @@ -135,18 +145,19 @@ def handle_command(self, command): elif name == "steal": self.steal(kwargs["indices"]) - def steal(self, indices): - indices = set(indices) + def steal(self, indices: Sequence[int]) -> None: + indices_set = set(indices) stolen = [] old_queue, self.torun = self.torun, self._make_queue() - def old_queue_get_nowait_noraise(): + def old_queue_get_nowait_noraise() -> int | None: with contextlib.suppress(self.channel.gateway.execmodel.queue.Empty): - return old_queue.get_nowait() + return old_queue.get_nowait() # type: ignore[no-any-return] + return None for i in iter(old_queue_get_nowait_noraise, None): - if i in indices: + if i in indices_set: stolen.append(i) else: self.torun.put(i) @@ -155,7 +166,7 @@ def old_queue_get_nowait_noraise(): old_queue.put(Marker.QUEUE_REPLACED) @pytest.hookimpl - def pytest_runtestloop(self, session): + def pytest_runtestloop(self, session: pytest.Session) -> bool: self.log("entering main loop") self.channel.setcallback(self.handle_command, endmarker=Marker.SHUTDOWN) self.nextitem_index = self._get_next_item_index() @@ -165,7 +176,8 @@ def pytest_runtestloop(self, session): break return True - def run_one_test(self): + def run_one_test(self) -> None: + assert isinstance(self.nextitem_index, int) self.item_index = self.nextitem_index self.nextitem_index = self._get_next_item_index() @@ -174,6 +186,7 @@ def run_one_test(self): if self.nextitem_index is Marker.SHUTDOWN: nextitem = None else: + assert self.nextitem_index is not None nextitem = items[self.nextitem_index] worker_title("[pytest-xdist running] %s" % item.nodeid) @@ -188,7 +201,11 @@ def run_one_test(self): "runtest_protocol_complete", item_index=self.item_index, duration=duration ) - def pytest_collection_modifyitems(self, session, config, items): + def pytest_collection_modifyitems( + self, + config: pytest.Config, + items: list[pytest.Item], + ) -> None: # add the group name to nodeid as suffix if --dist=loadgroup if config.getvalue("loadgroup"): for item in items: @@ -203,7 +220,7 @@ def pytest_collection_modifyitems(self, session, config, items): item._nodeid = f"{item.nodeid}@{gname}" @pytest.hookimpl - def pytest_collection_finish(self, session): + def pytest_collection_finish(self, session: pytest.Session) -> None: self.sendevent( "collectionfinish", topdir=str(self.config.rootpath), @@ -211,15 +228,23 @@ def pytest_collection_finish(self, session): ) @pytest.hookimpl - def pytest_runtest_logstart(self, nodeid, location): + def pytest_runtest_logstart( + self, + nodeid: str, + location: tuple[str, int | None, str], + ) -> None: self.sendevent("logstart", nodeid=nodeid, location=location) @pytest.hookimpl - def pytest_runtest_logfinish(self, nodeid, location): + def pytest_runtest_logfinish( + self, + nodeid: str, + location: tuple[str, int | None, str], + ) -> None: self.sendevent("logfinish", nodeid=nodeid, location=location) @pytest.hookimpl - def pytest_runtest_logreport(self, report): + def pytest_runtest_logreport(self, report: pytest.TestReport) -> None: data = self.config.hook.pytest_report_to_serializable( config=self.config, report=report ) @@ -230,7 +255,7 @@ def pytest_runtest_logreport(self, report): self.sendevent("testreport", data=data) @pytest.hookimpl - def pytest_collectreport(self, report): + def pytest_collectreport(self, report: pytest.CollectReport) -> None: # send only reports that have not passed to controller as optimization (#330) if not report.passed: data = self.config.hook.pytest_report_to_serializable( @@ -239,7 +264,13 @@ def pytest_collectreport(self, report): self.sendevent("collectreport", data=data) @pytest.hookimpl - def pytest_warning_recorded(self, warning_message, when, nodeid, location): + def pytest_warning_recorded( + self, + warning_message: warnings.WarningMessage, + when: str, + nodeid: str, + location: tuple[str, int, str] | None, + ) -> None: self.sendevent( "warning_recorded", warning_message_data=serialize_warning_message(warning_message), @@ -249,7 +280,9 @@ def pytest_warning_recorded(self, warning_message, when, nodeid, location): ) -def serialize_warning_message(warning_message): +def serialize_warning_message( + warning_message: warnings.WarningMessage, +) -> dict[str, Any]: if isinstance(warning_message.message, Warning): message_module = type(warning_message.message).__module__ message_class_name = type(warning_message.message).__name__ @@ -257,8 +290,8 @@ def serialize_warning_message(warning_message): # check now if we can serialize the warning arguments (#349) # if not, we will just use the exception message on the controller node try: - dumps(warning_message.message.args) - except DumpError: + execnet.dumps(warning_message.message.args) + except execnet.DumpError: message_args = None else: message_args = warning_message.message.args @@ -283,27 +316,38 @@ def serialize_warning_message(warning_message): "category_class_name": category_class_name, } # access private _WARNING_DETAILS because the attributes vary between Python versions - for attr_name in warning_message._WARNING_DETAILS: + for attr_name in warning_message._WARNING_DETAILS: # type: ignore[attr-defined] if attr_name in ("message", "category"): continue attr = getattr(warning_message, attr_name) # Check if we can serialize the warning detail, marking `None` otherwise # Note that we need to define the attr (even as `None`) to allow deserializing try: - dumps(attr) - except DumpError: + execnet.dumps(attr) + except execnet.DumpError: result[attr_name] = repr(attr) else: result[attr_name] = attr return result -def getinfodict(): +class WorkerInfo(TypedDict): + version: str + version_info: tuple[int, int, int, str, int] + sysplatform: str + platform: str + executable: str + cwd: str + id: str + spec: execnet.XSpec + + +def getinfodict() -> WorkerInfo: import platform return dict( version=sys.version, - version_info=tuple(sys.version_info), + version_info=tuple(sys.version_info), # type: ignore[typeddict-item] sysplatform=sys.platform, platform=platform.platform(), executable=sys.executable, @@ -311,7 +355,7 @@ def getinfodict(): ) -def setup_config(config, basetemp): +def setup_config(config: pytest.Config, basetemp: str | None) -> None: config.option.loadgroup = config.getvalue("dist") == "loadgroup" config.option.looponfail = False config.option.usepdb = False @@ -323,7 +367,7 @@ def setup_config(config, basetemp): if __name__ == "__channelexec__": - channel = channel # type: ignore[name-defined] # noqa: F821, PLW0127 + channel: execnet.Channel = channel # type: ignore[name-defined] # noqa: F821, PLW0127 workerinput, args, option_dict, change_sys_path = channel.receive() # type: ignore[name-defined] if change_sys_path is None: diff --git a/src/xdist/report.py b/src/xdist/report.py index d956577d..eb463abd 100644 --- a/src/xdist/report.py +++ b/src/xdist/report.py @@ -1,7 +1,15 @@ +from __future__ import annotations + from difflib import unified_diff +from typing import Sequence -def report_collection_diff(from_collection, to_collection, from_id, to_id): +def report_collection_diff( + from_collection: Sequence[str], + to_collection: Sequence[str], + from_id: str, + to_id: str, +) -> str | None: """Report the collected test difference between two nodes. :returns: detailed message describing the difference between the given diff --git a/src/xdist/scheduler/each.py b/src/xdist/scheduler/each.py index 47f7add3..aa4f7ba1 100644 --- a/src/xdist/scheduler/each.py +++ b/src/xdist/scheduler/each.py @@ -1,6 +1,13 @@ +from __future__ import annotations + +from typing import Sequence + +import pytest + from xdist.remote import Producer from xdist.report import report_collection_diff from xdist.workermanage import parse_spec_config +from xdist.workermanage import WorkerController class EachScheduling: @@ -17,13 +24,13 @@ class EachScheduling: assigned the remaining items from the removed node. """ - def __init__(self, config, log=None): + def __init__(self, config: pytest.Config, log: Producer | None = None) -> None: self.config = config self.numnodes = len(parse_spec_config(config)) - self.node2collection = {} - self.node2pending = {} - self._started = [] - self._removed2pending = {} + self.node2collection: dict[WorkerController, list[str]] = {} + self.node2pending: dict[WorkerController, list[int]] = {} + self._started: list[WorkerController] = [] + self._removed2pending: dict[WorkerController, list[int]] = {} if log is None: self.log = Producer("eachsched") else: @@ -31,12 +38,12 @@ def __init__(self, config, log=None): self.collection_is_completed = False @property - def nodes(self): + def nodes(self) -> list[WorkerController]: """A list of all nodes in the scheduler.""" return list(self.node2pending.keys()) @property - def tests_finished(self): + def tests_finished(self) -> bool: if not self.collection_is_completed: return False if self._removed2pending: @@ -47,7 +54,7 @@ def tests_finished(self): return True @property - def has_pending(self): + def has_pending(self) -> bool: """Return True if there are pending test items. This indicates that collection has finished and nodes are @@ -59,11 +66,13 @@ def has_pending(self): return True return False - def add_node(self, node): + def add_node(self, node: WorkerController) -> None: assert node not in self.node2pending self.node2pending[node] = [] - def add_node_collection(self, node, collection): + def add_node_collection( + self, node: WorkerController, collection: Sequence[str] + ) -> None: """Add the collected test items from a node. Collection is complete once all nodes have submitted their @@ -97,26 +106,32 @@ def add_node_collection(self, node, collection): self.node2pending[node] = pending break - def mark_test_complete(self, node, item_index, duration=0): + def mark_test_complete( + self, node: WorkerController, item_index: int, duration: float = 0 + ) -> None: self.node2pending[node].remove(item_index) - def mark_test_pending(self, item): + def mark_test_pending(self, item: str) -> None: raise NotImplementedError() - def remove_pending_tests_from_node(self, node, indices): + def remove_pending_tests_from_node( + self, + node: WorkerController, + indices: Sequence[int], + ) -> None: raise NotImplementedError() - def remove_node(self, node): + def remove_node(self, node: WorkerController) -> str | None: # KeyError if we didn't get an add_node() yet pending = self.node2pending.pop(node) if not pending: - return + return None crashitem = self.node2collection[node][pending.pop(0)] if pending: self._removed2pending[node] = pending return crashitem - def schedule(self): + def schedule(self) -> None: """Schedule the test items on the nodes. If the node's pending list is empty it is a new node which diff --git a/src/xdist/scheduler/load.py b/src/xdist/scheduler/load.py index fb9bdfdc..9d153bb9 100644 --- a/src/xdist/scheduler/load.py +++ b/src/xdist/scheduler/load.py @@ -1,10 +1,14 @@ +from __future__ import annotations + from itertools import cycle +from typing import Sequence import pytest from xdist.remote import Producer from xdist.report import report_collection_diff from xdist.workermanage import parse_spec_config +from xdist.workermanage import WorkerController class LoadScheduling: @@ -53,12 +57,12 @@ class LoadScheduling: :config: Config object, used for handling hooks. """ - def __init__(self, config, log=None): + def __init__(self, config: pytest.Config, log: Producer | None = None) -> None: self.numnodes = len(parse_spec_config(config)) - self.node2collection = {} - self.node2pending = {} - self.pending = [] - self.collection = None + self.node2collection: dict[WorkerController, list[str]] = {} + self.node2pending: dict[WorkerController, list[int]] = {} + self.pending: list[int] = [] + self.collection: list[str] | None = None if log is None: self.log = Producer("loadsched") else: @@ -67,12 +71,12 @@ def __init__(self, config, log=None): self.maxschedchunk = self.config.getoption("maxschedchunk") @property - def nodes(self): + def nodes(self) -> list[WorkerController]: """A list of all nodes in the scheduler.""" return list(self.node2pending.keys()) @property - def collection_is_completed(self): + def collection_is_completed(self) -> bool: """Boolean indication initial test collection is complete. This is a boolean indicating all initial participating nodes @@ -82,7 +86,7 @@ def collection_is_completed(self): return len(self.node2collection) >= self.numnodes @property - def tests_finished(self): + def tests_finished(self) -> bool: """Return True if all tests have been executed by the nodes.""" if not self.collection_is_completed: return False @@ -94,7 +98,7 @@ def tests_finished(self): return True @property - def has_pending(self): + def has_pending(self) -> bool: """Return True if there are pending test items. This indicates that collection has finished and nodes are @@ -108,7 +112,7 @@ def has_pending(self): return True return False - def add_node(self, node): + def add_node(self, node: WorkerController) -> None: """Add a new node to the scheduler. From now on the node will be allocated chunks of tests to @@ -120,7 +124,9 @@ def add_node(self, node): assert node not in self.node2pending self.node2pending[node] = [] - def add_node_collection(self, node, collection): + def add_node_collection( + self, node: WorkerController, collection: Sequence[str] + ) -> None: """Add the collected test items from a node. The collection is stored in the ``.node2collection`` map. @@ -141,7 +147,9 @@ def add_node_collection(self, node, collection): return self.node2collection[node] = list(collection) - def mark_test_complete(self, node, item_index, duration=0): + def mark_test_complete( + self, node: WorkerController, item_index: int, duration: float = 0 + ) -> None: """Mark test item as completed by node. The duration it took to execute the item is used as a hint to @@ -152,7 +160,8 @@ def mark_test_complete(self, node, item_index, duration=0): self.node2pending[node].remove(item_index) self.check_schedule(node, duration=duration) - def mark_test_pending(self, item): + def mark_test_pending(self, item: str) -> None: + assert self.collection is not None self.pending.insert( 0, self.collection.index(item), @@ -160,10 +169,14 @@ def mark_test_pending(self, item): for node in self.node2pending: self.check_schedule(node) - def remove_pending_tests_from_node(self, node, indices): + def remove_pending_tests_from_node( + self, + node: WorkerController, + indices: Sequence[int], + ) -> None: raise NotImplementedError() - def check_schedule(self, node, duration=0): + def check_schedule(self, node: WorkerController, duration: float = 0) -> None: """Maybe schedule new items on the node. If there are any globally pending nodes left then this will @@ -197,7 +210,7 @@ def check_schedule(self, node, duration=0): self.log("num items waiting for node:", len(self.pending)) - def remove_node(self, node): + def remove_node(self, node: WorkerController) -> str | None: """Remove a node from the scheduler. This should be called either when the node crashed or at @@ -212,16 +225,17 @@ def remove_node(self, node): """ pending = self.node2pending.pop(node) if not pending: - return + return None # The node crashed, reassing pending items + assert self.collection is not None crashitem = self.collection[pending.pop(0)] self.pending.extend(pending) for node in self.node2pending: self.check_schedule(node) return crashitem - def schedule(self): + def schedule(self) -> None: """Initiate distribution of the test collection. Initiate scheduling of the items across the nodes. If this @@ -285,14 +299,14 @@ def schedule(self): for node in self.nodes: node.shutdown() - def _send_tests(self, node, num): + def _send_tests(self, node: WorkerController, num: int) -> None: tests_per_node = self.pending[:num] if tests_per_node: del self.pending[:num] self.node2pending[node].extend(tests_per_node) node.send_runtest_some(tests_per_node) - def _check_nodes_have_same_collection(self): + def _check_nodes_have_same_collection(self) -> bool: """Return True if all nodes have collected the same items. If collections differ, this method returns False while logging diff --git a/src/xdist/scheduler/loadfile.py b/src/xdist/scheduler/loadfile.py index 25b72da4..fb6f027f 100644 --- a/src/xdist/scheduler/loadfile.py +++ b/src/xdist/scheduler/loadfile.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +import pytest + from xdist.remote import Producer from .loadscope import LoadScopeScheduling @@ -21,14 +25,14 @@ class LoadFileScheduling(LoadScopeScheduling): This class behaves very much like LoadScopeScheduling, but with a file-level scope. """ - def __init__(self, config, log=None): + def __init__(self, config: pytest.Config, log: Producer | None = None) -> None: super().__init__(config, log) if log is None: self.log = Producer("loadfilesched") else: self.log = log.loadfilesched - def _split_scope(self, nodeid): + def _split_scope(self, nodeid: str) -> str: """Determine the scope (grouping) of a nodeid. There are usually 3 cases for a nodeid:: diff --git a/src/xdist/scheduler/loadgroup.py b/src/xdist/scheduler/loadgroup.py index 1dee40e6..798c7128 100644 --- a/src/xdist/scheduler/loadgroup.py +++ b/src/xdist/scheduler/loadgroup.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +import pytest + from xdist.remote import Producer from .loadscope import LoadScopeScheduling @@ -10,14 +14,14 @@ class LoadGroupScheduling(LoadScopeScheduling): instead of the module or class to which they belong to. """ - def __init__(self, config, log=None): + def __init__(self, config: pytest.Config, log: Producer | None = None) -> None: super().__init__(config, log) if log is None: self.log = Producer("loadgroupsched") else: self.log = log.loadgroupsched - def _split_scope(self, nodeid): + def _split_scope(self, nodeid: str) -> str: """Determine the scope (grouping) of a nodeid. There are usually 3 cases for a nodeid:: diff --git a/src/xdist/scheduler/loadscope.py b/src/xdist/scheduler/loadscope.py index 7c66ed51..a4d63b29 100644 --- a/src/xdist/scheduler/loadscope.py +++ b/src/xdist/scheduler/loadscope.py @@ -1,10 +1,15 @@ +from __future__ import annotations + from collections import OrderedDict +from typing import NoReturn +from typing import Sequence import pytest from xdist.remote import Producer from xdist.report import report_collection_diff from xdist.workermanage import parse_spec_config +from xdist.workermanage import WorkerController class LoadScopeScheduling: @@ -85,13 +90,13 @@ class LoadScopeScheduling: :config: Config object, used for handling hooks. """ - def __init__(self, config, log=None): + def __init__(self, config: pytest.Config, log: Producer | None = None) -> None: self.numnodes = len(parse_spec_config(config)) - self.collection = None + self.collection: list[str] | None = None - self.workqueue = OrderedDict() - self.assigned_work = {} - self.registered_collections = {} + self.workqueue: OrderedDict[str, dict[str, bool]] = OrderedDict() + self.assigned_work: dict[WorkerController, dict[str, dict[str, bool]]] = {} + self.registered_collections: dict[WorkerController, list[str]] = {} if log is None: self.log = Producer("loadscopesched") @@ -101,12 +106,12 @@ def __init__(self, config, log=None): self.config = config @property - def nodes(self): + def nodes(self) -> list[WorkerController]: """A list of all active nodes in the scheduler.""" return list(self.assigned_work.keys()) @property - def collection_is_completed(self): + def collection_is_completed(self) -> bool: """Boolean indication initial test collection is complete. This is a boolean indicating all initial participating nodes have @@ -116,7 +121,7 @@ def collection_is_completed(self): return len(self.registered_collections) >= self.numnodes @property - def tests_finished(self): + def tests_finished(self) -> bool: """Return True if all tests have been executed by the nodes.""" if not self.collection_is_completed: return False @@ -131,7 +136,7 @@ def tests_finished(self): return True @property - def has_pending(self): + def has_pending(self) -> bool: """Return True if there are pending test items. This indicates that collection has finished and nodes are still @@ -147,7 +152,7 @@ def has_pending(self): return False - def add_node(self, node): + def add_node(self, node: WorkerController) -> None: """Add a new node to the scheduler. From now on the node will be assigned work units to be executed. @@ -158,7 +163,7 @@ def add_node(self, node): assert node not in self.assigned_work self.assigned_work[node] = {} - def remove_node(self, node): + def remove_node(self, node: WorkerController) -> str | None: """Remove a node from the scheduler. This should be called either when the node crashed or at shutdown time. @@ -199,7 +204,9 @@ def remove_node(self, node): return crashitem - def add_node_collection(self, node, collection): + def add_node_collection( + self, node: WorkerController, collection: Sequence[str] + ) -> None: """Add the collected test items from a node. The collection is stored in the ``.registered_collections`` dictionary. @@ -228,7 +235,9 @@ def add_node_collection(self, node, collection): self.registered_collections[node] = list(collection) - def mark_test_complete(self, node, item_index, duration=0): + def mark_test_complete( + self, node: WorkerController, item_index: int, duration: float = 0 + ) -> None: """Mark test item as completed by node. Called by the hook: @@ -241,13 +250,17 @@ def mark_test_complete(self, node, item_index, duration=0): self.assigned_work[node][scope][nodeid] = True self._reschedule(node) - def mark_test_pending(self, item): + def mark_test_pending(self, item: str) -> NoReturn: raise NotImplementedError() - def remove_pending_tests_from_node(self, node, indices): + def remove_pending_tests_from_node( + self, + node: WorkerController, + indices: Sequence[int], + ) -> None: raise NotImplementedError() - def _assign_work_unit(self, node): + def _assign_work_unit(self, node: WorkerController) -> None: """Assign a work unit to a node.""" assert self.workqueue @@ -268,7 +281,7 @@ def _assign_work_unit(self, node): node.send_runtest_some(nodeids_indexes) - def _split_scope(self, nodeid): + def _split_scope(self, nodeid: str) -> str: """Determine the scope (grouping) of a nodeid. There are usually 3 cases for a nodeid:: @@ -292,12 +305,12 @@ def _split_scope(self, nodeid): """ return nodeid.rsplit("::", 1)[0] - def _pending_of(self, workload): + def _pending_of(self, workload: dict[str, dict[str, bool]]) -> int: """Return the number of pending tests in a workload.""" pending = sum(list(scope.values()).count(False) for scope in workload.values()) return pending - def _reschedule(self, node): + def _reschedule(self, node: WorkerController) -> None: """Maybe schedule new items on the node. If there are any globally pending work units left then this will check @@ -322,7 +335,7 @@ def _reschedule(self, node): # Pop one unit of work and assign it self._assign_work_unit(node) - def schedule(self): + def schedule(self) -> None: """Initiate distribution of the test collection. Initiate scheduling of the items across the nodes. If this gets called @@ -352,7 +365,7 @@ def schedule(self): return # Determine chunks of work (scopes) - unsorted_workqueue = {} + unsorted_workqueue: dict[str, dict[str, bool]] = {} for nodeid in self.collection: scope = self._split_scope(nodeid) work_unit = unsorted_workqueue.setdefault(scope, {}) @@ -389,7 +402,7 @@ def schedule(self): for node in self.nodes: node.shutdown() - def _check_nodes_have_same_collection(self): + def _check_nodes_have_same_collection(self) -> bool: """Return True if all nodes have collected the same items. If collections differ, this method returns False while logging diff --git a/src/xdist/scheduler/worksteal.py b/src/xdist/scheduler/worksteal.py index 5253c1a4..fd208486 100644 --- a/src/xdist/scheduler/worksteal.py +++ b/src/xdist/scheduler/worksteal.py @@ -1,17 +1,18 @@ from __future__ import annotations -from typing import Any from typing import NamedTuple +from typing import Sequence import pytest from xdist.remote import Producer from xdist.report import report_collection_diff from xdist.workermanage import parse_spec_config +from xdist.workermanage import WorkerController class NodePending(NamedTuple): - node: Any + node: WorkerController pending: list[int] @@ -63,26 +64,26 @@ class WorkStealingScheduling: simultaneous requests. """ - def __init__(self, config, log=None): + def __init__(self, config: pytest.Config, log: Producer | None = None) -> None: self.numnodes = len(parse_spec_config(config)) - self.node2collection = {} - self.node2pending = {} - self.pending = [] - self.collection = None + self.node2collection: dict[WorkerController, list[str]] = {} + self.node2pending: dict[WorkerController, list[int]] = {} + self.pending: list[int] = [] + self.collection: list[str] | None = None if log is None: self.log = Producer("workstealsched") else: self.log = log.workstealsched self.config = config - self.steal_requested_from_node = None + self.steal_requested_from_node: WorkerController | None = None @property - def nodes(self): + def nodes(self) -> list[WorkerController]: """A list of all nodes in the scheduler.""" return list(self.node2pending.keys()) @property - def collection_is_completed(self): + def collection_is_completed(self) -> bool: """Boolean indication initial test collection is complete. This is a boolean indicating all initial participating nodes @@ -92,7 +93,7 @@ def collection_is_completed(self): return len(self.node2collection) >= self.numnodes @property - def tests_finished(self): + def tests_finished(self) -> bool: """Return True if all tests have been executed by the nodes.""" if not self.collection_is_completed: return False @@ -106,7 +107,7 @@ def tests_finished(self): return True @property - def has_pending(self): + def has_pending(self) -> bool: """Return True if there are pending test items. This indicates that collection has finished and nodes are @@ -120,7 +121,7 @@ def has_pending(self): return True return False - def add_node(self, node): + def add_node(self, node: WorkerController) -> None: """Add a new node to the scheduler. From now on the node will be allocated chunks of tests to @@ -132,7 +133,9 @@ def add_node(self, node): assert node not in self.node2pending self.node2pending[node] = [] - def add_node_collection(self, node, collection): + def add_node_collection( + self, node: WorkerController, collection: Sequence[str] + ) -> None: """Add the collected test items from a node. The collection is stored in the ``.node2collection`` map. @@ -153,7 +156,9 @@ def add_node_collection(self, node, collection): return self.node2collection[node] = list(collection) - def mark_test_complete(self, node, item_index, duration=None): + def mark_test_complete( + self, node: WorkerController, item_index: int, duration: float | None = None + ) -> None: """Mark test item as completed by node. This is called by the ``DSession.worker_testreport`` hook. @@ -161,14 +166,19 @@ def mark_test_complete(self, node, item_index, duration=None): self.node2pending[node].remove(item_index) self.check_schedule() - def mark_test_pending(self, item): + def mark_test_pending(self, item: str) -> None: + assert self.collection is not None self.pending.insert( 0, self.collection.index(item), ) self.check_schedule() - def remove_pending_tests_from_node(self, node, indices): + def remove_pending_tests_from_node( + self, + node: WorkerController, + indices: Sequence[int], + ) -> None: """Node returned some test indices back in response to 'steal' command. This is called by ``DSession.worker_unscheduled``. @@ -183,7 +193,7 @@ def remove_pending_tests_from_node(self, node, indices): self.pending.extend(indices) self.check_schedule() - def check_schedule(self): + def check_schedule(self) -> None: """Reschedule tests/perform load balancing.""" nodes_up = [ NodePending(node, pending) @@ -191,7 +201,7 @@ def check_schedule(self): if not node.shutting_down ] - def get_idle_nodes(): + def get_idle_nodes() -> list[WorkerController]: return [node for node, pending in nodes_up if len(pending) < MIN_PENDING] idle_nodes = get_idle_nodes() @@ -235,10 +245,11 @@ def get_idle_nodes(): node.shutdown() return + assert steal_from is not None steal_from.node.send_steal(steal_from.pending[-num_steal:]) self.steal_requested_from_node = steal_from.node - def remove_node(self, node): + def remove_node(self, node: WorkerController) -> str | None: """Remove a node from the scheduler. This should be called either when the node crashed or at @@ -249,12 +260,12 @@ def remove_node(self, node): Return the item which was being executing while the node crashed or None if the node has no more pending items. - """ pending = self.node2pending.pop(node) # If node was removed without completing its assigned tests - it crashed if pending: + assert self.collection is not None crashitem = self.collection[pending.pop(0)] else: crashitem = None @@ -268,7 +279,7 @@ def remove_node(self, node): self.check_schedule() return crashitem - def schedule(self): + def schedule(self) -> None: """Initiate distribution of the test collection. Initiate scheduling of the items across the nodes. If this @@ -298,14 +309,14 @@ def schedule(self): self.check_schedule() - def _send_tests(self, node, num): + def _send_tests(self, node: WorkerController, num: int) -> None: tests_per_node = self.pending[:num] if tests_per_node: del self.pending[:num] self.node2pending[node].extend(tests_per_node) node.send_runtest_some(tests_per_node) - def _check_nodes_have_same_collection(self): + def _check_nodes_have_same_collection(self) -> bool: """Return True if all nodes have collected the same items. If collections differ, this method returns False while logging diff --git a/src/xdist/workermanage.py b/src/xdist/workermanage.py index c3793efb..70d95971 100644 --- a/src/xdist/workermanage.py +++ b/src/xdist/workermanage.py @@ -7,9 +7,12 @@ import re import sys from typing import Any +from typing import Callable +from typing import Literal from typing import Sequence from typing import Union import uuid +import warnings import execnet import pytest @@ -17,11 +20,13 @@ from xdist.plugin import _sys_path import xdist.remote from xdist.remote import Producer +from xdist.remote import WorkerInfo -def parse_spec_config(config): +def parse_spec_config(config: pytest.Config) -> list[str]: xspeclist = [] - for xspec in config.getvalue("tx"): + tx: list[str] = config.getvalue("tx") + for xspec in tx: i = xspec.find("*") try: num = int(xspec[:i]) @@ -40,7 +45,12 @@ class NodeManager: EXIT_TIMEOUT = 10 DEFAULT_IGNORES = [".*", "*.pyc", "*.pyo", "*~"] - def __init__(self, config, specs=None, defaultchdir="pyexecnetcache") -> None: + def __init__( + self, + config: pytest.Config, + specs: Sequence[execnet.XSpec | str] | None = None, + defaultchdir: str = "pyexecnetcache", + ) -> None: self.config = config self.trace = self.config.trace.get("nodemanager") self.testrunuid = self.config.getoption("testrunuid") @@ -49,7 +59,7 @@ def __init__(self, config, specs=None, defaultchdir="pyexecnetcache") -> None: self.group = execnet.Group() if specs is None: specs = self._getxspecs() - self.specs = [] + self.specs: list[execnet.XSpec] = [] for spec in specs: if not isinstance(spec, execnet.XSpec): spec = execnet.XSpec(spec) @@ -61,31 +71,39 @@ def __init__(self, config, specs=None, defaultchdir="pyexecnetcache") -> None: self.rsyncoptions = self._getrsyncoptions() self._rsynced_specs: set[tuple[Any, Any]] = set() - def rsync_roots(self, gateway): + def rsync_roots(self, gateway: execnet.Gateway) -> None: """Rsync the set of roots to the node's gateway cwd.""" if self.roots: for root in self.roots: self.rsync(gateway, root, **self.rsyncoptions) - def setup_nodes(self, putevent): + def setup_nodes( + self, + putevent: Callable[[tuple[str, dict[str, Any]]], None], + ) -> list[WorkerController]: self.config.hook.pytest_xdist_setupnodes(config=self.config, specs=self.specs) self.trace("setting up nodes") return [self.setup_node(spec, putevent) for spec in self.specs] - def setup_node(self, spec, putevent): + def setup_node( + self, + spec: execnet.XSpec, + putevent: Callable[[tuple[str, dict[str, Any]]], None], + ) -> WorkerController: gw = self.group.makegateway(spec) self.config.hook.pytest_xdist_newgateway(gateway=gw) self.rsync_roots(gw) node = WorkerController(self, gw, self.config, putevent) - gw.node = node # keep the node alive + # Keep the node alive. + gw.node = node # type: ignore[attr-defined] node.setup() self.trace("started node %r" % node) return node - def teardown_nodes(self): + def teardown_nodes(self) -> None: self.group.terminate(self.EXIT_TIMEOUT) - def _getxspecs(self): + def _getxspecs(self) -> list[execnet.XSpec]: return [execnet.XSpec(x) for x in parse_spec_config(self.config)] def _getrsyncdirs(self) -> list[Path]: @@ -97,7 +115,7 @@ def _getrsyncdirs(self) -> list[Path]: import _pytest import pytest - def get_dir(p): + def get_dir(p: str) -> str: """Return the directory path if p is a package or the path to the .py file otherwise.""" stripped = p.rstrip("co") if os.path.basename(stripped) == "__init__.py": @@ -115,14 +133,14 @@ def get_dir(p): candidates.extend(rsyncroots) roots = [] for root in candidates: - root = Path(root).resolve() - if not root.exists(): + root_path = Path(root).resolve() + if not root_path.exists(): raise pytest.UsageError(f"rsyncdir doesn't exist: {root!r}") - if root not in roots: - roots.append(root) + if root_path not in roots: + roots.append(root_path) return roots - def _getrsyncoptions(self): + def _getrsyncoptions(self) -> dict[str, Any]: """Get options to be passed for rsync.""" ignores = list(self.DEFAULT_IGNORES) ignores += [str(path) for path in self.config.option.rsyncignore] @@ -133,7 +151,16 @@ def _getrsyncoptions(self): "verbose": getattr(self.config.option, "verbose", 0), } - def rsync(self, gateway, source, notify=None, verbose=False, ignores=None): + def rsync( + self, + gateway: execnet.Gateway, + source: str | os.PathLike[str], + notify: ( + Callable[[str, execnet.XSpec, str | os.PathLike[str]], Any] | None + ) = None, + verbose: int = False, + ignores: Sequence[str] | None = None, + ) -> None: """Perform rsync to remote hosts for node.""" # XXX This changes the calling behaviour of # pytest_xdist_rsyncstart and pytest_xdist_rsyncfinish to @@ -153,7 +180,7 @@ def rsync(self, gateway, source, notify=None, verbose=False, ignores=None): if (spec, source) in self._rsynced_specs: return - def finished(): + def finished() -> None: if notify: notify("rsyncrootready", spec, source) @@ -189,11 +216,19 @@ def filter(self, path: PathLike) -> bool: else: return True - def add_target_host(self, gateway, finished=None): + def add_target_host( + self, + gateway: execnet.Gateway, + finished: Callable[[], None] | None = None, + ) -> None: remotepath = os.path.basename(self._sourcedir) super().add_target(gateway, remotepath, finishedcallback=finished, delete=True) - def _report_send_file(self, gateway, modified_rel_path): + def _report_send_file( + self, + gateway: execnet.Gateway, # type: ignore[override] + modified_rel_path: str, + ) -> None: if self._verbose > 0: path = os.path.basename(self._sourcedir) + "/" + modified_rel_path remotepath = gateway.spec.chdir @@ -234,12 +269,21 @@ class Marker(enum.Enum): class WorkerController: + # Set when the worker is ready. + workerinfo: WorkerInfo + class RemoteHook: @pytest.hookimpl(trylast=True) - def pytest_xdist_getremotemodule(self): + def pytest_xdist_getremotemodule(self) -> Any: return xdist.remote - def __init__(self, nodemanager, gateway, config, putevent): + def __init__( + self, + nodemanager: NodeManager, + gateway: execnet.Gateway, + config: pytest.Config, + putevent: Callable[[tuple[str, dict[str, Any]]], None], + ) -> None: config.pluginmanager.register(self.RemoteHook()) self.nodemanager = nodemanager self.putevent = putevent @@ -255,14 +299,14 @@ def __init__(self, nodemanager, gateway, config, putevent): self._shutdown_sent = False self.log = Producer(f"workerctl-{gateway.id}", enabled=config.option.debug) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__} {self.gateway.id}>" @property - def shutting_down(self): + def shutting_down(self) -> bool: return self._down or self._shutdown_sent - def setup(self): + def setup(self) -> None: self.log("setting up worker session") spec = self.gateway.spec args = [str(x) for x in self.config.invocation_params.args or ()] @@ -283,10 +327,11 @@ def setup(self): change_sys_path = _sys_path if self.gateway.spec.popen else None self.channel.send((self.workerinput, args, option_dict, change_sys_path)) - if self.putevent: + # putevent is only None in a test. + if self.putevent: # type: ignore[truthy-function] self.channel.setcallback(self.process_from_remote, endmarker=Marker.END) - def ensure_teardown(self): + def ensure_teardown(self) -> None: if hasattr(self, "channel"): if not self.channel.isclosed(): self.log("closing", self.channel) @@ -297,16 +342,16 @@ def ensure_teardown(self): self.gateway.exit() # del self.gateway - def send_runtest_some(self, indices): + def send_runtest_some(self, indices: Sequence[int]) -> None: self.sendcommand("runtests", indices=indices) - def send_runtest_all(self): + def send_runtest_all(self) -> None: self.sendcommand("runtests_all") - def send_steal(self, indices): + def send_steal(self, indices: Sequence[int]) -> None: self.sendcommand("steal", indices=indices) - def shutdown(self): + def shutdown(self) -> None: if not self._down: try: self.sendcommand("shutdown") @@ -314,16 +359,18 @@ def shutdown(self): pass self._shutdown_sent = True - def sendcommand(self, name, **kwargs): + def sendcommand(self, name: str, **kwargs: object) -> None: """Send a named parametrized command to the other side.""" self.log(f"sending command {name}(**{kwargs})") self.channel.send((name, kwargs)) - def notify_inproc(self, eventname, **kwargs): + def notify_inproc(self, eventname: str, **kwargs: object) -> None: self.log(f"queuing {eventname}(**{kwargs})") self.putevent((eventname, kwargs)) - def process_from_remote(self, eventcall): + def process_from_remote( + self, eventcall: tuple[str, dict[str, Any]] | Literal[Marker.END] + ) -> None: """This gets called for each object we receive from the other side and if the channel closes. @@ -333,7 +380,7 @@ def process_from_remote(self, eventcall): """ try: if eventcall is Marker.END: - err = self.channel._getremoteerror() + err: object | None = self.channel._getremoteerror() # type: ignore[no-untyped-call] if not self._down: if not err or isinstance(err, EOFError): err = "Not properly terminated" # lost connection? @@ -399,9 +446,8 @@ def process_from_remote(self, eventcall): self.notify_inproc("errordown", node=self, error=excinfo) -def unserialize_warning_message(data): +def unserialize_warning_message(data: dict[str, Any]) -> warnings.WarningMessage: import importlib - import warnings if data["message_module"]: mod = importlib.import_module(data["message_module"]) @@ -438,4 +484,4 @@ def unserialize_warning_message(data): continue kwargs[attr_name] = data[attr_name] - return warnings.WarningMessage(**kwargs) # type: ignore[arg-type] + return warnings.WarningMessage(**kwargs) diff --git a/testing/acceptance_test.py b/testing/acceptance_test.py index d17ddf09..3ef10cc9 100644 --- a/testing/acceptance_test.py +++ b/testing/acceptance_test.py @@ -3,6 +3,7 @@ import os import re import shutil +from typing import cast import pytest @@ -223,7 +224,7 @@ def test_crash(): assert result.ret == 1 def test_distribution_rsyncdirs_example( - self, pytester: pytest.Pytester, monkeypatch + self, pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch ) -> None: # use a custom plugin that has a custom command-line option to ensure # this is propagated to workers (see #491) @@ -415,7 +416,7 @@ def test_hello(): class TestTerminalReporting: @pytest.mark.parametrize("verbosity", ["", "-q", "-v"]) - def test_output_verbosity(self, pytester, verbosity: str) -> None: + def test_output_verbosity(self, pytester: pytest.Pytester, verbosity: str) -> None: pytester.makepyfile( """ def test_ok(): @@ -610,7 +611,7 @@ def test_hello(myarg): def test_config_initialization( - pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch, pytestconfig + pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch ) -> None: """Ensure workers and controller are initialized consistently. Integration test for #445.""" pytester.makepyfile( @@ -635,7 +636,7 @@ def test_1(request): @pytest.mark.parametrize("when", ["setup", "call", "teardown"]) -def test_crashing_item(pytester, when) -> None: +def test_crashing_item(pytester: pytest.Pytester, when: str) -> None: """Ensure crashing item is correctly reported during all testing stages.""" code = dict(setup="", call="", teardown="") code[when] = "os._exit(1)" @@ -766,7 +767,7 @@ def test_ok(): @pytest.mark.parametrize("plugin", ["xdist.looponfail"]) -def test_sub_plugins_disabled(pytester, plugin) -> None: +def test_sub_plugins_disabled(pytester: pytest.Pytester, plugin: str) -> None: """Test that xdist doesn't break if we disable any of its sub-plugins (#32).""" p1 = pytester.makepyfile( """ @@ -781,7 +782,7 @@ def test_ok(): class TestWarnings: @pytest.mark.parametrize("n", ["-n0", "-n1"]) - def test_warnings(self, pytester, n) -> None: + def test_warnings(self, pytester: pytest.Pytester, n: str) -> None: pytester.makepyfile( """ import warnings, py, pytest @@ -827,7 +828,7 @@ def test(): result.stdout.no_fnmatch_line("*this hook should not be called in this version") @pytest.mark.parametrize("n", ["-n0", "-n1"]) - def test_custom_subclass(self, pytester, n) -> None: + def test_custom_subclass(self, pytester: pytest.Pytester, n: str) -> None: """Check that warning subclasses that don't honor the args attribute don't break pytest-xdist (#344). """ @@ -851,7 +852,7 @@ def test_func(request): result.stdout.fnmatch_lines(["*MyWarning*", "*1 passed, 1 warning*"]) @pytest.mark.parametrize("n", ["-n0", "-n1"]) - def test_unserializable_arguments(self, pytester, n) -> None: + def test_unserializable_arguments(self, pytester: pytest.Pytester, n: str) -> None: """Check that warnings with unserializable arguments are handled correctly (#349).""" pytester.makepyfile( """ @@ -869,7 +870,9 @@ def test_func(tmp_path): result.stdout.fnmatch_lines(["*UserWarning*foo.txt*", "*1 passed, 1 warning*"]) @pytest.mark.parametrize("n", ["-n0", "-n1"]) - def test_unserializable_warning_details(self, pytester, n) -> None: + def test_unserializable_warning_details( + self, pytester: pytest.Pytester, n: str + ) -> None: """Check that warnings with unserializable _WARNING_DETAILS are handled correctly (#379). """ @@ -1049,7 +1052,7 @@ def test_c(): pass @pytest.mark.parametrize("n", [0, 2]) -def test_worker_id_fixture(pytester, n) -> None: +def test_worker_id_fixture(pytester: pytest.Pytester, n: int) -> None: import glob f = pytester.makepyfile( @@ -1065,8 +1068,8 @@ def test_worker_id1(worker_id, run_num): result.stdout.fnmatch_lines("* 2 passed in *") worker_ids = set() for fname in glob.glob(str(pytester.path / "*.txt")): - with open(fname) as f: - worker_ids.add(f.read().strip()) + with open(fname) as fp: + worker_ids.add(fp.read().strip()) if n == 0: assert worker_ids == {"master"} else: @@ -1074,7 +1077,7 @@ def test_worker_id1(worker_id, run_num): @pytest.mark.parametrize("n", [0, 2]) -def test_testrun_uid_fixture(pytester, n) -> None: +def test_testrun_uid_fixture(pytester: pytest.Pytester, n: int) -> None: import glob f = pytester.makepyfile( @@ -1090,14 +1093,14 @@ def test_testrun_uid1(testrun_uid, run_num): result.stdout.fnmatch_lines("* 2 passed in *") testrun_uids = set() for fname in glob.glob(str(pytester.path / "*.txt")): - with open(fname) as f: - testrun_uids.add(f.read().strip()) + with open(fname) as fp: + testrun_uids.add(fp.read().strip()) assert len(testrun_uids) == 1 assert len(testrun_uids.pop()) == 32 @pytest.mark.parametrize("tb", ["auto", "long", "short", "no", "line", "native"]) -def test_error_report_styles(pytester, tb) -> None: +def test_error_report_styles(pytester: pytest.Pytester, tb: str) -> None: pytester.makepyfile( """ import pytest @@ -1111,7 +1114,7 @@ def test_error_report_styles(): result.assert_outcomes(failed=1) -def test_color_yes_collection_on_non_atty(pytester) -> None: +def test_color_yes_collection_on_non_atty(pytester: pytest.Pytester) -> None: """Skip collect progress report when working on non-terminals. Similar to pytest-dev/pytest#1397 @@ -1133,7 +1136,7 @@ def test_this(i): assert "collecting:" not in result.stdout.str() -def test_without_terminal_plugin(pytester, request) -> None: +def test_without_terminal_plugin(pytester: pytest.Pytester) -> None: """No output when terminal plugin is disabled.""" pytester.makepyfile( """ @@ -1368,7 +1371,7 @@ def test_2(): class TestGroupScope: - def test_by_module(self, pytester: pytest.Pytester): + def test_by_module(self, pytester: pytest.Pytester) -> None: test_file = """ import pytest class TestA: @@ -1399,7 +1402,7 @@ def test(self, i): == test_b_workers_and_test_count.items() ) - def test_by_class(self, pytester: pytest.Pytester): + def test_by_class(self, pytester: pytest.Pytester) -> None: pytester.makepyfile( test_a=""" import pytest @@ -1436,7 +1439,7 @@ def test(self, i): == test_b_workers_and_test_count.items() ) - def test_module_single_start(self, pytester: pytest.Pytester): + def test_module_single_start(self, pytester: pytest.Pytester) -> None: test_file1 = """ import pytest @pytest.mark.xdist_group(name="xdist_group") @@ -1459,7 +1462,7 @@ def test_2(): assert a.keys() == b.keys() and b.keys() == c.keys() - def test_with_two_group_names(self, pytester: pytest.Pytester): + def test_with_two_group_names(self, pytester: pytest.Pytester) -> None: test_file = """ import pytest @pytest.mark.xdist_group(name="group1") @@ -1512,7 +1515,7 @@ def test_c(self): @pytest.mark.parametrize( "scope", ["each", "load", "loadscope", "loadfile", "worksteal", "no"] ) - def test_single_file(self, pytester, scope) -> None: + def test_single_file(self, pytester: pytest.Pytester, scope: str) -> None: pytester.makepyfile(test_a=self.test_file1) result = pytester.runpytest("-n2", "--dist=%s" % scope, "-v") result.assert_outcomes(passed=(12 if scope != "each" else 12 * 2)) @@ -1520,7 +1523,7 @@ def test_single_file(self, pytester, scope) -> None: @pytest.mark.parametrize( "scope", ["each", "load", "loadscope", "loadfile", "worksteal", "no"] ) - def test_multi_file(self, pytester, scope) -> None: + def test_multi_file(self, pytester: pytest.Pytester, scope: str) -> None: pytester.makepyfile( test_a=self.test_file1, test_b=self.test_file1, @@ -1564,32 +1567,32 @@ def get_workers_and_test_count_by_prefix( class TestAPI: @pytest.fixture - def fake_request(self): + def fake_request(self) -> pytest.FixtureRequest: class FakeOption: - def __init__(self): + def __init__(self) -> None: self.dist = "load" class FakeConfig: - def __init__(self): + def __init__(self) -> None: self.workerinput = {"workerid": "gw5"} self.option = FakeOption() class FakeRequest: - def __init__(self): + def __init__(self) -> None: self.config = FakeConfig() - return FakeRequest() + return cast(pytest.FixtureRequest, FakeRequest()) - def test_is_xdist_worker(self, fake_request) -> None: + def test_is_xdist_worker(self, fake_request: pytest.FixtureRequest) -> None: assert xdist.is_xdist_worker(fake_request) - del fake_request.config.workerinput + del fake_request.config.workerinput # type: ignore[attr-defined] assert not xdist.is_xdist_worker(fake_request) - def test_is_xdist_controller(self, fake_request) -> None: + def test_is_xdist_controller(self, fake_request: pytest.FixtureRequest) -> None: assert not xdist.is_xdist_master(fake_request) assert not xdist.is_xdist_controller(fake_request) - del fake_request.config.workerinput + del fake_request.config.workerinput # type: ignore[attr-defined] assert xdist.is_xdist_master(fake_request) assert xdist.is_xdist_controller(fake_request) @@ -1597,13 +1600,13 @@ def test_is_xdist_controller(self, fake_request) -> None: assert not xdist.is_xdist_master(fake_request) assert not xdist.is_xdist_controller(fake_request) - def test_get_xdist_worker_id(self, fake_request) -> None: + def test_get_xdist_worker_id(self, fake_request: pytest.FixtureRequest) -> None: assert xdist.get_xdist_worker_id(fake_request) == "gw5" - del fake_request.config.workerinput + del fake_request.config.workerinput # type: ignore[attr-defined] assert xdist.get_xdist_worker_id(fake_request) == "master" -def test_collection_crash(pytester: pytest.Pytester): +def test_collection_crash(pytester: pytest.Pytester) -> None: p1 = pytester.makepyfile( """ assert 0 @@ -1622,7 +1625,7 @@ def test_collection_crash(pytester: pytest.Pytester): ) -def test_dist_in_addopts(pytester: pytest.Pytester): +def test_dist_in_addopts(pytester: pytest.Pytester) -> None: """Users can set a default distribution in the configuration file (#789).""" pytester.makepyfile( """ diff --git a/testing/conftest.py b/testing/conftest.py index 70bdfdff..5186b8b6 100644 --- a/testing/conftest.py +++ b/testing/conftest.py @@ -1,6 +1,8 @@ from __future__ import annotations import shutil +from typing import Callable +from typing import Generator import execnet import pytest @@ -10,12 +12,14 @@ @pytest.fixture(autouse=True) -def _divert_atexit(request, monkeypatch: pytest.MonkeyPatch): +def _divert_atexit(monkeypatch: pytest.MonkeyPatch) -> Generator[None, None, None]: import atexit finalizers = [] - def fake_register(func, *args, **kwargs): + def fake_register( + func: Callable[..., object], *args: object, **kwargs: object + ) -> None: finalizers.append((func, args, kwargs)) monkeypatch.setattr(atexit, "register", fake_register) @@ -27,7 +31,7 @@ def fake_register(func, *args, **kwargs): func(*args, **kwargs) -def pytest_addoption(parser) -> None: +def pytest_addoption(parser: pytest.Parser) -> None: parser.addoption( "--gx", action="append", @@ -37,16 +41,16 @@ def pytest_addoption(parser) -> None: @pytest.fixture -def specssh(request) -> str: +def specssh(request: pytest.FixtureRequest) -> str: return getspecssh(request.config) # configuration information for tests -def getgspecs(config) -> list[execnet.XSpec]: +def getgspecs(config: pytest.Config) -> list[execnet.XSpec]: return [execnet.XSpec(spec) for spec in config.getvalueorskip("gspecs")] -def getspecssh(config) -> str: # type: ignore[return] +def getspecssh(config: pytest.Config) -> str: xspecs = getgspecs(config) for spec in xspecs: if spec.ssh: @@ -56,7 +60,7 @@ def getspecssh(config) -> str: # type: ignore[return] pytest.skip("need '--gx ssh=...'") -def getsocketspec(config) -> execnet.XSpec: +def getsocketspec(config: pytest.Config) -> execnet.XSpec: xspecs = getgspecs(config) for spec in xspecs: if spec.socket: diff --git a/testing/test_dsession.py b/testing/test_dsession.py index 2a32a46e..e57caeb2 100644 --- a/testing/test_dsession.py +++ b/testing/test_dsession.py @@ -1,6 +1,9 @@ from __future__ import annotations +from typing import Any +from typing import cast from typing import Sequence +from typing import TYPE_CHECKING import execnet import pytest @@ -13,29 +16,38 @@ from xdist.scheduler import EachScheduling from xdist.scheduler import LoadScheduling from xdist.scheduler import WorkStealingScheduling +from xdist.workermanage import WorkerController -class MockGateway: +if TYPE_CHECKING: + BaseOfMockGateway = execnet.Gateway + BaseOfMockNode = WorkerController +else: + BaseOfMockGateway = object + BaseOfMockNode = object + + +class MockGateway(BaseOfMockGateway): def __init__(self) -> None: self._count = 0 self.id = str(self._count) self._count += 1 -class MockNode: +class MockNode(BaseOfMockNode): def __init__(self) -> None: - self.sent = [] # type: ignore[var-annotated] - self.stolen = [] # type: ignore[var-annotated] + self.sent: list[int | str] = [] + self.stolen: list[int] = [] self.gateway = MockGateway() self._shutdown = False - def send_runtest_some(self, indices) -> None: + def send_runtest_some(self, indices: Sequence[int]) -> None: self.sent.extend(indices) def send_runtest_all(self) -> None: self.sent.append("ALL") - def send_steal(self, indices) -> None: + def send_steal(self, indices: Sequence[int]) -> None: self.stolen.extend(indices) def shutdown(self) -> None: @@ -48,10 +60,9 @@ def shutting_down(self) -> bool: class TestEachScheduling: def test_schedule_load_simple(self, pytester: pytest.Pytester) -> None: - node1 = MockNode() - node2 = MockNode() config = pytester.parseconfig("--tx=2*popen") sched = EachScheduling(config) + node1, node2 = MockNode(), MockNode() sched.add_node(node1) sched.add_node(node2) collection = ["a.py::test_1"] @@ -59,7 +70,7 @@ def test_schedule_load_simple(self, pytester: pytest.Pytester) -> None: sched.add_node_collection(node1, collection) assert not sched.collection_is_completed sched.add_node_collection(node2, collection) - assert sched.collection_is_completed + assert bool(sched.collection_is_completed) assert sched.node2collection[node1] == collection assert sched.node2collection[node2] == collection sched.schedule() @@ -72,14 +83,14 @@ def test_schedule_load_simple(self, pytester: pytest.Pytester) -> None: assert sched.tests_finished def test_schedule_remove_node(self, pytester: pytest.Pytester) -> None: - node1 = MockNode() config = pytester.parseconfig("--tx=popen") sched = EachScheduling(config) + node1 = MockNode() sched.add_node(node1) collection = ["a.py::test_1"] assert not sched.collection_is_completed sched.add_node_collection(node1, collection) - assert sched.collection_is_completed + assert bool(sched.collection_is_completed) assert sched.node2collection[node1] == collection sched.schedule() assert sched.tests_finished @@ -93,15 +104,15 @@ class TestLoadScheduling: def test_schedule_load_simple(self, pytester: pytest.Pytester) -> None: config = pytester.parseconfig("--tx=2*popen") sched = LoadScheduling(config) - sched.add_node(MockNode()) - sched.add_node(MockNode()) - node1, node2 = sched.nodes + node1, node2 = MockNode(), MockNode() + sched.add_node(node1) + sched.add_node(node2) collection = ["a.py::test_1", "a.py::test_2"] assert not sched.collection_is_completed sched.add_node_collection(node1, collection) assert not sched.collection_is_completed sched.add_node_collection(node2, collection) - assert sched.collection_is_completed + assert bool(sched.collection_is_completed) assert sched.node2collection[node1] == collection assert sched.node2collection[node2] == collection sched.schedule() @@ -111,15 +122,17 @@ def test_schedule_load_simple(self, pytester: pytest.Pytester) -> None: assert len(node2.sent) == 1 assert node1.sent == [0] assert node2.sent == [1] - sched.mark_test_complete(node1, node1.sent[0]) + sent10 = node1.sent[0] + assert isinstance(sent10, int) + sched.mark_test_complete(node1, sent10) assert sched.tests_finished def test_schedule_batch_size(self, pytester: pytest.Pytester) -> None: config = pytester.parseconfig("--tx=2*popen") sched = LoadScheduling(config) - sched.add_node(MockNode()) - sched.add_node(MockNode()) - node1, node2 = sched.nodes + node1, node2 = MockNode(), MockNode() + sched.add_node(node1) + sched.add_node(node2) col = ["xyz"] * 6 sched.add_node_collection(node1, col) sched.add_node_collection(node2, col) @@ -144,9 +157,9 @@ def test_schedule_batch_size(self, pytester: pytest.Pytester) -> None: def test_schedule_maxchunk_none(self, pytester: pytest.Pytester) -> None: config = pytester.parseconfig("--tx=2*popen") sched = LoadScheduling(config) - sched.add_node(MockNode()) - sched.add_node(MockNode()) - node1, node2 = sched.nodes + node1, node2 = MockNode(), MockNode() + sched.add_node(node1) + sched.add_node(node2) col = [f"test{i}" for i in range(16)] sched.add_node_collection(node1, col) sched.add_node_collection(node2, col) @@ -172,9 +185,9 @@ def test_schedule_maxchunk_none(self, pytester: pytest.Pytester) -> None: def test_schedule_maxchunk_1(self, pytester: pytest.Pytester) -> None: config = pytester.parseconfig("--tx=2*popen", "--maxschedchunk=1") sched = LoadScheduling(config) - sched.add_node(MockNode()) - sched.add_node(MockNode()) - node1, node2 = sched.nodes + node1, node2 = MockNode(), MockNode() + sched.add_node(node1) + sched.add_node(node2) col = [f"test{i}" for i in range(16)] sched.add_node_collection(node1, col) sched.add_node_collection(node2, col) @@ -186,7 +199,9 @@ def test_schedule_maxchunk_1(self, pytester: pytest.Pytester) -> None: assert sched.node2pending[node2] == node2.sent for complete_index, first_pending in enumerate(range(5, 16)): - sched.mark_test_complete(node1, node1.sent[complete_index]) + sent_index = node1.sent[complete_index] + assert isinstance(sent_index, int) + sched.mark_test_complete(node1, sent_index) assert node1.sent == [0, 1, *range(4, first_pending)] assert node2.sent == [2, 3] assert sched.pending == list(range(first_pending, 16)) @@ -194,10 +209,10 @@ def test_schedule_maxchunk_1(self, pytester: pytest.Pytester) -> None: def test_schedule_fewer_tests_than_nodes(self, pytester: pytest.Pytester) -> None: config = pytester.parseconfig("--tx=3*popen") sched = LoadScheduling(config) - sched.add_node(MockNode()) - sched.add_node(MockNode()) - sched.add_node(MockNode()) - node1, node2, node3 = sched.nodes + node1, node2, node3 = MockNode(), MockNode(), MockNode() + sched.add_node(node1) + sched.add_node(node2) + sched.add_node(node3) col = ["xyz"] * 2 sched.add_node_collection(node1, col) sched.add_node_collection(node2, col) @@ -215,10 +230,10 @@ def test_schedule_fewer_than_two_tests_per_node( ) -> None: config = pytester.parseconfig("--tx=3*popen") sched = LoadScheduling(config) - sched.add_node(MockNode()) - sched.add_node(MockNode()) - sched.add_node(MockNode()) - node1, node2, node3 = sched.nodes + node1, node2, node3 = MockNode(), MockNode(), MockNode() + sched.add_node(node1) + sched.add_node(node2) + sched.add_node(node3) col = ["xyz"] * 5 sched.add_node_collection(node1, col) sched.add_node_collection(node2, col) @@ -232,9 +247,9 @@ def test_schedule_fewer_than_two_tests_per_node( assert not sched.pending def test_add_remove_node(self, pytester: pytest.Pytester) -> None: - node = MockNode() config = pytester.parseconfig("--tx=popen") sched = LoadScheduling(config) + node = MockNode() sched.add_node(node) collection = ["test_file.py::test_func"] sched.add_node_collection(node, collection) @@ -253,18 +268,17 @@ def test_different_tests_collected(self, pytester: pytest.Pytester) -> None: class CollectHook: """Dummy hook that stores collection reports.""" - def __init__(self): - self.reports = [] + def __init__(self) -> None: + self.reports: list[pytest.CollectReport] = [] - def pytest_collectreport(self, report): + def pytest_collectreport(self, report: pytest.CollectReport) -> None: self.reports.append(report) collect_hook = CollectHook() config = pytester.parseconfig("--tx=2*popen") config.pluginmanager.register(collect_hook, "collect_hook") - node1 = MockNode() - node2 = MockNode() sched = LoadScheduling(config) + node1, node2 = MockNode(), MockNode() sched.add_node(node1) sched.add_node(node2) sched.add_node_collection(node1, ["a.py::test_1"]) @@ -272,6 +286,7 @@ def pytest_collectreport(self, report): sched.schedule() assert len(collect_hook.reports) == 1 rep = collect_hook.reports[0] + assert isinstance(rep.longrepr, str) assert "Different tests were collected between" in rep.longrepr @@ -279,15 +294,15 @@ class TestWorkStealingScheduling: def test_ideal_case(self, pytester: pytest.Pytester) -> None: config = pytester.parseconfig("--tx=2*popen") sched = WorkStealingScheduling(config) - sched.add_node(MockNode()) - sched.add_node(MockNode()) - node1, node2 = sched.nodes + node1, node2 = MockNode(), MockNode() + sched.add_node(node1) + sched.add_node(node2) collection = [f"test_workstealing.py::test_{i}" for i in range(16)] assert not sched.collection_is_completed sched.add_node_collection(node1, collection) assert not sched.collection_is_completed sched.add_node_collection(node2, collection) - assert sched.collection_is_completed + assert bool(sched.collection_is_completed) assert sched.node2collection[node1] == collection assert sched.node2collection[node2] == collection sched.schedule() @@ -296,18 +311,20 @@ def test_ideal_case(self, pytester: pytest.Pytester) -> None: assert node1.sent == list(range(8)) assert node2.sent == list(range(8, 16)) for i in range(8): - sched.mark_test_complete(node1, node1.sent[i]) - sched.mark_test_complete(node2, node2.sent[i]) - assert sched.tests_finished + sent1, sent2 = node1.sent[i], node2.sent[i] + assert isinstance(sent1, int) and isinstance(sent2, int) + sched.mark_test_complete(node1, sent1) + sched.mark_test_complete(node2, sent2) + assert bool(sched.tests_finished) assert node1.stolen == [] assert node2.stolen == [] def test_stealing(self, pytester: pytest.Pytester) -> None: config = pytester.parseconfig("--tx=2*popen") sched = WorkStealingScheduling(config) - sched.add_node(MockNode()) - sched.add_node(MockNode()) - node1, node2 = sched.nodes + node1, node2 = MockNode(), MockNode() + sched.add_node(node1) + sched.add_node(node2) collection = [f"test_workstealing.py::test_{i}" for i in range(16)] sched.add_node_collection(node1, collection) sched.add_node_collection(node2, collection) @@ -316,11 +333,15 @@ def test_stealing(self, pytester: pytest.Pytester) -> None: assert node1.sent == list(range(8)) assert node2.sent == list(range(8, 16)) for i in range(8): - sched.mark_test_complete(node1, node1.sent[i]) + sent = node1.sent[i] + assert isinstance(sent, int) + sched.mark_test_complete(node1, sent) assert node2.stolen == list(range(12, 16)) sched.remove_pending_tests_from_node(node2, node2.stolen) for i in range(4): - sched.mark_test_complete(node2, node2.sent[i]) + sent = node2.sent[i] + assert isinstance(sent, int) + sched.mark_test_complete(node2, sent) assert node1.stolen == [14, 15] sched.remove_pending_tests_from_node(node1, node1.stolen) sched.mark_test_complete(node1, 12) @@ -355,10 +376,10 @@ def test_steal_on_add_node(self, pytester: pytest.Pytester) -> None: def test_schedule_fewer_tests_than_nodes(self, pytester: pytest.Pytester) -> None: config = pytester.parseconfig("--tx=3*popen") sched = WorkStealingScheduling(config) - sched.add_node(MockNode()) - sched.add_node(MockNode()) - sched.add_node(MockNode()) - node1, node2, node3 = sched.nodes + node1, node2, node3 = MockNode(), MockNode(), MockNode() + sched.add_node(node1) + sched.add_node(node2) + sched.add_node(node3) col = ["xyz"] * 2 sched.add_node_collection(node1, col) sched.add_node_collection(node2, col) @@ -378,10 +399,10 @@ def test_schedule_fewer_than_two_tests_per_node( ) -> None: config = pytester.parseconfig("--tx=3*popen") sched = WorkStealingScheduling(config) - sched.add_node(MockNode()) - sched.add_node(MockNode()) - sched.add_node(MockNode()) - node1, node2, node3 = sched.nodes + node1, node2, node3 = MockNode(), MockNode(), MockNode() + sched.add_node(node1) + sched.add_node(node2) + sched.add_node(node3) col = ["xyz"] * 5 sched.add_node_collection(node1, col) sched.add_node_collection(node2, col) @@ -392,11 +413,19 @@ def test_schedule_fewer_than_two_tests_per_node( assert node3.sent == [3, 4] assert not sched.pending assert not sched.tests_finished - sched.mark_test_complete(node1, node1.sent[0]) - sched.mark_test_complete(node2, node2.sent[0]) - sched.mark_test_complete(node3, node3.sent[0]) - sched.mark_test_complete(node3, node3.sent[1]) - assert sched.tests_finished + sent10 = node1.sent[0] + assert isinstance(sent10, int) + sent20 = node2.sent[0] + assert isinstance(sent20, int) + sent30 = node3.sent[0] + assert isinstance(sent30, int) + sent31 = node3.sent[1] + assert isinstance(sent31, int) + sched.mark_test_complete(node1, sent10) + sched.mark_test_complete(node2, sent20) + sched.mark_test_complete(node3, sent30) + sched.mark_test_complete(node3, sent31) + assert bool(sched.tests_finished) assert node1.stolen == [] assert node2.stolen == [] assert node3.stolen == [] @@ -416,18 +445,17 @@ def test_add_remove_node(self, pytester: pytest.Pytester) -> None: def test_different_tests_collected(self, pytester: pytest.Pytester) -> None: class CollectHook: - def __init__(self): - self.reports = [] + def __init__(self) -> None: + self.reports: list[pytest.CollectReport] = [] - def pytest_collectreport(self, report): + def pytest_collectreport(self, report: pytest.CollectReport) -> None: self.reports.append(report) collect_hook = CollectHook() config = pytester.parseconfig("--tx=2*popen") config.pluginmanager.register(collect_hook, "collect_hook") - node1 = MockNode() - node2 = MockNode() sched = WorkStealingScheduling(config) + node1, node2 = MockNode(), MockNode() sched.add_node(node1) sched.add_node(node2) sched.add_node_collection(node1, ["a.py::test_1"]) @@ -435,12 +463,13 @@ def pytest_collectreport(self, report): sched.schedule() assert len(collect_hook.reports) == 1 rep = collect_hook.reports[0] + assert isinstance(rep.longrepr, str) assert "Different tests were collected between" in rep.longrepr class TestDistReporter: @pytest.mark.xfail - def test_rsync_printing(self, pytester: pytest.Pytester, linecomp) -> None: + def test_rsync_printing(self, pytester: pytest.Pytester, linecomp: Any) -> None: config = pytester.parseconfig() from _pytest.terminal import TerminalReporter @@ -473,15 +502,17 @@ class gw2: def test_report_collection_diff_equal() -> None: """Test reporting of equal collections.""" from_collection = to_collection = ["aaa", "bbb", "ccc"] - assert report_collection_diff(from_collection, to_collection, 1, 2) is None + assert report_collection_diff(from_collection, to_collection, "1", "2") is None def test_default_max_worker_restart() -> None: - class config: + class MockConfig: class option: maxworkerrestart: str | None = None numprocesses: int = 0 + config = cast(pytest.Config, MockConfig) + assert get_default_max_worker_restart(config) is None config.option.numprocesses = 2 diff --git a/testing/test_looponfail.py b/testing/test_looponfail.py index e2fa02d7..eda0ad12 100644 --- a/testing/test_looponfail.py +++ b/testing/test_looponfail.py @@ -143,7 +143,7 @@ def test_func(): control = RemoteControl(modcol.config) control.loop_once() assert control.failures - modcol_path = modcol.path # type:ignore[attr-defined] + modcol_path = modcol.path modcol_path.write_text( textwrap.dedent( @@ -173,7 +173,7 @@ def test_func(): """ ) ) - parent = modcol.path.parent.parent # type: ignore[attr-defined] + parent = modcol.path.parent.parent monkeypatch.chdir(parent) modcol.config.args = [ str(Path(x).relative_to(parent)) for x in modcol.config.args @@ -332,7 +332,7 @@ def test_one(): remotecontrol = RemoteControl(modcol.config) orig_runsession = remotecontrol.runsession - def runsession_dups(): + def runsession_dups() -> tuple[list[str], list[str], bool]: # twisted.trial test cases may report multiple errors. failures, reports, collection_failed = orig_runsession() print(failures) diff --git a/testing/test_plugin.py b/testing/test_plugin.py index 687a3d7f..f3526670 100644 --- a/testing/test_plugin.py +++ b/testing/test_plugin.py @@ -10,7 +10,7 @@ @pytest.fixture -def monkeypatch_3_cpus(monkeypatch: pytest.MonkeyPatch): +def monkeypatch_3_cpus(monkeypatch: pytest.MonkeyPatch) -> None: """Make pytest-xdist believe the system has 3 CPUs.""" # block import monkeypatch.setitem(sys.modules, "psutil", None) @@ -128,7 +128,7 @@ def test_auto_detect_cpus_psutil( def test_auto_detect_cpus_os( - pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch, monkeypatch_3_cpus + pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch, monkeypatch_3_cpus: None ) -> None: from xdist.plugin import pytest_cmdline_main as check_options @@ -189,7 +189,7 @@ def pytest_xdist_auto_num_workers(config): def test_hook_auto_num_workers_none( - pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch, monkeypatch_3_cpus + pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch, monkeypatch_3_cpus: None ) -> None: # Returning None from a hook to skip it is pytest behavior, # but we document it so let's test it. @@ -231,7 +231,7 @@ def test_envvar_auto_num_workers( def test_envvar_auto_num_workers_warn( - pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch, monkeypatch_3_cpus + pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch, monkeypatch_3_cpus: None ) -> None: from xdist.plugin import pytest_cmdline_main as check_options @@ -244,7 +244,7 @@ def test_envvar_auto_num_workers_warn( def test_auto_num_workers_hook_overrides_envvar( - pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch, monkeypatch_3_cpus + pytester: pytest.Pytester, monkeypatch: pytest.MonkeyPatch, monkeypatch_3_cpus: None ) -> None: from xdist.plugin import pytest_cmdline_main as check_options diff --git a/testing/test_remote.py b/testing/test_remote.py index 245f27d0..cbbf758b 100644 --- a/testing/test_remote.py +++ b/testing/test_remote.py @@ -1,36 +1,46 @@ +from __future__ import annotations + import marshal import pprint from queue import Queue import sys +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Union import uuid import execnet import pytest +from xdist.workermanage import NodeManager from xdist.workermanage import WorkerController WAIT_TIMEOUT = 10.0 -def check_marshallable(d): +def check_marshallable(d: object) -> None: try: - marshal.dumps(d) + marshal.dumps(d) # type: ignore[arg-type] except ValueError as e: pprint.pprint(d) raise ValueError("not marshallable") from e class EventCall: - def __init__(self, eventcall): + def __init__(self, eventcall: tuple[str, dict[str, Any]]) -> None: self.name, self.kwargs = eventcall - def __str__(self): + def __str__(self) -> str: return f"" class WorkerSetup: - def __init__(self, request, pytester: pytest.Pytester) -> None: + def __init__( + self, request: pytest.FixtureRequest, pytester: pytest.Pytester + ) -> None: self.request = request self.pytester = pytester self.use_callback = False @@ -47,11 +57,18 @@ class DummyMananger: testrunuid = uuid.uuid4().hex specs = [0, 1] - self.slp = WorkerController(DummyMananger, self.gateway, config, putevent) + nodemanager = cast(NodeManager, DummyMananger) + + self.slp = WorkerController( + nodemanager=nodemanager, + gateway=self.gateway, + config=config, + putevent=putevent, # type: ignore[arg-type] + ) self.request.addfinalizer(self.slp.ensure_teardown) self.slp.setup() - def popevent(self, name=None): + def popevent(self, name: str | None = None) -> EventCall: while 1: if self.use_callback: data = self.events.get(timeout=WAIT_TIMEOUT) @@ -62,27 +79,33 @@ def popevent(self, name=None): return ev print(f"skipping {ev}") - def sendcommand(self, name, **kwargs): + def sendcommand(self, name: str, **kwargs: Any) -> None: self.slp.sendcommand(name, **kwargs) @pytest.fixture -def worker(request, pytester: pytest.Pytester) -> WorkerSetup: +def worker(request: pytest.FixtureRequest, pytester: pytest.Pytester) -> WorkerSetup: return WorkerSetup(request, pytester) class TestWorkerInteractor: + UnserializerReport = Callable[ + [Dict[str, Any]], Union[pytest.CollectReport, pytest.TestReport] + ] + @pytest.fixture - def unserialize_report(self, pytestconfig): - def unserialize(data): - return pytestconfig.hook.pytest_report_from_serializable( + def unserialize_report(self, pytestconfig: pytest.Config) -> UnserializerReport: + def unserialize( + data: dict[str, Any], + ) -> pytest.CollectReport | pytest.TestReport: + return pytestconfig.hook.pytest_report_from_serializable( # type: ignore[no-any-return] config=pytestconfig, data=data ) return unserialize def test_basic_collect_and_runtests( - self, worker: WorkerSetup, unserialize_report + self, worker: WorkerSetup, unserialize_report: UnserializerReport ) -> None: worker.pytester.makepyfile( """ @@ -115,7 +138,9 @@ def test_func(): ev = worker.popevent("workerfinished") assert "workeroutput" in ev.kwargs - def test_remote_collect_skip(self, worker: WorkerSetup, unserialize_report) -> None: + def test_remote_collect_skip( + self, worker: WorkerSetup, unserialize_report: UnserializerReport + ) -> None: worker.pytester.makepyfile( """ import pytest @@ -129,11 +154,14 @@ def test_remote_collect_skip(self, worker: WorkerSetup, unserialize_report) -> N assert ev.name == "collectreport" rep = unserialize_report(ev.kwargs["data"]) assert rep.skipped + assert isinstance(rep.longrepr, tuple) assert rep.longrepr[2] == "Skipped: hello" ev = worker.popevent("collectionfinish") assert not ev.kwargs["ids"] - def test_remote_collect_fail(self, worker: WorkerSetup, unserialize_report) -> None: + def test_remote_collect_fail( + self, worker: WorkerSetup, unserialize_report: UnserializerReport + ) -> None: worker.pytester.makepyfile("""aasd qwe""") worker.setup() ev = worker.popevent("collectionstart") @@ -145,7 +173,9 @@ def test_remote_collect_fail(self, worker: WorkerSetup, unserialize_report) -> N ev = worker.popevent("collectionfinish") assert not ev.kwargs["ids"] - def test_runtests_all(self, worker: WorkerSetup, unserialize_report) -> None: + def test_runtests_all( + self, worker: WorkerSetup, unserialize_report: UnserializerReport + ) -> None: worker.pytester.makepyfile( """ def test_func(): pass @@ -205,13 +235,15 @@ def test_process_from_remote_error_handling( ) -> None: worker.use_callback = True worker.setup() - worker.slp.process_from_remote(("", ())) + worker.slp.process_from_remote(("", {})) out, err = capsys.readouterr() assert "INTERNALERROR> ValueError: unknown event: " in out ev = worker.popevent() assert ev.name == "errordown" - def test_steal_work(self, worker: WorkerSetup, unserialize_report) -> None: + def test_steal_work( + self, worker: WorkerSetup, unserialize_report: UnserializerReport + ) -> None: worker.pytester.makepyfile( """ import time @@ -262,7 +294,9 @@ def test_func4(): pass ev = worker.popevent("workerfinished") assert "workeroutput" in ev.kwargs - def test_steal_empty_queue(self, worker: WorkerSetup, unserialize_report) -> None: + def test_steal_empty_queue( + self, worker: WorkerSetup, unserialize_report: UnserializerReport + ) -> None: worker.pytester.makepyfile( """ def test_func(): pass diff --git a/testing/test_workermanage.py b/testing/test_workermanage.py index 1246911f..08b38851 100644 --- a/testing/test_workermanage.py +++ b/testing/test_workermanage.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pathlib import Path import shutil import textwrap @@ -19,13 +21,15 @@ @pytest.fixture -def hookrecorder(request, config, pytester: pytest.Pytester): +def hookrecorder( + config: pytest.Config, pytester: pytest.Pytester +) -> pytest.HookRecorder: hookrecorder = pytester.make_hook_recorder(config.pluginmanager) return hookrecorder @pytest.fixture -def config(pytester: pytest.Pytester): +def config(pytester: pytest.Pytester) -> pytest.Config: return pytester.parseconfig() @@ -44,24 +48,23 @@ def dest(tmp_path: Path) -> Path: @pytest.fixture -def workercontroller(monkeypatch: pytest.MonkeyPatch): +def workercontroller(monkeypatch: pytest.MonkeyPatch) -> None: class MockController: - def __init__(self, *args): + def __init__(self, *args: object) -> None: pass - def setup(self): + def setup(self) -> None: pass monkeypatch.setattr(workermanage, "WorkerController", MockController) - return MockController class TestNodeManagerPopen: - def test_popen_no_default_chdir(self, config) -> None: + def test_popen_no_default_chdir(self, config: pytest.Config) -> None: gm = NodeManager(config, ["popen"]) assert gm.specs[0].chdir is None - def test_default_chdir(self, config) -> None: + def test_default_chdir(self, config: pytest.Config) -> None: specs = ["ssh=noco", "socket=xyz"] for spec in NodeManager(config, specs).specs: assert spec.chdir == "pyexecnetcache" @@ -69,10 +72,13 @@ def test_default_chdir(self, config) -> None: assert spec.chdir == "abc" def test_popen_makegateway_events( - self, config, hookrecorder, workercontroller + self, + config: pytest.Config, + hookrecorder: pytest.HookRecorder, + workercontroller: None, ) -> None: hm = NodeManager(config, ["popen"] * 2) - hm.setup_nodes(None) + hm.setup_nodes(None) # type: ignore[arg-type] call = hookrecorder.popcall("pytest_xdist_setupnodes") assert len(call.specs) == 2 @@ -86,20 +92,24 @@ def test_popen_makegateway_events( assert not len(hm.group) def test_popens_rsync( - self, config, source: Path, dest: Path, workercontroller + self, + config: pytest.Config, + source: Path, + dest: Path, + workercontroller: None, ) -> None: hm = NodeManager(config, ["popen"] * 2) - hm.setup_nodes(None) + hm.setup_nodes(None) # type: ignore[arg-type] assert len(hm.group) == 2 for gw in hm.group: class pseudoexec: args = [] # type: ignore[var-annotated] - def __init__(self, *args): + def __init__(self, *args: object) -> None: self.args.extend(args) - def waitclose(self): + def waitclose(self) -> None: pass gw.remote_exec = pseudoexec # type: ignore[assignment] @@ -112,10 +122,10 @@ def waitclose(self): assert "sys.path.insert" in gw.remote_exec.args[0] # type: ignore[attr-defined] def test_rsync_popen_with_path( - self, config, source: Path, dest: Path, workercontroller + self, config: pytest.Config, source: Path, dest: Path, workercontroller: None ) -> None: hm = NodeManager(config, ["popen//chdir=%s" % dest] * 1) - hm.setup_nodes(None) + hm.setup_nodes(None) # type: ignore[arg-type] source.joinpath("dir1", "dir2").mkdir(parents=True) source.joinpath("dir1", "dir2", "hello").touch() notifications = [] @@ -131,15 +141,15 @@ def test_rsync_popen_with_path( def test_rsync_same_popen_twice( self, - config, + config: pytest.Config, source: Path, dest: Path, - hookrecorder, - workercontroller, + hookrecorder: pytest.HookRecorder, + workercontroller: None, ) -> None: hm = NodeManager(config, ["popen//chdir=%s" % dest] * 2) hm.roots = [] - hm.setup_nodes(None) + hm.setup_nodes(None) # type: ignore[arg-type] source.joinpath("dir1", "dir2").mkdir(parents=True) source.joinpath("dir1", "dir2", "hello").touch() gw = hm.group[0] @@ -200,7 +210,11 @@ def test_rsync_roots_no_roots( assert p.joinpath("dir1", "file1").check() def test_popen_rsync_subdir( - self, pytester: pytest.Pytester, source: Path, dest: Path, workercontroller + self, + pytester: pytest.Pytester, + source: Path, + dest: Path, + workercontroller: None, ) -> None: dir1 = source / "dir1" dir1.mkdir() @@ -214,7 +228,8 @@ def test_popen_rsync_subdir( "--tx", "popen//chdir=%s" % dest, "--rsyncdir", rsyncroot, source ) ) - nodemanager.setup_nodes(None) # calls .rsync_roots() + # calls .rsync_roots() + nodemanager.setup_nodes(None) # type: ignore[arg-type] if rsyncroot == source: dest = dest.joinpath("source") assert dest.joinpath("dir1").exists() @@ -223,14 +238,19 @@ def test_popen_rsync_subdir( nodemanager.teardown_nodes() @pytest.mark.parametrize( - "flag, expects_report", [("-q", False), ("", False), ("-v", True)] + ["flag", "expects_report"], + [ + ("-q", False), + ("", False), + ("-v", True), + ], ) def test_rsync_report( self, pytester: pytest.Pytester, source: Path, dest: Path, - workercontroller, + workercontroller: None, capsys: pytest.CaptureFixture[str], flag: str, expects_report: bool, @@ -241,7 +261,8 @@ def test_rsync_report( if flag: args.append(flag) nodemanager = NodeManager(pytester.parseconfig(*args)) - nodemanager.setup_nodes(None) # calls .rsync_roots() + # calls .rsync_roots() + nodemanager.setup_nodes(None) # type: ignore[arg-type] out, _ = capsys.readouterr() if expects_report: assert "<= pytest/__init__.py" in out @@ -249,7 +270,11 @@ def test_rsync_report( assert "<= pytest/__init__.py" not in out def test_init_rsync_roots( - self, pytester: pytest.Pytester, source: Path, dest: Path, workercontroller + self, + pytester: pytest.Pytester, + source: Path, + dest: Path, + workercontroller: None, ) -> None: dir2 = source.joinpath("dir1", "dir2") dir2.mkdir(parents=True) @@ -267,13 +292,18 @@ def test_init_rsync_roots( ) config = pytester.parseconfig(source) nodemanager = NodeManager(config, ["popen//chdir=%s" % dest]) - nodemanager.setup_nodes(None) # calls .rsync_roots() + # calls .rsync_roots() + nodemanager.setup_nodes(None) # type: ignore[arg-type] assert dest.joinpath("dir2").exists() assert not dest.joinpath("dir1").exists() assert not dest.joinpath("bogus").exists() def test_rsyncignore( - self, pytester: pytest.Pytester, source: Path, dest: Path, workercontroller + self, + pytester: pytest.Pytester, + source: Path, + dest: Path, + workercontroller: None, ) -> None: dir2 = source.joinpath("dir1", "dir2") dir2.mkdir(parents=True) @@ -297,7 +327,8 @@ def test_rsyncignore( config = pytester.parseconfig(source) config.option.rsyncignore = ["bar"] nodemanager = NodeManager(config, ["popen//chdir=%s" % dest]) - nodemanager.setup_nodes(None) # calls .rsync_roots() + # calls .rsync_roots() + nodemanager.setup_nodes(None) # type: ignore[arg-type] assert dest.joinpath("dir1").exists() assert not dest.joinpath("dir1", "dir2").exists() assert dest.joinpath("dir5", "file").exists() @@ -306,14 +337,19 @@ def test_rsyncignore( assert not dest.joinpath("bar").exists() def test_optimise_popen( - self, pytester: pytest.Pytester, source: Path, dest: Path, workercontroller + self, + pytester: pytest.Pytester, + source: Path, + dest: Path, + workercontroller: None, ) -> None: specs = ["popen"] * 3 source.joinpath("conftest.py").write_text("rsyncdirs = ['a']") source.joinpath("a").mkdir() config = pytester.parseconfig(source) nodemanager = NodeManager(config, specs) - nodemanager.setup_nodes(None) # calls .rysnc_roots() + # calls .rysnc_roots() + nodemanager.setup_nodes(None) # type: ignore[arg-type] for gwspec in nodemanager.specs: assert gwspec._samefilesystem() assert not gwspec.chdir @@ -349,7 +385,7 @@ class MyWarning(UserWarning): ), ], ) -def test_unserialize_warning_msg(w_cls): +def test_unserialize_warning_msg(w_cls: type[Warning] | str) -> None: """Test that warning serialization process works well.""" # Create a test warning message with pytest.warns(UserWarning) as w: @@ -390,7 +426,7 @@ class MyWarningUnknown(UserWarning): __module__ = "unknown" -def test_warning_serialization_tweaked_module(): +def test_warning_serialization_tweaked_module() -> None: """Test for GH#404.""" # Create a test warning message with pytest.warns(UserWarning) as w: diff --git a/testing/util.py b/testing/util.py index c7bcc552..649bcce0 100644 --- a/testing/util.py +++ b/testing/util.py @@ -5,5 +5,5 @@ class MyWarning2(UserWarning): pass -def generate_warning(): +def generate_warning() -> None: warnings.warn(MyWarning2("hello"))