diff --git a/src/syrupy/assertion.py b/src/syrupy/assertion.py index 48e57399..de37678e 100644 --- a/src/syrupy/assertion.py +++ b/src/syrupy/assertion.py @@ -262,7 +262,8 @@ def _assert(self, data: "SerializableData") -> bool: ) assertion_success = matches if not matches and self.update_snapshots: - self.extension.write_snapshot( + self.session.queue_snapshot_write( + extension=self.extension, data=serialized_data, index=self.index, ) @@ -297,6 +298,8 @@ def _post_assert(self) -> None: def _recall_data(self, index: "SnapshotIndex") -> Optional["SerializableData"]: try: - return self.extension.read_snapshot(index=index) + return self.extension.read_snapshot( + index=index, session_id=str(id(self.session)) + ) except SnapshotDoesNotExist: return None diff --git a/src/syrupy/extensions/amber/__init__.py b/src/syrupy/extensions/amber/__init__.py index ffbe42d2..ed8fa4af 100644 --- a/src/syrupy/extensions/amber/__init__.py +++ b/src/syrupy/extensions/amber/__init__.py @@ -1,3 +1,4 @@ +from functools import lru_cache from pathlib import Path from typing import ( TYPE_CHECKING, @@ -46,16 +47,23 @@ def _file_extension(self) -> str: def _read_snapshot_fossil(self, snapshot_location: str) -> "SnapshotFossil": return DataSerializer.read_file(snapshot_location) + @lru_cache() + def __cacheable_read_snapshot( + self, snapshot_location: str, cache_key: str + ) -> "SnapshotFossil": + return DataSerializer.read_file(snapshot_location) + def _read_snapshot_data_from_location( - self, snapshot_location: str, snapshot_name: str + self, snapshot_location: str, snapshot_name: str, session_id: str ) -> Optional["SerializableData"]: - snapshot = self._read_snapshot_fossil(snapshot_location).get(snapshot_name) + snapshots = self.__cacheable_read_snapshot( + snapshot_location=snapshot_location, cache_key=session_id + ) + snapshot = snapshots.get(snapshot_name) return snapshot.data if snapshot else None def _write_snapshot_fossil(self, *, snapshot_fossil: "SnapshotFossil") -> None: - snapshot_fossil_to_update = DataSerializer.read_file(snapshot_fossil.location) - snapshot_fossil_to_update.merge(snapshot_fossil) - DataSerializer.write_file(snapshot_fossil_to_update) + DataSerializer.write_file(snapshot_fossil, merge=True) __all__ = ["AmberSnapshotExtension", "DataSerializer"] diff --git a/src/syrupy/extensions/amber/serializer.py b/src/syrupy/extensions/amber/serializer.py index 6dc293a5..31f00907 100644 --- a/src/syrupy/extensions/amber/serializer.py +++ b/src/syrupy/extensions/amber/serializer.py @@ -1,4 +1,3 @@ -import functools import os from types import ( GeneratorType, @@ -71,11 +70,16 @@ class DataSerializer: _marker_crn: str = "\r\n" @classmethod - def write_file(cls, snapshot_fossil: "SnapshotFossil") -> None: + def write_file(cls, snapshot_fossil: "SnapshotFossil", merge: bool = False) -> None: """ - Writes the snapshot data into the snapshot file that be read later. + Writes the snapshot data into the snapshot file that can be read later. """ filepath = snapshot_fossil.location + if merge: + base_snapshot = cls.read_file(filepath) + base_snapshot.merge(snapshot_fossil) + snapshot_fossil = base_snapshot + with open(filepath, "w", encoding=TEXT_ENCODING, newline=None) as f: for snapshot in sorted(snapshot_fossil, key=lambda s: s.name): snapshot_data = str(snapshot.data) @@ -86,7 +90,6 @@ def write_file(cls, snapshot_fossil: "SnapshotFossil") -> None: f.write(f"\n{cls._marker_divider}\n") @classmethod - @functools.lru_cache() def read_file(cls, filepath: str) -> "SnapshotFossil": """ Read the raw snapshot data (str) from the snapshot file into a dict diff --git a/src/syrupy/extensions/base.py b/src/syrupy/extensions/base.py index 57298b73..e6bc0ca7 100644 --- a/src/syrupy/extensions/base.py +++ b/src/syrupy/extensions/base.py @@ -3,6 +3,7 @@ ABC, abstractmethod, ) +from collections import defaultdict from difflib import ndiff from gettext import gettext from itertools import zip_longest @@ -10,11 +11,13 @@ from typing import ( TYPE_CHECKING, Callable, + DefaultDict, Dict, Iterator, List, Optional, Set, + Tuple, ) from syrupy.constants import ( @@ -115,7 +118,9 @@ def discover_snapshots(self) -> "SnapshotFossils": return discovered - def read_snapshot(self, *, index: "SnapshotIndex") -> "SerializedData": + def read_snapshot( + self, *, index: "SnapshotIndex", session_id: str + ) -> "SerializedData": """ Utility method for reading the contents of a snapshot assertion. Will call `_pre_read`, then perform `read` and finally `post_read`, @@ -129,7 +134,9 @@ def read_snapshot(self, *, index: "SnapshotIndex") -> "SerializedData": snapshot_location = self.get_location(index=index) snapshot_name = self.get_snapshot_name(index=index) snapshot_data = self._read_snapshot_data_from_location( - snapshot_location=snapshot_location, snapshot_name=snapshot_name + snapshot_location=snapshot_location, + snapshot_name=snapshot_name, + session_id=session_id, ) if snapshot_data is None: raise SnapshotDoesNotExist() @@ -145,33 +152,66 @@ def write_snapshot(self, *, data: "SerializedData", index: "SnapshotIndex") -> N This method is _final_, do not override. You can override `_write_snapshot_fossil` in a subclass to change behaviour. """ - self._pre_write(data=data, index=index) - snapshot_location = self.get_location(index=index) - if not self.test_location.matches_snapshot_location(snapshot_location): - warning_msg = gettext( - "{line_end}Can not relate snapshot location '{}' to the test location." - "{line_end}Consider adding '{}' to the generated location." - ).format( - snapshot_location, - self.test_location.filename, - line_end="\n", - ) - warnings.warn(warning_msg) - snapshot_name = self.get_snapshot_name(index=index) - if not self.test_location.matches_snapshot_name(snapshot_name): - warning_msg = gettext( - "{line_end}Can not relate snapshot name '{}' to the test location." - "{line_end}Consider adding '{}' to the generated name." - ).format( - snapshot_name, - self.test_location.testname, - line_end="\n", - ) - warnings.warn(warning_msg) - snapshot_fossil = SnapshotFossil(location=snapshot_location) - snapshot_fossil.add(Snapshot(name=snapshot_name, data=data)) - self._write_snapshot_fossil(snapshot_fossil=snapshot_fossil) - self._post_write(data=data, index=index) + self.write_snapshot_batch(snapshots=[(data, index)]) + + def write_snapshot_batch( + self, *, snapshots: List[Tuple["SerializedData", "SnapshotIndex"]] + ) -> None: + """ + Utility method for writing the contents of multiple snapshot assertions. + Will call `_pre_write` per snapshot, then perform `write` per snapshot + and finally `_post_write`. + + This method is _final_, do not override. You can override + `_write_snapshot_fossil` in a subclass to change behaviour. + """ + # First we group by location since it'll let us batch by file on disk. + # Not as useful for single file snapshots, but useful for the standard + # Amber extension. + locations: DefaultDict[str, List["Snapshot"]] = defaultdict(list) + for data, index in snapshots: + location = self.get_location(index=index) + snapshot_name = self.get_snapshot_name(index=index) + locations[location].append(Snapshot(name=snapshot_name, data=data)) + + # Is there a better place to do the pre-writes? + # Or can we remove the pre-write concept altogether? + self._pre_write(data=data, index=index) + + for location, location_snapshots in locations.items(): + snapshot_fossil = SnapshotFossil(location=location) + + if not self.test_location.matches_snapshot_location(location): + warning_msg = gettext( + "{line_end}Can not relate snapshot location '{}' " + "to the test location.{line_end}" + "Consider adding '{}' to the generated location." + ).format( + location, + self.test_location.filename, + line_end="\n", + ) + warnings.warn(warning_msg) + + for snapshot in location_snapshots: + snapshot_fossil.add(snapshot) + + if not self.test_location.matches_snapshot_name(snapshot.name): + warning_msg = gettext( + "{line_end}Can not relate snapshot name '{}' " + "to the test location.{line_end}" + "Consider adding '{}' to the generated name." + ).format( + snapshot.name, + self.test_location.testname, + line_end="\n", + ) + warnings.warn(warning_msg) + + self._write_snapshot_fossil(snapshot_fossil=snapshot_fossil) + + for data, index in snapshots: + self._post_write(data=data, index=index) @abstractmethod def delete_snapshots( @@ -206,7 +246,7 @@ def _read_snapshot_fossil(self, *, snapshot_location: str) -> "SnapshotFossil": @abstractmethod def _read_snapshot_data_from_location( - self, *, snapshot_location: str, snapshot_name: str + self, *, snapshot_location: str, snapshot_name: str, session_id: str ) -> Optional["SerializedData"]: """ Get only the snapshot data from location for assertion diff --git a/src/syrupy/extensions/single_file.py b/src/syrupy/extensions/single_file.py index eee1ebc8..e80a444f 100644 --- a/src/syrupy/extensions/single_file.py +++ b/src/syrupy/extensions/single_file.py @@ -77,7 +77,7 @@ def _read_snapshot_fossil(self, *, snapshot_location: str) -> "SnapshotFossil": return snapshot_fossil def _read_snapshot_data_from_location( - self, *, snapshot_location: str, snapshot_name: str + self, *, snapshot_location: str, snapshot_name: str, session_id: str ) -> Optional["SerializableData"]: try: with open( diff --git a/src/syrupy/session.py b/src/syrupy/session.py index 014d5e72..9f5caf54 100644 --- a/src/syrupy/session.py +++ b/src/syrupy/session.py @@ -13,6 +13,7 @@ List, Optional, Set, + Tuple, ) import pytest @@ -20,6 +21,10 @@ from .constants import EXIT_STATUS_FAIL_UNUSED from .data import SnapshotFossils from .report import SnapshotReport +from .types import ( + SerializedData, + SnapshotIndex, +) if TYPE_CHECKING: from .assertion import SnapshotAssertion @@ -43,6 +48,26 @@ class SnapshotSession: default_factory=lambda: defaultdict(set) ) + _queued_snapshot_writes: Dict[ + "AbstractSyrupyExtension", List[Tuple["SerializedData", "SnapshotIndex"]] + ] = field(default_factory=dict) + + def queue_snapshot_write( + self, + extension: "AbstractSyrupyExtension", + data: "SerializedData", + index: "SnapshotIndex", + ) -> None: + queue = self._queued_snapshot_writes.get(extension, []) + queue.append((data, index)) + self._queued_snapshot_writes[extension] = queue + + def flush_snapshot_write_queue(self) -> None: + for extension, queued_write in self._queued_snapshot_writes.items(): + if queued_write: + extension.write_snapshot_batch(snapshots=queued_write) + self._queued_snapshot_writes = {} + @property def update_snapshots(self) -> bool: return bool(self.pytest_session.config.option.update_snapshots) @@ -72,6 +97,7 @@ def ran_item(self, nodeid: str) -> None: def finish(self) -> int: exitstatus = 0 + self.flush_snapshot_write_queue() self.report = SnapshotReport( base_dir=self.pytest_session.config.rootdir, collected_items=self._collected_items,