From dbf395a21ad11f1bd5c9d4ef7375c0b7a8ca7d43 Mon Sep 17 00:00:00 2001 From: Huon Wilson Date: Mon, 23 Dec 2024 10:28:29 +1100 Subject: [PATCH 1/8] fix: check current session's pending-write queue when recalling snapshots (e.g. diffing) --- src/syrupy/assertion.py | 6 +-- src/syrupy/session.py | 37 ++++++++++++++-- tests/integration/test_snapshot_diff.py | 56 +++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 9 deletions(-) create mode 100644 tests/integration/test_snapshot_diff.py diff --git a/src/syrupy/assertion.py b/src/syrupy/assertion.py index 3c2b89fb..1ff7b353 100644 --- a/src/syrupy/assertion.py +++ b/src/syrupy/assertion.py @@ -377,11 +377,7 @@ def _recall_data( ) -> Tuple[Optional["SerializableData"], bool]: try: return ( - self.extension.read_snapshot( - test_location=self.test_location, - index=index, - session_id=str(id(self.session)), - ), + self.session.recall_snapshot(self.extension, self.test_location, index), False, ) except SnapshotDoesNotExist: diff --git a/src/syrupy/session.py b/src/syrupy/session.py index 9770948a..918aec68 100644 --- a/src/syrupy/session.py +++ b/src/syrupy/session.py @@ -67,17 +67,25 @@ class SnapshotSession: List[Tuple["SerializedData", "PyTestLocation", "SnapshotIndex"]], ] = field(default_factory=dict) - def queue_snapshot_write( + def _snapshot_write_queue_key( self, extension: "AbstractSyrupyExtension", test_location: "PyTestLocation", - data: "SerializedData", index: "SnapshotIndex", - ) -> None: + ) -> Tuple[Type["AbstractSyrupyExtension"], str]: snapshot_location = extension.get_location( test_location=test_location, index=index ) - key = (extension.__class__, snapshot_location) + return (extension.__class__, snapshot_location) + + def queue_snapshot_write( + self, + extension: "AbstractSyrupyExtension", + test_location: "PyTestLocation", + data: "SerializedData", + index: "SnapshotIndex", + ) -> None: + key = self._snapshot_write_queue_key(extension, test_location, index) queue = self._queued_snapshot_writes.get(key, []) queue.append((data, test_location, index)) self._queued_snapshot_writes[key] = queue @@ -93,6 +101,27 @@ def flush_snapshot_write_queue(self) -> None: ) self._queued_snapshot_writes = {} + def recall_snapshot( + self, + extension: "AbstractSyrupyExtension", + test_location: "PyTestLocation", + index: "SnapshotIndex", + ) -> Optional["SerializedData"]: + """Find the current value of the snapshot, for this session, either a pending write or the actual snapshot.""" + + key = self._snapshot_write_queue_key(extension, test_location, index) + queue = self._queued_snapshot_writes.get(key) + if queue: + # find the last (i.e. most recent) write to this index/location in the queue: + for queue_data, queue_test_location, queue_index in reversed(queue): + if queue_index == index and queue_test_location == test_location: + return queue_data + + # no queue, or no matching write, so just read the snapshot directly: + return extension.read_snapshot( + test_location=test_location, index=index, session_id=str(id(self)) + ) + @property def update_snapshots(self) -> bool: return bool(self.pytest_session.config.option.update_snapshots) diff --git a/tests/integration/test_snapshot_diff.py b/tests/integration/test_snapshot_diff.py new file mode 100644 index 00000000..302b4d94 --- /dev/null +++ b/tests/integration/test_snapshot_diff.py @@ -0,0 +1,56 @@ +import pytest + +_TEST = """ +def test_foo(snapshot): + assert {**base} == snapshot(name="a") + assert {**base, **extra} == snapshot(name="b", diff="a") +""" + + +def _make_file(testdir, base, extra): + testdir.makepyfile( + test_file="\n\n".join([f"base = {base!r}", f"extra = {extra!r}", _TEST]) + ) + + +def _run_test(testdir, base, extra, expected_update_lines): + _make_file(testdir, base=base, extra=extra) + + # Run with --snapshot-update, to generate/update snapshots: + result = testdir.runpytest( + "-v", + "--snapshot-update", + ) + result.stdout.re_match_lines((expected_update_lines,)) + assert result.ret == 0 + + # Run without --snapshot-update, to validate the snapshots are actually up-to-date + result = testdir.runpytest("-v") + result.stdout.re_match_lines((r"2 snapshots passed\.",)) + assert result.ret == 0 + + +def test_diff_lifecycle(testdir) -> pytest.Testdir: + # first: create both snapshots completely from scratch + _run_test( + testdir, + base={"A": 1}, + extra={"X": 10}, + expected_update_lines=r"2 snapshots generated\.", + ) + + # second: edit the base data, to change the data for both snapshots (only changes the serialized output for the base snapshot `a`). + _run_test( + testdir, + base={"A": 1, "B": 2}, + extra={"X": 10}, + expected_update_lines=r"1 snapshot passed. 1 snapshot updated\.", + ) + + # third: edit just the extra data (only changes the serialized output for the diff snapshot `b`) + _run_test( + testdir, + base={"A": 1, "B": 2}, + extra={"X": 10, "Y": 20}, + expected_update_lines=r"1 snapshot passed. 1 snapshot updated\.", + ) From 0fb10fd34de5b90a86da7747cc9e8c56f1e968b2 Mon Sep 17 00:00:00 2001 From: Huon Wilson Date: Mon, 13 Jan 2025 11:00:13 +1100 Subject: [PATCH 2/8] Make PyTestLocation hashable --- src/syrupy/location.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/src/syrupy/location.py b/src/syrupy/location.py index 0f955bb8..bfa3c91f 100644 --- a/src/syrupy/location.py +++ b/src/syrupy/location.py @@ -13,7 +13,7 @@ from syrupy.constants import PYTEST_NODE_SEP -@dataclass +@dataclass(frozen=True) class PyTestLocation: item: "pytest.Item" nodename: Optional[str] = field(init=False) @@ -23,27 +23,41 @@ class PyTestLocation: filepath: str = field(init=False) def __post_init__(self) -> None: + # NB. we're in a frozen dataclass, but need to transform the values that the caller + # supplied... we do so by (ab)using object.__setattr__ to forcibly set the attributes. (See + # rejected PEP-0712 for an example of a better way to handle this.) + # + # This is safe because this all happens during initialization: `self` hasn't been hashed + # (or, e.g., stored in a dict), so the mutation won't be noticed. if self.is_doctest: return self.__attrs_post_init_doc__() self.__attrs_post_init_def__() def __attrs_post_init_def__(self) -> None: node_path: Path = getattr(self.item, "path") # noqa: B009 - self.filepath = str(node_path.absolute()) + # See __post_init__ for discussion of object.__setattr__ + object.__setattr__(self, "filepath", str(node_path.absolute())) obj = getattr(self.item, "obj") # noqa: B009 - self.modulename = obj.__module__ - self.methodname = obj.__name__ - self.nodename = getattr(self.item, "name", None) - self.testname = self.nodename or self.methodname + object.__setattr__(self, "modulename", obj.__module__) + object.__setattr__(self, "methodname", obj.__name__) + object.__setattr__(self, "nodename", getattr(self.item, "name", None)) + object.__setattr__(self, "testname", self.nodename or self.methodname) def __attrs_post_init_doc__(self) -> None: doctest = getattr(self.item, "dtest") # noqa: B009 - self.filepath = doctest.filename + # See __post_init__ for discussion of object.__setattr__ + object.__setattr__(self, "filepath", doctest.filename) test_relfile, test_node = self.nodeid.split(PYTEST_NODE_SEP) test_relpath = Path(test_relfile) - self.modulename = ".".join([*test_relpath.parent.parts, test_relpath.stem]) - self.nodename = test_node.replace(f"{self.modulename}.", "") - self.testname = self.nodename or self.methodname + object.__setattr__( + self, + "modulename", + ".".join([*test_relpath.parent.parts, test_relpath.stem]), + ) + object.__setattr__( + self, "nodename", test_node.replace(f"{self.modulename}.", "") + ) + object.__setattr__(self, "testname", self.nodename or self.methodname) @property def classname(self) -> Optional[str]: From 0f694ca9db612d05e773f4cdc01f383f7a95d5ca Mon Sep 17 00:00:00 2001 From: Huon Wilson Date: Mon, 13 Jan 2025 11:44:54 +1100 Subject: [PATCH 3/8] Explicitly set methodname to None for doctests ----------------------------------------------------------------------------------- benchmark: 3 tests ----------------------------------------------------------------------------------- Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ test_1000x_reads 666.9710 (1.0) 748.6652 (1.0) 705.2418 (1.0) 37.2862 (1.0) 703.0552 (1.0) 70.1912 (1.07) 2;0 1.4180 (1.0) 5 1 test_standard 669.7840 (1.00) 843.3747 (1.13) 733.8905 (1.04) 68.2257 (1.83) 705.8282 (1.00) 85.6269 (1.30) 1;0 1.3626 (0.96) 5 1 test_1000x_writes 793.8229 (1.19) 937.1953 (1.25) 850.9716 (1.21) 54.4067 (1.46) 847.3260 (1.21) 65.9041 (1.0) 2;0 1.1751 (0.83) 5 1 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ --- src/syrupy/location.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/syrupy/location.py b/src/syrupy/location.py index bfa3c91f..8a85d0b9 100644 --- a/src/syrupy/location.py +++ b/src/syrupy/location.py @@ -54,6 +54,7 @@ def __attrs_post_init_doc__(self) -> None: "modulename", ".".join([*test_relpath.parent.parts, test_relpath.stem]), ) + object.__setattr__(self, "methodname", None) object.__setattr__( self, "nodename", test_node.replace(f"{self.modulename}.", "") ) From c577ba6b5a74b4a5c6e99c89af4411cc5c0e117f Mon Sep 17 00:00:00 2001 From: Huon Wilson Date: Mon, 13 Jan 2025 11:34:17 +1100 Subject: [PATCH 4/8] Queue writes with a dict for O(1) look-ups Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_1000x_reads 625.5781 (1.0) 887.4346 (1.0) 694.6221 (1.0) 109.0048 (1.0) 658.3128 (1.0) 87.7517 (1.0) 1;1 1.4396 (1.0) 5 1 test_1000x_writes 637.3099 (1.02) 1,021.0924 (1.15) 812.9789 (1.17) 150.2342 (1.38) 757.7635 (1.15) 215.9572 (2.46) 2;0 1.2300 (0.85) 5 1 test_standard 694.1814 (1.11) 1,037.9224 (1.17) 845.1463 (1.22) 136.2068 (1.25) 785.6973 (1.19) 194.9636 (2.22) 2;0 1.1832 (0.82) 5 1 ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- --- src/syrupy/session.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/syrupy/session.py b/src/syrupy/session.py index 918aec68..ebec30d3 100644 --- a/src/syrupy/session.py +++ b/src/syrupy/session.py @@ -64,7 +64,7 @@ class SnapshotSession: _queued_snapshot_writes: Dict[ Tuple[Type["AbstractSyrupyExtension"], str], - List[Tuple["SerializedData", "PyTestLocation", "SnapshotIndex"]], + Dict[Tuple["PyTestLocation", "SnapshotIndex"], "SerializedData"], ] = field(default_factory=dict) def _snapshot_write_queue_key( @@ -86,8 +86,8 @@ def queue_snapshot_write( index: "SnapshotIndex", ) -> None: key = self._snapshot_write_queue_key(extension, test_location, index) - queue = self._queued_snapshot_writes.get(key, []) - queue.append((data, test_location, index)) + queue = self._queued_snapshot_writes.get(key, {}) + queue[(test_location, index)] = data self._queued_snapshot_writes[key] = queue def flush_snapshot_write_queue(self) -> None: @@ -97,7 +97,11 @@ def flush_snapshot_write_queue(self) -> None: ), queued_write in self._queued_snapshot_writes.items(): if queued_write: extension_class.write_snapshot( - snapshot_location=snapshot_location, snapshots=queued_write + snapshot_location=snapshot_location, + snapshots=[ + (data, loc, index) + for (loc, index), data in queued_write.items() + ], ) self._queued_snapshot_writes = {} @@ -110,12 +114,10 @@ def recall_snapshot( """Find the current value of the snapshot, for this session, either a pending write or the actual snapshot.""" key = self._snapshot_write_queue_key(extension, test_location, index) - queue = self._queued_snapshot_writes.get(key) - if queue: - # find the last (i.e. most recent) write to this index/location in the queue: - for queue_data, queue_test_location, queue_index in reversed(queue): - if queue_index == index and queue_test_location == test_location: - return queue_data + queue = self._queued_snapshot_writes.get(key, {}) + data = queue.get((test_location, index)) + if data is not None: + return data # no queue, or no matching write, so just read the snapshot directly: return extension.read_snapshot( From 55c4995c7f9467e68a32f59efd8131bca80d8296 Mon Sep 17 00:00:00 2001 From: Huon Wilson Date: Mon, 13 Jan 2025 16:13:42 +1100 Subject: [PATCH 5/8] Use type aliases --- src/syrupy/session.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/syrupy/session.py b/src/syrupy/session.py index ebec30d3..f86883aa 100644 --- a/src/syrupy/session.py +++ b/src/syrupy/session.py @@ -46,6 +46,10 @@ class ItemStatus(Enum): SKIPPED = "skipped" +_QueuedWriteExtensionKey = Tuple[Type["AbstractSyrupyExtension"], str] +_QueuedWriteTestLocationKey = Tuple["PyTestLocation", "SnapshotIndex"] + + @dataclass class SnapshotSession: pytest_session: "pytest.Session" @@ -63,8 +67,8 @@ class SnapshotSession: ) _queued_snapshot_writes: Dict[ - Tuple[Type["AbstractSyrupyExtension"], str], - Dict[Tuple["PyTestLocation", "SnapshotIndex"], "SerializedData"], + _QueuedWriteExtensionKey, + Dict[_QueuedWriteTestLocationKey, "SerializedData"], ] = field(default_factory=dict) def _snapshot_write_queue_key( @@ -72,7 +76,7 @@ def _snapshot_write_queue_key( extension: "AbstractSyrupyExtension", test_location: "PyTestLocation", index: "SnapshotIndex", - ) -> Tuple[Type["AbstractSyrupyExtension"], str]: + ) -> _QueuedWriteExtensionKey: snapshot_location = extension.get_location( test_location=test_location, index=index ) From 9a3c9f0adf348509a6ac44114428049772735851 Mon Sep 17 00:00:00 2001 From: Huon Wilson Date: Mon, 13 Jan 2025 16:15:07 +1100 Subject: [PATCH 6/8] return both keys from _snapshot_write_queue_key --- src/syrupy/session.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/syrupy/session.py b/src/syrupy/session.py index f86883aa..970718d6 100644 --- a/src/syrupy/session.py +++ b/src/syrupy/session.py @@ -71,16 +71,16 @@ class SnapshotSession: Dict[_QueuedWriteTestLocationKey, "SerializedData"], ] = field(default_factory=dict) - def _snapshot_write_queue_key( + def _snapshot_write_queue_keys( self, extension: "AbstractSyrupyExtension", test_location: "PyTestLocation", index: "SnapshotIndex", - ) -> _QueuedWriteExtensionKey: + ) -> Tuple[_QueuedWriteExtensionKey, _QueuedWriteTestLocationKey]: snapshot_location = extension.get_location( test_location=test_location, index=index ) - return (extension.__class__, snapshot_location) + return (extension.__class__, snapshot_location), (test_location, index) def queue_snapshot_write( self, @@ -89,10 +89,12 @@ def queue_snapshot_write( data: "SerializedData", index: "SnapshotIndex", ) -> None: - key = self._snapshot_write_queue_key(extension, test_location, index) - queue = self._queued_snapshot_writes.get(key, {}) - queue[(test_location, index)] = data - self._queued_snapshot_writes[key] = queue + ext_key, loc_key = self._snapshot_write_queue_keys( + extension, test_location, index + ) + queue = self._queued_snapshot_writes.get(ext_key, {}) + queue[loc_key] = data + self._queued_snapshot_writes[ext_key] = queue def flush_snapshot_write_queue(self) -> None: for ( @@ -117,9 +119,11 @@ def recall_snapshot( ) -> Optional["SerializedData"]: """Find the current value of the snapshot, for this session, either a pending write or the actual snapshot.""" - key = self._snapshot_write_queue_key(extension, test_location, index) - queue = self._queued_snapshot_writes.get(key, {}) - data = queue.get((test_location, index)) + ext_key, loc_key = self._snapshot_write_queue_keys( + extension, test_location, index + ) + queue = self._queued_snapshot_writes.get(ext_key, {}) + data = queue.get(loc_key) if data is not None: return data From 817de5217f64aa7af9e6dba7ba60e9942957ff6a Mon Sep 17 00:00:00 2001 From: Huon Wilson Date: Mon, 13 Jan 2025 16:19:08 +1100 Subject: [PATCH 7/8] Use a defaultdict --- src/syrupy/session.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/syrupy/session.py b/src/syrupy/session.py index 970718d6..90739c19 100644 --- a/src/syrupy/session.py +++ b/src/syrupy/session.py @@ -66,10 +66,10 @@ class SnapshotSession: default_factory=lambda: defaultdict(set) ) - _queued_snapshot_writes: Dict[ + _queued_snapshot_writes: DefaultDict[ _QueuedWriteExtensionKey, Dict[_QueuedWriteTestLocationKey, "SerializedData"], - ] = field(default_factory=dict) + ] = field(default_factory=lambda: defaultdict(dict)) def _snapshot_write_queue_keys( self, @@ -92,9 +92,7 @@ def queue_snapshot_write( ext_key, loc_key = self._snapshot_write_queue_keys( extension, test_location, index ) - queue = self._queued_snapshot_writes.get(ext_key, {}) - queue[loc_key] = data - self._queued_snapshot_writes[ext_key] = queue + self._queued_snapshot_writes[ext_key][loc_key] = data def flush_snapshot_write_queue(self) -> None: for ( @@ -109,7 +107,7 @@ def flush_snapshot_write_queue(self) -> None: for (loc, index), data in queued_write.items() ], ) - self._queued_snapshot_writes = {} + self._queued_snapshot_writes.clear() def recall_snapshot( self, @@ -122,8 +120,7 @@ def recall_snapshot( ext_key, loc_key = self._snapshot_write_queue_keys( extension, test_location, index ) - queue = self._queued_snapshot_writes.get(ext_key, {}) - data = queue.get(loc_key) + data = self._queued_snapshot_writes[ext_key].get(loc_key) if data is not None: return data From a86f5d82c2cad6e216ac59b5688044003322363f Mon Sep 17 00:00:00 2001 From: Huon Wilson Date: Mon, 13 Jan 2025 16:21:33 +1100 Subject: [PATCH 8/8] Update comments --- src/syrupy/session.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/syrupy/session.py b/src/syrupy/session.py index 90739c19..cba70a65 100644 --- a/src/syrupy/session.py +++ b/src/syrupy/session.py @@ -66,6 +66,13 @@ class SnapshotSession: default_factory=lambda: defaultdict(set) ) + # For performance, we buffer snapshot writes in memory before flushing them to disk. In + # particular, we want to be able to write to a file on disk only once, rather than having to + # repeatedly rewrite it. + # + # That batching leads to using two layers of dicts here: the outer layer represents the + # extension/file-location pair that will be written, and the inner layer represents the + # snapshots within that, "indexed" to allow efficient recall. _queued_snapshot_writes: DefaultDict[ _QueuedWriteExtensionKey, Dict[_QueuedWriteTestLocationKey, "SerializedData"], @@ -124,7 +131,7 @@ def recall_snapshot( if data is not None: return data - # no queue, or no matching write, so just read the snapshot directly: + # No matching write queued, so just read the snapshot directly: return extension.read_snapshot( test_location=test_location, index=index, session_id=str(id(self)) )