diff --git a/src/syrupy/session.py b/src/syrupy/session.py index 06ed0003..75ec29ad 100644 --- a/src/syrupy/session.py +++ b/src/syrupy/session.py @@ -1,7 +1,9 @@ +from collections import defaultdict from pathlib import Path from typing import ( TYPE_CHECKING, Any, + DefaultDict, Dict, Iterable, List, @@ -34,6 +36,10 @@ class SnapshotSession: _assertions: List["SnapshotAssertion"] = attr.ib(factory=list) _extensions: Dict[str, "AbstractSyrupyExtension"] = attr.ib(factory=dict) + _locations_discovered: DefaultDict[str, Set[Any]] = attr.ib( + factory=lambda: defaultdict(set) + ) + @property def update_snapshots(self) -> bool: return bool(self._pytest_session.config.option.update_snapshots) @@ -55,6 +61,7 @@ def start(self) -> None: self._selected_items = {} self._assertions = [] self._extensions = {} + self._locations_discovered = defaultdict(set) def ran_item(self, nodeid: str) -> None: self._selected_items[nodeid] = True @@ -80,12 +87,17 @@ def finish(self) -> int: def register_request(self, assertion: "SnapshotAssertion") -> None: self._assertions.append(assertion) - discovered_extensions = { - discovered.location: assertion.extension - for discovered in assertion.extension.discover_snapshots() - if discovered.has_snapshots - } - self._extensions.update(discovered_extensions) + + test_location = assertion.extension.test_location.filepath + extension_class = assertion.extension.__class__ + if extension_class not in self._locations_discovered[test_location]: + self._locations_discovered[test_location].add(extension_class) + discovered_extensions = { + discovered.location: assertion.extension + for discovered in assertion.extension.discover_snapshots() + if discovered.has_snapshots + } + self._extensions.update(discovered_extensions) def remove_unused_snapshots( self,