diff --git a/.gitignore b/.gitignore index 1093e258..62c025a0 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,7 @@ wheels/ *.egg MANIFEST version.txt +pip-wheel-metadata # Installer logs pip-log.txt diff --git a/src/syrupy/__init__.py b/src/syrupy/__init__.py index 784f6aaa..2f85926a 100644 --- a/src/syrupy/__init__.py +++ b/src/syrupy/__init__.py @@ -109,12 +109,15 @@ def pytest_sessionstart(session: Any) -> None: config._syrupy.start() -def pytest_collection_modifyitems(session: Any, config: Any, items: List[Any]) -> None: +def pytest_collection_modifyitems( + session: Any, config: Any, items: List["pytest.Item"] +) -> None: """ After tests are collected and before any modification is performed. https://docs.pytest.org/en/latest/reference.html#_pytest.hookspec.pytest_collection_modifyitems """ - config._syrupy._all_items.update(items) + for item in config._syrupy.filter_valid_items(items): + config._syrupy._all_items[item] = True def pytest_collection_finish(session: Any) -> None: @@ -122,7 +125,8 @@ def pytest_collection_finish(session: Any) -> None: After collection has been performed and modified. https://docs.pytest.org/en/latest/reference.html#_pytest.hookspec.pytest_collection_finish """ - session.config._syrupy._ran_items.update(session.items) + for item in session.config._syrupy.filter_valid_items(session.items): + session.config._syrupy._ran_items[item] = True def pytest_sessionfinish(session: Any, exitstatus: int) -> None: diff --git a/src/syrupy/report.py b/src/syrupy/report.py index 859d07e5..2e750578 100644 --- a/src/syrupy/report.py +++ b/src/syrupy/report.py @@ -5,13 +5,13 @@ from pathlib import Path from typing import ( TYPE_CHECKING, - Any, + Dict, Iterator, List, - Set, ) import attr +import pytest from .data import ( Snapshot, @@ -36,8 +36,8 @@ @attr.s class SnapshotReport: base_dir: str = attr.ib() - all_items: Set[Any] = attr.ib() - ran_items: Set[Any] = attr.ib() + all_items: Dict["pytest.Item", bool] = attr.ib() + ran_items: Dict["pytest.Item", bool] = attr.ib() update_snapshots: bool = attr.ib() is_providing_paths: bool = attr.ib() is_providing_nodes: bool = attr.ib() diff --git a/src/syrupy/session.py b/src/syrupy/session.py index db868869..6c33b919 100644 --- a/src/syrupy/session.py +++ b/src/syrupy/session.py @@ -1,14 +1,14 @@ from pathlib import Path from typing import ( TYPE_CHECKING, - Any, Dict, + Iterable, List, Optional, - Set, ) import attr +import pytest from .constants import EXIT_STATUS_FAIL_UNUSED from .data import SnapshotFossils @@ -28,15 +28,15 @@ class SnapshotSession: is_providing_paths: bool = attr.ib() is_providing_nodes: bool = attr.ib() report: Optional["SnapshotReport"] = attr.ib(default=None) - _all_items: Set[Any] = attr.ib(factory=set) - _ran_items: Set[Any] = attr.ib(factory=set) + _all_items: Dict["pytest.Item", bool] = attr.ib(factory=dict) + _ran_items: Dict["pytest.Item", bool] = attr.ib(factory=dict) _assertions: List["SnapshotAssertion"] = attr.ib(factory=list) _extensions: Dict[str, "AbstractSyrupyExtension"] = attr.ib(factory=dict) def start(self) -> None: self.report = None - self._all_items = set() - self._ran_items = set() + self._all_items = {} + self._ran_items = {} self._assertions = [] self._extensions = {} @@ -89,3 +89,7 @@ def remove_unused_snapshots( ) elif snapshot_location not in used_snapshot_fossils: Path(snapshot_location).unlink() + + @staticmethod + def filter_valid_items(items: List["pytest.Item"]) -> Iterable["pytest.Item"]: + return (item for item in items if isinstance(item, pytest.Function)) diff --git a/stubs/pytest.pyi b/stubs/pytest.pyi index 3dea19db..888b02b3 100644 --- a/stubs/pytest.pyi +++ b/stubs/pytest.pyi @@ -3,3 +3,6 @@ from typing import Any, Callable, TypeVar ReturnType = TypeVar("ReturnType") def fixture(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]: ... + +class Function: ... +class Item: ... diff --git a/tests/test_integration_pytest_extension.py b/tests/test_integration_pytest_extension.py new file mode 100644 index 00000000..fced0a75 --- /dev/null +++ b/tests/test_integration_pytest_extension.py @@ -0,0 +1,28 @@ +from .utils import clean_output + + +def test_ignores_non_function_nodes(testdir): + conftest = """ + import pytest + + class CustomItem(pytest.Item, pytest.File): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._nodeid += "::CUSTOM" + + def runtest(self): + pass + + def pytest_collect_file(path, parent): + return CustomItem(path, parent) + """ + testcase = """ + def test_example(snapshot): + assert snapshot == 1 + """ + testdir.makepyfile(conftest=conftest) + testdir.makepyfile(test_file=testcase) + result = testdir.runpytest("test_file.py", "-v", "--snapshot-update") + result_stdout = clean_output(result.stdout.str()) + assert result.ret == 0 + assert "test_file.py::CUSTOM" in result_stdout