From d8a4ff725548df5c4fe65698eb7e1e5216cf0be5 Mon Sep 17 00:00:00 2001 From: noahnu Date: Sat, 28 May 2022 18:55:59 -0400 Subject: [PATCH] fix: defer snapshot writes until end of session --- CONTRIBUTING.md | 5 ++ poetry.lock | 30 ++++++- pyproject.toml | 3 +- src/syrupy/assertion.py | 7 +- src/syrupy/extensions/amber/__init__.py | 18 ++-- src/syrupy/extensions/amber/serializer.py | 11 ++- src/syrupy/extensions/base.py | 100 +++++++++++++++------- src/syrupy/extensions/single_file.py | 2 +- src/syrupy/session.py | 26 ++++++ tasks/test.py | 14 ++- 10 files changed, 170 insertions(+), 46 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4c290028..dc9136bb 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -22,6 +22,7 @@ These are mostly guidelines, not rules. Use your best judgment, and feel free to - [Suggesting Enhancements](#suggesting-enhancements) - [Your First Code Contribution](#your-first-code-contribution) - [Pull Requests](#pull-requests) +- [Debugging](#debugging) [Styleguides](#styleguides) @@ -86,6 +87,10 @@ Creating a pull request uses our template using the GitHub web interface. Fill in the relevant sections, clearly linking the issue the change is attemping to resolve. +### Debugging + +`debugpy` is installed in local development. A VSCode launch config is provided. Run `inv test -v -d` to enable the debugger (`-d` for debug). It'll then wait for you to attach your VSCode debugging client. + ## Styleguides ### Commit Messages diff --git a/poetry.lock b/poetry.lock index fd9e4087..da3775d6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -184,6 +184,14 @@ sdist = ["setuptools_rust (>=0.11.4)"] ssh = ["bcrypt (>=3.1.5)"] test = ["pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-subtests", "pytest-xdist", "pretend", "iso8601", "pytz", "hypothesis (>=1.11.4,!=3.79.2)"] +[[package]] +name = "debugpy" +version = "1.6.0" +description = "An implementation of the Debug Adapter Protocol for Python" +category = "dev" +optional = false +python-versions = ">=3.7" + [[package]] name = "deprecated" version = "1.2.13" @@ -812,7 +820,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [metadata] lock-version = "1.1" python-versions = '>=3.7,<4' -content-hash = "38d485ec9a43e10e37db053e77697c267f582700d779cc7645da7ad51d5f9a34" +content-hash = "fe6a86c8cbcded7db49acd59aefbf8c42acbbea4f8eb7976864982fa8dd7cafa" [metadata.files] atomicwrites = [ @@ -1003,6 +1011,26 @@ cryptography = [ {file = "cryptography-37.0.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:3b8398b3d0efc420e777c40c16764d6870bcef2eb383df9c6dbb9ffe12c64452"}, {file = "cryptography-37.0.2.tar.gz", hash = "sha256:f224ad253cc9cea7568f49077007d2263efa57396a2f2f78114066fd54b5c68e"}, ] +debugpy = [ + {file = "debugpy-1.6.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:eb1946efac0c0c3d411cea0b5ac772fbde744109fd9520fb0c5a51979faf05ad"}, + {file = "debugpy-1.6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e3513399177dd37af4c1332df52da5da1d0c387e5927dc4c0709e26ee7302e8f"}, + {file = "debugpy-1.6.0-cp310-cp310-win32.whl", hash = "sha256:5c492235d6b68f879df3bdbdb01f25c15be15682665517c2c7d0420e5658d71f"}, + {file = "debugpy-1.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:40de9ba137d355538432209d05e0f5fe5d0498dce761c39119ad4b950b51db31"}, + {file = "debugpy-1.6.0-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:0d383b91efee57dbb923ba20801130cf60450a0eda60bce25bccd937de8e323a"}, + {file = "debugpy-1.6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:1ff853e60e77e1c16f85a31adb8360bb2d98ca588d7ed645b7f0985b240bdb5e"}, + {file = "debugpy-1.6.0-cp37-cp37m-win32.whl", hash = "sha256:8e972c717d95f56b6a3a7a29a5ede1ee8f2c3802f6f0e678203b0778eb322bf1"}, + {file = "debugpy-1.6.0-cp37-cp37m-win_amd64.whl", hash = "sha256:a8aaeb53e87225141fda7b9081bd87155c1debc13e2f5a532d341112d1983b65"}, + {file = "debugpy-1.6.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:132defb585b518955358321d0f42f6aa815aa15b432be27db654807707c70b2f"}, + {file = "debugpy-1.6.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8ee75844242b4537beb5899f3e60a578454d1f136b99e8d57ac424573797b94a"}, + {file = "debugpy-1.6.0-cp38-cp38-win32.whl", hash = "sha256:a65a2499761d47df3e9ea9567109be6e73d412e00ac3ffcf74839f3ddfcdf028"}, + {file = "debugpy-1.6.0-cp38-cp38-win_amd64.whl", hash = "sha256:bd980d533d0ddfc451e03a3bb32acb2900049fec39afc3425b944ebf0889be62"}, + {file = "debugpy-1.6.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:245c7789a012f86210847ec7ee9f38c30a30d4c2223c3e111829a76c9006a5d0"}, + {file = "debugpy-1.6.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:0e3aa2368883e83e7b689ddff3cafb595f7b711f6a065886b46a96a7fef874e7"}, + {file = "debugpy-1.6.0-cp39-cp39-win32.whl", hash = "sha256:72bcfa97f3afa0064afc77ab811f48ad4a06ac330f290b675082c24437730366"}, + {file = "debugpy-1.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:30abefefd2ff5a5481162d613cb70e60e2fa80a5eb4c994717c0f008ed25d2e1"}, + {file = "debugpy-1.6.0-py2.py3-none-any.whl", hash = "sha256:4de7777842da7e08652f2776c552070bbdd758557fdec73a15d7be0e4aab95ce"}, + {file = "debugpy-1.6.0.zip", hash = "sha256:7b79c40852991f7b6c3ea65845ed0f5f6b731c37f4f9ad9c61e2ab4bd48a9275"}, +] deprecated = [ {file = "Deprecated-1.2.13-py2.py3-none-any.whl", hash = "sha256:64756e3e14c8c5eea9795d93c524551432a0be75629f8f29e67ab8caf076c76d"}, {file = "Deprecated-1.2.13.tar.gz", hash = "sha256:43ac5335da90c31c24ba028af536a91d41d53f9e6901ddb021bcc572ce44e38d"}, diff --git a/pyproject.toml b/pyproject.toml index 8d780151..81890c67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ twine = '^4.0.0' semver = "^2.13.0" setuptools-scm = "^6.4.2" PyGithub = "^1.55" +debugpy = "^1.6.0" [tool.black] line-length = 88 @@ -102,5 +103,5 @@ source = ['./src'] exclude_lines = ['pragma: no-cover', 'if TYPE_CHECKING:', '@abstractmethod'] [build-system] -requires = ["poetry-core>=1.0.0"] +requires = ["poetry-core>=1.1.0b2"] build-backend = "poetry.core.masonry.api" diff --git a/src/syrupy/assertion.py b/src/syrupy/assertion.py index 48e57399..de37678e 100644 --- a/src/syrupy/assertion.py +++ b/src/syrupy/assertion.py @@ -262,7 +262,8 @@ def _assert(self, data: "SerializableData") -> bool: ) assertion_success = matches if not matches and self.update_snapshots: - self.extension.write_snapshot( + self.session.queue_snapshot_write( + extension=self.extension, data=serialized_data, index=self.index, ) @@ -297,6 +298,8 @@ def _post_assert(self) -> None: def _recall_data(self, index: "SnapshotIndex") -> Optional["SerializableData"]: try: - return self.extension.read_snapshot(index=index) + return self.extension.read_snapshot( + index=index, session_id=str(id(self.session)) + ) except SnapshotDoesNotExist: return None diff --git a/src/syrupy/extensions/amber/__init__.py b/src/syrupy/extensions/amber/__init__.py index ffbe42d2..ed8fa4af 100644 --- a/src/syrupy/extensions/amber/__init__.py +++ b/src/syrupy/extensions/amber/__init__.py @@ -1,3 +1,4 @@ +from functools import lru_cache from pathlib import Path from typing import ( TYPE_CHECKING, @@ -46,16 +47,23 @@ def _file_extension(self) -> str: def _read_snapshot_fossil(self, snapshot_location: str) -> "SnapshotFossil": return DataSerializer.read_file(snapshot_location) + @lru_cache() + def __cacheable_read_snapshot( + self, snapshot_location: str, cache_key: str + ) -> "SnapshotFossil": + return DataSerializer.read_file(snapshot_location) + def _read_snapshot_data_from_location( - self, snapshot_location: str, snapshot_name: str + self, snapshot_location: str, snapshot_name: str, session_id: str ) -> Optional["SerializableData"]: - snapshot = self._read_snapshot_fossil(snapshot_location).get(snapshot_name) + snapshots = self.__cacheable_read_snapshot( + snapshot_location=snapshot_location, cache_key=session_id + ) + snapshot = snapshots.get(snapshot_name) return snapshot.data if snapshot else None def _write_snapshot_fossil(self, *, snapshot_fossil: "SnapshotFossil") -> None: - snapshot_fossil_to_update = DataSerializer.read_file(snapshot_fossil.location) - snapshot_fossil_to_update.merge(snapshot_fossil) - DataSerializer.write_file(snapshot_fossil_to_update) + DataSerializer.write_file(snapshot_fossil, merge=True) __all__ = ["AmberSnapshotExtension", "DataSerializer"] diff --git a/src/syrupy/extensions/amber/serializer.py b/src/syrupy/extensions/amber/serializer.py index 6dc293a5..31f00907 100644 --- a/src/syrupy/extensions/amber/serializer.py +++ b/src/syrupy/extensions/amber/serializer.py @@ -1,4 +1,3 @@ -import functools import os from types import ( GeneratorType, @@ -71,11 +70,16 @@ class DataSerializer: _marker_crn: str = "\r\n" @classmethod - def write_file(cls, snapshot_fossil: "SnapshotFossil") -> None: + def write_file(cls, snapshot_fossil: "SnapshotFossil", merge: bool = False) -> None: """ - Writes the snapshot data into the snapshot file that be read later. + Writes the snapshot data into the snapshot file that can be read later. """ filepath = snapshot_fossil.location + if merge: + base_snapshot = cls.read_file(filepath) + base_snapshot.merge(snapshot_fossil) + snapshot_fossil = base_snapshot + with open(filepath, "w", encoding=TEXT_ENCODING, newline=None) as f: for snapshot in sorted(snapshot_fossil, key=lambda s: s.name): snapshot_data = str(snapshot.data) @@ -86,7 +90,6 @@ def write_file(cls, snapshot_fossil: "SnapshotFossil") -> None: f.write(f"\n{cls._marker_divider}\n") @classmethod - @functools.lru_cache() def read_file(cls, filepath: str) -> "SnapshotFossil": """ Read the raw snapshot data (str) from the snapshot file into a dict diff --git a/src/syrupy/extensions/base.py b/src/syrupy/extensions/base.py index 57298b73..e6bc0ca7 100644 --- a/src/syrupy/extensions/base.py +++ b/src/syrupy/extensions/base.py @@ -3,6 +3,7 @@ ABC, abstractmethod, ) +from collections import defaultdict from difflib import ndiff from gettext import gettext from itertools import zip_longest @@ -10,11 +11,13 @@ from typing import ( TYPE_CHECKING, Callable, + DefaultDict, Dict, Iterator, List, Optional, Set, + Tuple, ) from syrupy.constants import ( @@ -115,7 +118,9 @@ def discover_snapshots(self) -> "SnapshotFossils": return discovered - def read_snapshot(self, *, index: "SnapshotIndex") -> "SerializedData": + def read_snapshot( + self, *, index: "SnapshotIndex", session_id: str + ) -> "SerializedData": """ Utility method for reading the contents of a snapshot assertion. Will call `_pre_read`, then perform `read` and finally `post_read`, @@ -129,7 +134,9 @@ def read_snapshot(self, *, index: "SnapshotIndex") -> "SerializedData": snapshot_location = self.get_location(index=index) snapshot_name = self.get_snapshot_name(index=index) snapshot_data = self._read_snapshot_data_from_location( - snapshot_location=snapshot_location, snapshot_name=snapshot_name + snapshot_location=snapshot_location, + snapshot_name=snapshot_name, + session_id=session_id, ) if snapshot_data is None: raise SnapshotDoesNotExist() @@ -145,33 +152,66 @@ def write_snapshot(self, *, data: "SerializedData", index: "SnapshotIndex") -> N This method is _final_, do not override. You can override `_write_snapshot_fossil` in a subclass to change behaviour. """ - self._pre_write(data=data, index=index) - snapshot_location = self.get_location(index=index) - if not self.test_location.matches_snapshot_location(snapshot_location): - warning_msg = gettext( - "{line_end}Can not relate snapshot location '{}' to the test location." - "{line_end}Consider adding '{}' to the generated location." - ).format( - snapshot_location, - self.test_location.filename, - line_end="\n", - ) - warnings.warn(warning_msg) - snapshot_name = self.get_snapshot_name(index=index) - if not self.test_location.matches_snapshot_name(snapshot_name): - warning_msg = gettext( - "{line_end}Can not relate snapshot name '{}' to the test location." - "{line_end}Consider adding '{}' to the generated name." - ).format( - snapshot_name, - self.test_location.testname, - line_end="\n", - ) - warnings.warn(warning_msg) - snapshot_fossil = SnapshotFossil(location=snapshot_location) - snapshot_fossil.add(Snapshot(name=snapshot_name, data=data)) - self._write_snapshot_fossil(snapshot_fossil=snapshot_fossil) - self._post_write(data=data, index=index) + self.write_snapshot_batch(snapshots=[(data, index)]) + + def write_snapshot_batch( + self, *, snapshots: List[Tuple["SerializedData", "SnapshotIndex"]] + ) -> None: + """ + Utility method for writing the contents of multiple snapshot assertions. + Will call `_pre_write` per snapshot, then perform `write` per snapshot + and finally `_post_write`. + + This method is _final_, do not override. You can override + `_write_snapshot_fossil` in a subclass to change behaviour. + """ + # First we group by location since it'll let us batch by file on disk. + # Not as useful for single file snapshots, but useful for the standard + # Amber extension. + locations: DefaultDict[str, List["Snapshot"]] = defaultdict(list) + for data, index in snapshots: + location = self.get_location(index=index) + snapshot_name = self.get_snapshot_name(index=index) + locations[location].append(Snapshot(name=snapshot_name, data=data)) + + # Is there a better place to do the pre-writes? + # Or can we remove the pre-write concept altogether? + self._pre_write(data=data, index=index) + + for location, location_snapshots in locations.items(): + snapshot_fossil = SnapshotFossil(location=location) + + if not self.test_location.matches_snapshot_location(location): + warning_msg = gettext( + "{line_end}Can not relate snapshot location '{}' " + "to the test location.{line_end}" + "Consider adding '{}' to the generated location." + ).format( + location, + self.test_location.filename, + line_end="\n", + ) + warnings.warn(warning_msg) + + for snapshot in location_snapshots: + snapshot_fossil.add(snapshot) + + if not self.test_location.matches_snapshot_name(snapshot.name): + warning_msg = gettext( + "{line_end}Can not relate snapshot name '{}' " + "to the test location.{line_end}" + "Consider adding '{}' to the generated name." + ).format( + snapshot.name, + self.test_location.testname, + line_end="\n", + ) + warnings.warn(warning_msg) + + self._write_snapshot_fossil(snapshot_fossil=snapshot_fossil) + + for data, index in snapshots: + self._post_write(data=data, index=index) @abstractmethod def delete_snapshots( @@ -206,7 +246,7 @@ def _read_snapshot_fossil(self, *, snapshot_location: str) -> "SnapshotFossil": @abstractmethod def _read_snapshot_data_from_location( - self, *, snapshot_location: str, snapshot_name: str + self, *, snapshot_location: str, snapshot_name: str, session_id: str ) -> Optional["SerializedData"]: """ Get only the snapshot data from location for assertion diff --git a/src/syrupy/extensions/single_file.py b/src/syrupy/extensions/single_file.py index eee1ebc8..e80a444f 100644 --- a/src/syrupy/extensions/single_file.py +++ b/src/syrupy/extensions/single_file.py @@ -77,7 +77,7 @@ def _read_snapshot_fossil(self, *, snapshot_location: str) -> "SnapshotFossil": return snapshot_fossil def _read_snapshot_data_from_location( - self, *, snapshot_location: str, snapshot_name: str + self, *, snapshot_location: str, snapshot_name: str, session_id: str ) -> Optional["SerializableData"]: try: with open( diff --git a/src/syrupy/session.py b/src/syrupy/session.py index 014d5e72..9f5caf54 100644 --- a/src/syrupy/session.py +++ b/src/syrupy/session.py @@ -13,6 +13,7 @@ List, Optional, Set, + Tuple, ) import pytest @@ -20,6 +21,10 @@ from .constants import EXIT_STATUS_FAIL_UNUSED from .data import SnapshotFossils from .report import SnapshotReport +from .types import ( + SerializedData, + SnapshotIndex, +) if TYPE_CHECKING: from .assertion import SnapshotAssertion @@ -43,6 +48,26 @@ class SnapshotSession: default_factory=lambda: defaultdict(set) ) + _queued_snapshot_writes: Dict[ + "AbstractSyrupyExtension", List[Tuple["SerializedData", "SnapshotIndex"]] + ] = field(default_factory=dict) + + def queue_snapshot_write( + self, + extension: "AbstractSyrupyExtension", + data: "SerializedData", + index: "SnapshotIndex", + ) -> None: + queue = self._queued_snapshot_writes.get(extension, []) + queue.append((data, index)) + self._queued_snapshot_writes[extension] = queue + + def flush_snapshot_write_queue(self) -> None: + for extension, queued_write in self._queued_snapshot_writes.items(): + if queued_write: + extension.write_snapshot_batch(snapshots=queued_write) + self._queued_snapshot_writes = {} + @property def update_snapshots(self) -> bool: return bool(self.pytest_session.config.option.update_snapshots) @@ -72,6 +97,7 @@ def ran_item(self, nodeid: str) -> None: def finish(self) -> int: exitstatus = 0 + self.flush_snapshot_write_queue() self.report = SnapshotReport( base_dir=self.pytest_session.config.rootdir, collected_items=self._collected_items, diff --git a/tasks/test.py b/tasks/test.py index b4a68b75..55552ed0 100644 --- a/tasks/test.py +++ b/tasks/test.py @@ -28,11 +28,21 @@ def test( "-s -vv": verbose, f"-k {test_pattern}": test_pattern, "--snapshot-update": update_snapshots, - "--pdb": debug, } coverage_module = "coverage run -m " if coverage else "" test_flags = " ".join(flag for flag, enabled in flags.items() if enabled) - ctx_run(ctx, f"python -m {coverage_module}pytest {test_flags} ./tests") + + if debug and coverage: + raise Exception("The debug and coverage options are mutually exclusive.") + + if debug: + ctx_run( + ctx, + f"python -m debugpy --listen 5678 --wait-for-client -m pytest {test_flags} ./tests", + ) + else: + ctx_run(ctx, f"python -m {coverage_module}pytest {test_flags} ./tests") + if coverage: if not os.environ.get("CI") or not os.environ.get("CODECOV_TOKEN"): ctx_run(ctx, "coverage report")