Skip to content

Commit

Permalink
support for calling a function/coroutine, e.g. main within examples (
Browse files Browse the repository at this point in the history
…#39)

* call main in examples

* option to call a function
  • Loading branch information
samuelcolvin authored Nov 15, 2024
1 parent 3e916e6 commit be12278
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 39 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ jobs:
with:
enable-cache: true

- name: Install dependencies
run: uv sync --python 3.12 --frozen
- run: uv sync --python 3.12 --frozen --no-dev --group lint

- uses: pre-commit/action@v3.0.0
with:
Expand Down
14 changes: 12 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ sources = pytest_examples tests example

.PHONY: install # Install the package, dependencies, and pre-commit for local development
install: .uv
uv sync --frozen
uv sync --frozen --group lint
uv run pre-commit install --install-hooks

.PHONY: format # Format the code
Expand All @@ -22,7 +22,17 @@ lint:

.PHONY: test
test:
pytest
uv run pytest

.PHONY: test-all-python # Run tests on Python 3.9 to 3.13
test-all-python:
UV_PROJECT_ENVIRONMENT=.venv39 uv run --python 3.9 coverage run -p -m pytest
UV_PROJECT_ENVIRONMENT=.venv310 uv run --python 3.10 coverage run -p -m pytest
UV_PROJECT_ENVIRONMENT=.venv311 uv run --python 3.11 coverage run -p -m pytest
UV_PROJECT_ENVIRONMENT=.venv312 uv run --python 3.12 coverage run -p -m pytest
UV_PROJECT_ENVIRONMENT=.venv313 uv run --python 3.13 coverage run -p -m pytest
@uv run coverage combine
@uv run coverage report

.PHONY: testcov # Run tests and collect coverage data
testcov:
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ repository = 'https://github.com/pydantic/pytest-examples'
[dependency-groups]
dev = [
'coverage[toml]>=7.6.1',
'pre-commit>=3.5.0',
'pytest-pretty>=1.2.0',
]
lint = [
'pre-commit>=3.5.0',
'ruff>=0.7.4',
]

Expand Down
80 changes: 52 additions & 28 deletions pytest_examples/eval_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,16 @@ def set_config(
):
"""Set the config for lints.
:param line_length: The line length to use when wrapping print statements, defaults to 88.
:param quotes: The quote to use, defaults to "either".
:param magic_trailing_comma: If True, add a trailing comma to magic methods, defaults to True.
:param target_version: The target version to use when upgrading code, defaults to "py37".
:param upgrade: If True, upgrade the code to the target version, defaults to False.
:param isort: If True, run ruff's isort extension on the code, defaults to False.
:param ruff_line_length: In general, we disable line-length checks in ruff, to let black take care of them.
:param ruff_select: Ruff rules to select
:param ruff_ignore: Ruff rules to ignore
Args:
line_length: The line length to use when wrapping print statements, defaults to 88.
quotes: The quote to use, defaults to "either".
magic_trailing_comma: If True, add a trailing comma to magic methods, defaults to True.
target_version: The target version to use when upgrading code, defaults to "py37".
upgrade: If True, upgrade the code to the target version, defaults to False.
isort: If True, run ruff's isort extension on the code, defaults to False.
ruff_line_length: In general, we disable line-length checks in ruff, to let black take care of them.
ruff_select: Ruff rules to select
ruff_ignore: Ruff rules to ignore
"""
self.config = ExamplesConfig(
line_length=line_length,
Expand All @@ -77,16 +78,19 @@ def run(
*,
module_globals: dict[str, Any] | None = None,
rewrite_assertions: bool = True,
call: str | None = None,
) -> dict[str, Any]:
"""Run the example, print is not mocked and print statements are not checked.
:param example: The example to run.
:param module_globals: The globals to use when running the example.
:param rewrite_assertions: If True, rewrite assertions in the example using pytest's assertion rewriting.
Args:
example: The example to run.
module_globals: The globals to use when running the example.
rewrite_assertions: If True, rewrite assertions in the example using pytest's assertion rewriting.
call: If not None, method to check for and call if it exists.
"""
__tracebackhide__ = True
example.test_id = self._test_id
_, module_dict = self._run(example, None, module_globals, rewrite_assertions)
_, module_dict = self._run(example, None, module_globals, rewrite_assertions, call)
return module_dict

def run_print_check(
Expand All @@ -95,16 +99,19 @@ def run_print_check(
*,
module_globals: dict[str, Any] | None = None,
rewrite_assertions: bool = True,
call: str | None = None,
) -> dict[str, Any]:
"""Run the example and check print statements.
:param example: The example to run.
:param module_globals: The globals to use when running the example.
:param rewrite_assertions: If True, rewrite assertions in the example using pytest's assertion rewriting.
Args:
example: The example to run.
module_globals: The globals to use when running the example.
rewrite_assertions: If True, rewrite assertions in the example using pytest's assertion rewriting.
call: If not None, method to check for and call if it exists.
"""
__tracebackhide__ = True
example.test_id = self._test_id
insert_print, module_dict = self._run(example, 'check', module_globals, rewrite_assertions)
insert_print, module_dict = self._run(example, 'check', module_globals, rewrite_assertions, call)
insert_print.check_print_statements(example)
return module_dict

Expand All @@ -114,16 +121,19 @@ def run_print_update(
*,
module_globals: dict[str, Any] | None = None,
rewrite_assertions: bool = True,
call: str | None = None,
) -> dict[str, Any]:
"""Run the example and update print statements, requires `--update-examples`.
:param example: The example to run.
:param module_globals: The globals to use when running the example.
:param rewrite_assertions: If True, rewrite assertions in the example using pytest's assertion rewriting.
Args:
example: The example to run.
module_globals: The globals to use when running the example.
rewrite_assertions: If True, rewrite assertions in the example using pytest's assertion rewriting.
call: If not None, method to check for and call if it exists.
"""
__tracebackhide__ = True
self._check_update(example)
insert_print, module_dict = self._run(example, 'update', module_globals, rewrite_assertions)
insert_print, module_dict = self._run(example, 'update', module_globals, rewrite_assertions, call)

new_code = insert_print.updated_print_statements(example)
if new_code:
Expand All @@ -137,6 +147,7 @@ def _run(
insert_print_statements: Literal['check', 'update', None],
module_globals: dict[str, Any] | None,
rewrite_assertions: bool,
call: str | None,
) -> tuple[InsertPrintStatements, dict[str, Any]]:
__tracebackhide__ = True

Expand All @@ -155,21 +166,30 @@ def _run(

python_file = self._write_file(example)
return run_code(
example, python_file, loader, self.config, enable_print_mock, self.print_callback, module_globals
example=example,
python_file=python_file,
loader=loader,
config=self.config,
enable_print_mock=enable_print_mock,
print_callback=self.print_callback,
module_globals=module_globals,
call=call,
)

def lint(self, example: CodeExample) -> None:
"""Lint the example with black and ruff.
:param example: The example to lint.
Args:
example: The example to lint.
"""
self.lint_black(example)
self.lint_ruff(example)

def lint_black(self, example: CodeExample) -> None:
"""Lint the example using black.
:param example: The example to lint.
Args:
example: The example to lint.
"""
example.test_id = self._test_id
try:
Expand All @@ -183,7 +203,8 @@ def lint_ruff(
) -> None:
"""Lint the example using ruff.
:param example: The example to lint.
Args:
example: The example to lint.
"""
example.test_id = self._test_id
try:
Expand All @@ -194,15 +215,17 @@ def lint_ruff(
def format(self, example: CodeExample) -> None:
"""Format the example with black and ruff, requires `--update-examples`.
:param example: The example to format.
Args:
example: The example to format.
"""
self.format_ruff(example)
self.format_black(example)

def format_black(self, example: CodeExample) -> None:
"""Format the example using black, requires `--update-examples`.
:param example: The example to lint.
Args:
example: The example to lint.
"""
self._check_update(example)

Expand All @@ -217,7 +240,8 @@ def format_ruff(
) -> None:
"""Format the example using ruff, requires `--update-examples`.
:param example: The example to lint.
Args:
example: The example to lint.
"""
self._check_update(example)

Expand Down
11 changes: 11 additions & 0 deletions pytest_examples/run_code.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations as _annotations

import ast
import asyncio
import dataclasses
import importlib.util
import inspect
Expand Down Expand Up @@ -29,13 +30,15 @@


def run_code(
*,
example: CodeExample,
python_file: Path,
loader: Loader | None,
config: ExamplesConfig,
enable_print_mock: bool,
print_callback: Callable[[str], str] | None,
module_globals: dict[str, Any] | None,
call: str | None,
) -> tuple[InsertPrintStatements, dict[str, Any]]:
__tracebackhide__ = True

Expand All @@ -47,10 +50,18 @@ def run_code(

if module_globals:
module.__dict__.update(module_globals)

try:
with insert_print:
sys.modules[spec.name] = module
spec.loader.exec_module(module)
if call:
to_call = getattr(module, call, None)
if to_call is not None:
if inspect.iscoroutinefunction(to_call):
asyncio.run(to_call())
else:
to_call()
except KeyboardInterrupt:
print('KeyboardInterrupt in example')
except Exception as exc:
Expand Down
53 changes: 53 additions & 0 deletions tests/test_insert_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,3 +365,56 @@ def test_insert_print_check_change(tmp_path, eval_example):
' print(1, 2, [3, 4, 5], "hello")\n'
' +#> 1 2 [3, 4, 5] hello\n'
)


def test_run_main(tmp_path, eval_example):
# note this file is no written here as it's not required
md_file = tmp_path / 'test.md'
python_code = """
def main():
1 / 0
"""
example = CodeExample.create(python_code, path=md_file)
eval_example.set_config(line_length=30)
eval_example.run_print_check(example)

with pytest.raises(ZeroDivisionError):
eval_example.run_print_check(example, call='main')


def test_run_main_print(tmp_path, eval_example):
# note this file is no written here as it's not required
md_file = tmp_path / 'test.md'
python_code = """
main_called = False
def main():
global main_called
main_called = True
print(1, 2, 3)
#> 1 2 3
"""
example = CodeExample.create(python_code, path=md_file)
eval_example.set_config(line_length=30)

module_dict = eval_example.run_print_check(example, call='main')
assert module_dict['main_called']


def test_run_main_print_async(tmp_path, eval_example):
# note this file is no written here as it's not required
md_file = tmp_path / 'test.md'
python_code = """
main_called = False
async def main():
global main_called
main_called = True
print(1, 2, 3)
#> 1 2 3
"""
example = CodeExample.create(python_code, path=md_file)
eval_example.set_config(line_length=30)

module_dict = eval_example.run_print_check(example, call='main')
assert module_dict['main_called']
10 changes: 6 additions & 4 deletions tests/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def test_update_files(pytester: pytest.Pytester):
```
```py
print(["first things", "second things", "third things"])
async def main():
print(["first things", "second things", "third things"])
```
""",
)
Expand Down Expand Up @@ -45,7 +46,7 @@ def func_a():
def test_find_examples(example: CodeExample, eval_example: EvalExample):
if eval_example.update_examples:
eval_example.lint(example)
eval_example.run_print_update(example)
eval_example.run_print_update(example, call='main')
else:
eval_example.lint(example)
# insert_print_statements='check' would fail here
Expand All @@ -70,8 +71,9 @@ def test_find_examples(example: CodeExample, eval_example: EvalExample):
```
```py
print(["first things", "second things", "third things"])
#> ['first things', 'second things', 'third things']
async def main():
print(["first things", "second things", "third things"])
#> ['first things', 'second things', 'third things']
```"""
)
assert (
Expand Down
8 changes: 6 additions & 2 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit be12278

Please sign in to comment.