diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cd0c990..31f92db 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,8 +21,7 @@ jobs: with: enable-cache: true - - name: Install dependencies - run: uv sync --python 3.12 --frozen + - run: uv sync --python 3.12 --frozen --no-dev --group lint - uses: pre-commit/action@v3.0.0 with: diff --git a/Makefile b/Makefile index f6ca7f9..81d1d72 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ sources = pytest_examples tests example .PHONY: install # Install the package, dependencies, and pre-commit for local development install: .uv - uv sync --frozen + uv sync --frozen --group lint uv run pre-commit install --install-hooks .PHONY: format # Format the code @@ -22,7 +22,17 @@ lint: .PHONY: test test: - pytest + uv run pytest + +.PHONY: test-all-python # Run tests on Python 3.9 to 3.13 +test-all-python: + UV_PROJECT_ENVIRONMENT=.venv39 uv run --python 3.9 coverage run -p -m pytest + UV_PROJECT_ENVIRONMENT=.venv310 uv run --python 3.10 coverage run -p -m pytest + UV_PROJECT_ENVIRONMENT=.venv311 uv run --python 3.11 coverage run -p -m pytest + UV_PROJECT_ENVIRONMENT=.venv312 uv run --python 3.12 coverage run -p -m pytest + UV_PROJECT_ENVIRONMENT=.venv313 uv run --python 3.13 coverage run -p -m pytest + @uv run coverage combine + @uv run coverage report .PHONY: testcov # Run tests and collect coverage data testcov: diff --git a/pyproject.toml b/pyproject.toml index c1a9389..256306e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,8 +51,10 @@ repository = 'https://github.com/pydantic/pytest-examples' [dependency-groups] dev = [ 'coverage[toml]>=7.6.1', - 'pre-commit>=3.5.0', 'pytest-pretty>=1.2.0', +] +lint = [ + 'pre-commit>=3.5.0', 'ruff>=0.7.4', ] diff --git a/pytest_examples/eval_example.py b/pytest_examples/eval_example.py index 4d2b475..9a4d8c1 100644 --- a/pytest_examples/eval_example.py +++ b/pytest_examples/eval_example.py @@ -45,15 +45,16 @@ def set_config( ): """Set the config for lints. - :param line_length: The line length to use when wrapping print statements, defaults to 88. - :param quotes: The quote to use, defaults to "either". - :param magic_trailing_comma: If True, add a trailing comma to magic methods, defaults to True. - :param target_version: The target version to use when upgrading code, defaults to "py37". - :param upgrade: If True, upgrade the code to the target version, defaults to False. - :param isort: If True, run ruff's isort extension on the code, defaults to False. - :param ruff_line_length: In general, we disable line-length checks in ruff, to let black take care of them. - :param ruff_select: Ruff rules to select - :param ruff_ignore: Ruff rules to ignore + Args: + line_length: The line length to use when wrapping print statements, defaults to 88. + quotes: The quote to use, defaults to "either". + magic_trailing_comma: If True, add a trailing comma to magic methods, defaults to True. + target_version: The target version to use when upgrading code, defaults to "py37". + upgrade: If True, upgrade the code to the target version, defaults to False. + isort: If True, run ruff's isort extension on the code, defaults to False. + ruff_line_length: In general, we disable line-length checks in ruff, to let black take care of them. + ruff_select: Ruff rules to select + ruff_ignore: Ruff rules to ignore """ self.config = ExamplesConfig( line_length=line_length, @@ -77,16 +78,19 @@ def run( *, module_globals: dict[str, Any] | None = None, rewrite_assertions: bool = True, + call: str | None = None, ) -> dict[str, Any]: """Run the example, print is not mocked and print statements are not checked. - :param example: The example to run. - :param module_globals: The globals to use when running the example. - :param rewrite_assertions: If True, rewrite assertions in the example using pytest's assertion rewriting. + Args: + example: The example to run. + module_globals: The globals to use when running the example. + rewrite_assertions: If True, rewrite assertions in the example using pytest's assertion rewriting. + call: If not None, method to check for and call if it exists. """ __tracebackhide__ = True example.test_id = self._test_id - _, module_dict = self._run(example, None, module_globals, rewrite_assertions) + _, module_dict = self._run(example, None, module_globals, rewrite_assertions, call) return module_dict def run_print_check( @@ -95,16 +99,19 @@ def run_print_check( *, module_globals: dict[str, Any] | None = None, rewrite_assertions: bool = True, + call: str | None = None, ) -> dict[str, Any]: """Run the example and check print statements. - :param example: The example to run. - :param module_globals: The globals to use when running the example. - :param rewrite_assertions: If True, rewrite assertions in the example using pytest's assertion rewriting. + Args: + example: The example to run. + module_globals: The globals to use when running the example. + rewrite_assertions: If True, rewrite assertions in the example using pytest's assertion rewriting. + call: If not None, method to check for and call if it exists. """ __tracebackhide__ = True example.test_id = self._test_id - insert_print, module_dict = self._run(example, 'check', module_globals, rewrite_assertions) + insert_print, module_dict = self._run(example, 'check', module_globals, rewrite_assertions, call) insert_print.check_print_statements(example) return module_dict @@ -114,16 +121,19 @@ def run_print_update( *, module_globals: dict[str, Any] | None = None, rewrite_assertions: bool = True, + call: str | None = None, ) -> dict[str, Any]: """Run the example and update print statements, requires `--update-examples`. - :param example: The example to run. - :param module_globals: The globals to use when running the example. - :param rewrite_assertions: If True, rewrite assertions in the example using pytest's assertion rewriting. + Args: + example: The example to run. + module_globals: The globals to use when running the example. + rewrite_assertions: If True, rewrite assertions in the example using pytest's assertion rewriting. + call: If not None, method to check for and call if it exists. """ __tracebackhide__ = True self._check_update(example) - insert_print, module_dict = self._run(example, 'update', module_globals, rewrite_assertions) + insert_print, module_dict = self._run(example, 'update', module_globals, rewrite_assertions, call) new_code = insert_print.updated_print_statements(example) if new_code: @@ -137,6 +147,7 @@ def _run( insert_print_statements: Literal['check', 'update', None], module_globals: dict[str, Any] | None, rewrite_assertions: bool, + call: str | None, ) -> tuple[InsertPrintStatements, dict[str, Any]]: __tracebackhide__ = True @@ -155,13 +166,21 @@ def _run( python_file = self._write_file(example) return run_code( - example, python_file, loader, self.config, enable_print_mock, self.print_callback, module_globals + example=example, + python_file=python_file, + loader=loader, + config=self.config, + enable_print_mock=enable_print_mock, + print_callback=self.print_callback, + module_globals=module_globals, + call=call, ) def lint(self, example: CodeExample) -> None: """Lint the example with black and ruff. - :param example: The example to lint. + Args: + example: The example to lint. """ self.lint_black(example) self.lint_ruff(example) @@ -169,7 +188,8 @@ def lint(self, example: CodeExample) -> None: def lint_black(self, example: CodeExample) -> None: """Lint the example using black. - :param example: The example to lint. + Args: + example: The example to lint. """ example.test_id = self._test_id try: @@ -183,7 +203,8 @@ def lint_ruff( ) -> None: """Lint the example using ruff. - :param example: The example to lint. + Args: + example: The example to lint. """ example.test_id = self._test_id try: @@ -194,7 +215,8 @@ def lint_ruff( def format(self, example: CodeExample) -> None: """Format the example with black and ruff, requires `--update-examples`. - :param example: The example to format. + Args: + example: The example to format. """ self.format_ruff(example) self.format_black(example) @@ -202,7 +224,8 @@ def format(self, example: CodeExample) -> None: def format_black(self, example: CodeExample) -> None: """Format the example using black, requires `--update-examples`. - :param example: The example to lint. + Args: + example: The example to lint. """ self._check_update(example) @@ -217,7 +240,8 @@ def format_ruff( ) -> None: """Format the example using ruff, requires `--update-examples`. - :param example: The example to lint. + Args: + example: The example to lint. """ self._check_update(example) diff --git a/pytest_examples/run_code.py b/pytest_examples/run_code.py index 6d0b710..3bcc6d0 100644 --- a/pytest_examples/run_code.py +++ b/pytest_examples/run_code.py @@ -1,6 +1,7 @@ from __future__ import annotations as _annotations import ast +import asyncio import dataclasses import importlib.util import inspect @@ -29,6 +30,7 @@ def run_code( + *, example: CodeExample, python_file: Path, loader: Loader | None, @@ -36,6 +38,7 @@ def run_code( enable_print_mock: bool, print_callback: Callable[[str], str] | None, module_globals: dict[str, Any] | None, + call: str | None, ) -> tuple[InsertPrintStatements, dict[str, Any]]: __tracebackhide__ = True @@ -47,10 +50,18 @@ def run_code( if module_globals: module.__dict__.update(module_globals) + try: with insert_print: sys.modules[spec.name] = module spec.loader.exec_module(module) + if call: + to_call = getattr(module, call, None) + if to_call is not None: + if inspect.iscoroutinefunction(to_call): + asyncio.run(to_call()) + else: + to_call() except KeyboardInterrupt: print('KeyboardInterrupt in example') except Exception as exc: diff --git a/tests/test_insert_print.py b/tests/test_insert_print.py index 0e583be..e1316fc 100644 --- a/tests/test_insert_print.py +++ b/tests/test_insert_print.py @@ -365,3 +365,56 @@ def test_insert_print_check_change(tmp_path, eval_example): ' print(1, 2, [3, 4, 5], "hello")\n' ' +#> 1 2 [3, 4, 5] hello\n' ) + + +def test_run_main(tmp_path, eval_example): + # note this file is no written here as it's not required + md_file = tmp_path / 'test.md' + python_code = """ +def main(): + 1 / 0 +""" + example = CodeExample.create(python_code, path=md_file) + eval_example.set_config(line_length=30) + eval_example.run_print_check(example) + + with pytest.raises(ZeroDivisionError): + eval_example.run_print_check(example, call='main') + + +def test_run_main_print(tmp_path, eval_example): + # note this file is no written here as it's not required + md_file = tmp_path / 'test.md' + python_code = """ +main_called = False + +def main(): + global main_called + main_called = True + print(1, 2, 3) + #> 1 2 3 +""" + example = CodeExample.create(python_code, path=md_file) + eval_example.set_config(line_length=30) + + module_dict = eval_example.run_print_check(example, call='main') + assert module_dict['main_called'] + + +def test_run_main_print_async(tmp_path, eval_example): + # note this file is no written here as it's not required + md_file = tmp_path / 'test.md' + python_code = """ +main_called = False + +async def main(): + global main_called + main_called = True + print(1, 2, 3) + #> 1 2 3 +""" + example = CodeExample.create(python_code, path=md_file) + eval_example.set_config(line_length=30) + + module_dict = eval_example.run_print_check(example, call='main') + assert module_dict['main_called'] diff --git a/tests/test_update.py b/tests/test_update.py index e145d4b..da66b2c 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -13,7 +13,8 @@ def test_update_files(pytester: pytest.Pytester): ``` ```py -print(["first things", "second things", "third things"]) +async def main(): + print(["first things", "second things", "third things"]) ``` """, ) @@ -45,7 +46,7 @@ def func_a(): def test_find_examples(example: CodeExample, eval_example: EvalExample): if eval_example.update_examples: eval_example.lint(example) - eval_example.run_print_update(example) + eval_example.run_print_update(example, call='main') else: eval_example.lint(example) # insert_print_statements='check' would fail here @@ -70,8 +71,9 @@ def test_find_examples(example: CodeExample, eval_example: EvalExample): ``` ```py -print(["first things", "second things", "third things"]) -#> ['first things', 'second things', 'third things'] +async def main(): + print(["first things", "second things", "third things"]) + #> ['first things', 'second things', 'third things'] ```""" ) assert ( diff --git a/uv.lock b/uv.lock index 892fd5e..4b00ce7 100644 --- a/uv.lock +++ b/uv.lock @@ -328,8 +328,10 @@ dependencies = [ [package.dev-dependencies] dev = [ { name = "coverage", extra = ["toml"] }, - { name = "pre-commit" }, { name = "pytest-pretty" }, +] +lint = [ + { name = "pre-commit" }, { name = "ruff" }, ] @@ -343,8 +345,10 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ { name = "coverage", extras = ["toml"], specifier = ">=7.6.1" }, - { name = "pre-commit", specifier = ">=3.5.0" }, { name = "pytest-pretty", specifier = ">=1.2.0" }, +] +lint = [ + { name = "pre-commit", specifier = ">=3.5.0" }, { name = "ruff", specifier = ">=0.7.4" }, ]