Skip to content

Commit

Permalink
Add PosArgT typing to trio.run (#3022)
Browse files Browse the repository at this point in the history
* add PosArgT typing to run()

* add type tests
  • Loading branch information
jakkdl authored Jun 26, 2024
1 parent a7db0e4 commit 80eec96
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/trio/_core/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2184,8 +2184,8 @@ def setup_runner(


def run(
async_fn: Callable[..., Awaitable[RetT]],
*args: object,
async_fn: Callable[[Unpack[PosArgT]], Awaitable[RetT]],
*args: Unpack[PosArgT],
clock: Clock | None = None,
instruments: Sequence[Instrument] = (),
restrict_keyboard_interrupt_to_checkpoints: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion src/trio/_core/_tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ async def trivial(x: T) -> T:

with pytest.raises(TypeError):
# Missing an argument
_core.run(trivial)
_core.run(trivial) # type: ignore[arg-type]

with pytest.raises(TypeError):
# Not an async function
Expand Down
46 changes: 46 additions & 0 deletions src/trio/_core/_tests/type_tests/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from __future__ import annotations

from typing import Sequence, overload

import trio
from typing_extensions import assert_type


async def sleep_sort(values: Sequence[float]) -> list[float]:
return [1]


async def has_optional(arg: int | None = None) -> int:
return 5


@overload
async def foo_overloaded(arg: int) -> str: ...


@overload
async def foo_overloaded(arg: str) -> int: ...


async def foo_overloaded(arg: int | str) -> int | str:
if isinstance(arg, str):
return 5
return "hello"


v = trio.run(
sleep_sort, (1, 3, 5, 2, 4), clock=trio.testing.MockClock(autojump_threshold=0)
)
assert_type(v, "list[float]")
trio.run(sleep_sort, ["hi", "there"]) # type: ignore[arg-type]
trio.run(sleep_sort) # type: ignore[arg-type]

r = trio.run(has_optional)
assert_type(r, int)
r = trio.run(has_optional, 5)
trio.run(has_optional, 7, 8) # type: ignore[arg-type]
trio.run(has_optional, "hello") # type: ignore[arg-type]


assert_type(trio.run(foo_overloaded, 5), str)
assert_type(trio.run(foo_overloaded, ""), int)

0 comments on commit 80eec96

Please sign in to comment.