From 0d3a9a65803045ba53800baadd8c91d1784ff5b1 Mon Sep 17 00:00:00 2001 From: Noah Date: Thu, 1 Dec 2022 17:49:41 -0500 Subject: [PATCH] refactor: write performance improvements, api clarity (#645) * fix: group snapshot writes by extension class * refactor: rename PyTestLocation.filename to .basename BREAKING CHANGE: PyTestLocation.filename has been renamed to .basename * refactor: add test_location kwarg to get_snapshot_name * refactor: get_snapshot_name is now static as a classmethod * refactor: remove pre and post read/write hooks BREAKING CHANGE: Pre and post read/write hooks have been removed without replacement to make internal refactor simpler. Please open a GitHub issue if you have a use case for these hooks. * refactor: rename Fossil to Collection BREAKING CHANGE: The term 'fossil' has been replaced by the clearer term 'collection'. * refactor: pass test_location to read_snapshot * refactor: remove singular write_snapshot method * refactor: dirname property to method * refactor: pass test_location to discover_snapshots * refactor: remove usage of self.test_location * refactor: make write_snapshot a classmethod * refactor: do not instantiate extension with test_location BREAKING CHANGE: Numerous instance methods have been refactored as classmethods. --- CONTRIBUTING.md | 24 ++ src/syrupy/assertion.py | 15 +- src/syrupy/constants.py | 4 +- src/syrupy/data.py | 52 ++-- src/syrupy/extensions/amber/__init__.py | 27 +- src/syrupy/extensions/amber/serializer.py | 22 +- src/syrupy/extensions/base.py | 244 ++++++++---------- src/syrupy/extensions/image.py | 8 +- src/syrupy/extensions/json/__init__.py | 5 +- src/syrupy/extensions/single_file.py | 87 ++++--- src/syrupy/location.py | 4 +- src/syrupy/report.py | 133 +++++----- src/syrupy/session.py | 52 ++-- tests/examples/test_custom_image_extension.py | 4 +- .../test_custom_snapshot_directory.py | 9 +- .../test_custom_snapshot_directory_2.py | 7 +- tests/examples/test_custom_snapshot_name.py | 10 +- .../test_snapshot_option_update.py | 9 +- .../test_snapshot_outside_directory.py | 4 +- .../test_snapshot_use_extension.py | 14 +- tests/syrupy/extensions/test_single_file.py | 14 +- tests/syrupy/test_location.py | 2 +- 22 files changed, 399 insertions(+), 351 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index dc9136bb..eae37ba1 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -91,6 +91,30 @@ Fill in the relevant sections, clearly linking the issue the change is attemping `debugpy` is installed in local development. A VSCode launch config is provided. Run `inv test -v -d` to enable the debugger (`-d` for debug). It'll then wait for you to attach your VSCode debugging client. +#### Debugging Performance Issues + +You can run `inv benchmark` to run the full benchmark suite. Alternatively, write a test file, e.g.: + +```py +# test_performance.py +import pytest +import os + +SIZE = int(os.environ.get("SIZE", 1000)) + +@pytest.mark.parametrize("x", range(SIZE)) +def test_performance(x, snapshot): + assert x == snapshot +``` + +and then run: + +```sh +SIZE=1000 python -m cProfile -s cumtime -m pytest test_performance.py --snapshot-update -s > profile.log +``` + +See the cProfile docs for metric sorting options. + ## Styleguides ### Commit Messages diff --git a/src/syrupy/assertion.py b/src/syrupy/assertion.py index de37678e..c32fa695 100644 --- a/src/syrupy/assertion.py +++ b/src/syrupy/assertion.py @@ -94,7 +94,7 @@ def __post_init__(self) -> None: def __init_extension( self, extension_class: Type["AbstractSyrupyExtension"] ) -> "AbstractSyrupyExtension": - return extension_class(test_location=self.test_location) + return extension_class() @property def extension(self) -> "AbstractSyrupyExtension": @@ -238,8 +238,12 @@ def __eq__(self, other: "SerializableData") -> bool: return self._assert(other) def _assert(self, data: "SerializableData") -> bool: - snapshot_location = self.extension.get_location(index=self.index) - snapshot_name = self.extension.get_snapshot_name(index=self.index) + snapshot_location = self.extension.get_location( + test_location=self.test_location, index=self.index + ) + snapshot_name = self.extension.get_snapshot_name( + test_location=self.test_location, index=self.index + ) snapshot_data: Optional["SerializedData"] = None serialized_data: Optional["SerializedData"] = None matches = False @@ -264,6 +268,7 @@ def _assert(self, data: "SerializableData") -> bool: if not matches and self.update_snapshots: self.session.queue_snapshot_write( extension=self.extension, + test_location=self.test_location, data=serialized_data, index=self.index, ) @@ -299,7 +304,9 @@ def _post_assert(self) -> None: def _recall_data(self, index: "SnapshotIndex") -> Optional["SerializableData"]: try: return self.extension.read_snapshot( - index=index, session_id=str(id(self.session)) + test_location=self.test_location, + index=index, + session_id=str(id(self.session)), ) except SnapshotDoesNotExist: return None diff --git a/src/syrupy/constants.py b/src/syrupy/constants.py index 7c7db338..0503ff16 100644 --- a/src/syrupy/constants.py +++ b/src/syrupy/constants.py @@ -1,6 +1,6 @@ SNAPSHOT_DIRNAME = "__snapshots__" -SNAPSHOT_EMPTY_FOSSIL_KEY = "empty snapshot fossil" -SNAPSHOT_UNKNOWN_FOSSIL_KEY = "unknown snapshot fossil" +SNAPSHOT_EMPTY_FOSSIL_KEY = "empty snapshot collection" +SNAPSHOT_UNKNOWN_FOSSIL_KEY = "unknown snapshot collection" EXIT_STATUS_FAIL_UNUSED = 1 diff --git a/src/syrupy/data.py b/src/syrupy/data.py index 18fb41b8..b74f2141 100644 --- a/src/syrupy/data.py +++ b/src/syrupy/data.py @@ -36,7 +36,7 @@ class SnapshotUnknown(Snapshot): @dataclass -class SnapshotFossil: +class SnapshotCollection: """A collection of snapshots at a save location""" location: str @@ -54,8 +54,8 @@ def add(self, snapshot: "Snapshot") -> None: if snapshot.name != SNAPSHOT_EMPTY_FOSSIL_KEY: self.remove(SNAPSHOT_EMPTY_FOSSIL_KEY) - def merge(self, snapshot_fossil: "SnapshotFossil") -> None: - for snapshot in snapshot_fossil: + def merge(self, snapshot_collection: "SnapshotCollection") -> None: + for snapshot in snapshot_collection: self.add(snapshot) def remove(self, snapshot_name: str) -> None: @@ -69,8 +69,8 @@ def __iter__(self) -> Iterator["Snapshot"]: @dataclass -class SnapshotEmptyFossil(SnapshotFossil): - """This is a saved fossil that is known to be empty and thus can be removed""" +class SnapshotEmptyCollection(SnapshotCollection): + """This is a saved collection that is known to be empty and thus can be removed""" _snapshots: Dict[str, "Snapshot"] = field( default_factory=lambda: {SnapshotEmpty().name: SnapshotEmpty()} @@ -82,8 +82,8 @@ def has_snapshots(self) -> bool: @dataclass -class SnapshotUnknownFossil(SnapshotFossil): - """This is a saved fossil that is unclaimed by any extension currently in use""" +class SnapshotUnknownCollection(SnapshotCollection): + """This is a saved collection that is unclaimed by any extension currently in use""" _snapshots: Dict[str, "Snapshot"] = field( default_factory=lambda: {SnapshotUnknown().name: SnapshotUnknown()} @@ -91,33 +91,33 @@ class SnapshotUnknownFossil(SnapshotFossil): @dataclass -class SnapshotFossils: - _snapshot_fossils: Dict[str, "SnapshotFossil"] = field(default_factory=dict) +class SnapshotCollections: + _snapshot_collections: Dict[str, "SnapshotCollection"] = field(default_factory=dict) - def get(self, location: str) -> Optional["SnapshotFossil"]: - return self._snapshot_fossils.get(location) + def get(self, location: str) -> Optional["SnapshotCollection"]: + return self._snapshot_collections.get(location) - def add(self, snapshot_fossil: "SnapshotFossil") -> None: - self._snapshot_fossils[snapshot_fossil.location] = snapshot_fossil + def add(self, snapshot_collection: "SnapshotCollection") -> None: + self._snapshot_collections[snapshot_collection.location] = snapshot_collection - def update(self, snapshot_fossil: "SnapshotFossil") -> None: - snapshot_fossil_to_update = self.get(snapshot_fossil.location) - if snapshot_fossil_to_update is None: - snapshot_fossil_to_update = SnapshotFossil( - location=snapshot_fossil.location + def update(self, snapshot_collection: "SnapshotCollection") -> None: + snapshot_collection_to_update = self.get(snapshot_collection.location) + if snapshot_collection_to_update is None: + snapshot_collection_to_update = SnapshotCollection( + location=snapshot_collection.location ) - self.add(snapshot_fossil_to_update) - snapshot_fossil_to_update.merge(snapshot_fossil) + self.add(snapshot_collection_to_update) + snapshot_collection_to_update.merge(snapshot_collection) - def merge(self, snapshot_fossils: "SnapshotFossils") -> None: - for snapshot_fossil in snapshot_fossils: - self.update(snapshot_fossil) + def merge(self, snapshot_collections: "SnapshotCollections") -> None: + for snapshot_collection in snapshot_collections: + self.update(snapshot_collection) - def __iter__(self) -> Iterator["SnapshotFossil"]: - return iter(self._snapshot_fossils.values()) + def __iter__(self) -> Iterator["SnapshotCollection"]: + return iter(self._snapshot_collections.values()) def __contains__(self, key: str) -> bool: - return key in self._snapshot_fossils + return key in self._snapshot_collections @dataclass diff --git a/src/syrupy/extensions/amber/__init__.py b/src/syrupy/extensions/amber/__init__.py index 91efdc11..f6ca5773 100644 --- a/src/syrupy/extensions/amber/__init__.py +++ b/src/syrupy/extensions/amber/__init__.py @@ -7,7 +7,7 @@ Set, ) -from syrupy.data import SnapshotFossil +from syrupy.data import SnapshotCollection from syrupy.extensions.base import AbstractSyrupyExtension from .serializer import DataSerializer @@ -21,6 +21,8 @@ class AmberSnapshotExtension(AbstractSyrupyExtension): An amber snapshot file stores data in the following format: """ + _file_extension = "ambr" + def serialize(self, data: "SerializableData", **kwargs: Any) -> str: """ Returns the serialized form of 'data' to be compared @@ -31,27 +33,23 @@ def serialize(self, data: "SerializableData", **kwargs: Any) -> str: def delete_snapshots( self, snapshot_location: str, snapshot_names: Set[str] ) -> None: - snapshot_fossil_to_update = DataSerializer.read_file(snapshot_location) + snapshot_collection_to_update = DataSerializer.read_file(snapshot_location) for snapshot_name in snapshot_names: - snapshot_fossil_to_update.remove(snapshot_name) + snapshot_collection_to_update.remove(snapshot_name) - if snapshot_fossil_to_update.has_snapshots: - DataSerializer.write_file(snapshot_fossil_to_update) + if snapshot_collection_to_update.has_snapshots: + DataSerializer.write_file(snapshot_collection_to_update) else: Path(snapshot_location).unlink() - @property - def _file_extension(self) -> str: - return "ambr" - - def _read_snapshot_fossil(self, snapshot_location: str) -> "SnapshotFossil": + def _read_snapshot_collection(self, snapshot_location: str) -> "SnapshotCollection": return DataSerializer.read_file(snapshot_location) @staticmethod @lru_cache() def __cacheable_read_snapshot( snapshot_location: str, cache_key: str - ) -> "SnapshotFossil": + ) -> "SnapshotCollection": return DataSerializer.read_file(snapshot_location) def _read_snapshot_data_from_location( @@ -63,8 +61,11 @@ def _read_snapshot_data_from_location( snapshot = snapshots.get(snapshot_name) return snapshot.data if snapshot else None - def _write_snapshot_fossil(self, *, snapshot_fossil: "SnapshotFossil") -> None: - DataSerializer.write_file(snapshot_fossil, merge=True) + @classmethod + def _write_snapshot_collection( + cls, *, snapshot_collection: "SnapshotCollection" + ) -> None: + DataSerializer.write_file(snapshot_collection, merge=True) __all__ = ["AmberSnapshotExtension", "DataSerializer"] diff --git a/src/syrupy/extensions/amber/serializer.py b/src/syrupy/extensions/amber/serializer.py index 31f00907..fd996ed4 100644 --- a/src/syrupy/extensions/amber/serializer.py +++ b/src/syrupy/extensions/amber/serializer.py @@ -22,7 +22,7 @@ ) from syrupy.data import ( Snapshot, - SnapshotFossil, + SnapshotCollection, ) if TYPE_CHECKING: @@ -70,18 +70,20 @@ class DataSerializer: _marker_crn: str = "\r\n" @classmethod - def write_file(cls, snapshot_fossil: "SnapshotFossil", merge: bool = False) -> None: + def write_file( + cls, snapshot_collection: "SnapshotCollection", merge: bool = False + ) -> None: """ Writes the snapshot data into the snapshot file that can be read later. """ - filepath = snapshot_fossil.location + filepath = snapshot_collection.location if merge: base_snapshot = cls.read_file(filepath) - base_snapshot.merge(snapshot_fossil) - snapshot_fossil = base_snapshot + base_snapshot.merge(snapshot_collection) + snapshot_collection = base_snapshot with open(filepath, "w", encoding=TEXT_ENCODING, newline=None) as f: - for snapshot in sorted(snapshot_fossil, key=lambda s: s.name): + for snapshot in sorted(snapshot_collection, key=lambda s: s.name): snapshot_data = str(snapshot.data) if snapshot_data is not None: f.write(f"{cls._marker_name} {snapshot.name}\n") @@ -90,7 +92,7 @@ def write_file(cls, snapshot_fossil: "SnapshotFossil", merge: bool = False) -> N f.write(f"\n{cls._marker_divider}\n") @classmethod - def read_file(cls, filepath: str) -> "SnapshotFossil": + def read_file(cls, filepath: str) -> "SnapshotCollection": """ Read the raw snapshot data (str) from the snapshot file into a dict of snapshot name to raw data. This does not attempt any deserialization @@ -98,7 +100,7 @@ def read_file(cls, filepath: str) -> "SnapshotFossil": """ name_marker_len = len(cls._marker_name) indent_len = len(cls._indent) - snapshot_fossil = SnapshotFossil(location=filepath) + snapshot_collection = SnapshotCollection(location=filepath) try: with open(filepath, "r", encoding=TEXT_ENCODING, newline=None) as f: test_name = None @@ -112,7 +114,7 @@ def read_file(cls, filepath: str) -> "SnapshotFossil": if line.startswith(cls._indent): snapshot_data += line[indent_len:] elif line.startswith(cls._marker_divider) and snapshot_data: - snapshot_fossil.add( + snapshot_collection.add( Snapshot( name=test_name, data=snapshot_data.rstrip(os.linesep), @@ -121,7 +123,7 @@ def read_file(cls, filepath: str) -> "SnapshotFossil": except FileNotFoundError: pass - return snapshot_fossil + return snapshot_collection @classmethod def serialize( diff --git a/src/syrupy/extensions/base.py b/src/syrupy/extensions/base.py index e6bc0ca7..306f4579 100644 --- a/src/syrupy/extensions/base.py +++ b/src/syrupy/extensions/base.py @@ -3,7 +3,6 @@ ABC, abstractmethod, ) -from collections import defaultdict from difflib import ndiff from gettext import gettext from itertools import zip_longest @@ -11,7 +10,6 @@ from typing import ( TYPE_CHECKING, Callable, - DefaultDict, Dict, Iterator, List, @@ -30,9 +28,9 @@ from syrupy.data import ( DiffedLine, Snapshot, - SnapshotEmptyFossil, - SnapshotFossil, - SnapshotFossils, + SnapshotCollection, + SnapshotCollections, + SnapshotEmptyCollection, ) from syrupy.exceptions import SnapshotDoesNotExist from syrupy.terminal import ( @@ -76,142 +74,137 @@ def serialize( raise NotImplementedError -class SnapshotFossilizer(ABC): - @property - @abstractmethod - def test_location(self) -> "PyTestLocation": - raise NotImplementedError +class SnapshotCollectionStorage(ABC): + _file_extension = "" - def get_snapshot_name(self, *, index: "SnapshotIndex" = 0) -> str: + @classmethod + def get_snapshot_name( + cls, *, test_location: "PyTestLocation", index: "SnapshotIndex" = 0 + ) -> str: """Get the snapshot name for the assertion index in a test location""" index_suffix = "" if isinstance(index, (str,)): index_suffix = f"[{index}]" elif index: index_suffix = f".{index}" - return f"{self.test_location.snapshot_name}{index_suffix}" + return f"{test_location.snapshot_name}{index_suffix}" - def get_location(self, *, index: "SnapshotIndex") -> str: - """Returns full location where snapshot data is stored.""" - basename = self._get_file_basename(index=index) - fileext = f".{self._file_extension}" if self._file_extension else "" - return str(Path(self._dirname).joinpath(f"{basename}{fileext}")) + @classmethod + def get_location( + cls, *, test_location: "PyTestLocation", index: "SnapshotIndex" + ) -> str: + """Returns full filepath where snapshot data is stored.""" + basename = cls._get_file_basename(test_location=test_location, index=index) + fileext = f".{cls._file_extension}" if cls._file_extension else "" + return str( + Path(cls.dirname(test_location=test_location)).joinpath( + f"{basename}{fileext}" + ) + ) def is_snapshot_location(self, *, location: str) -> bool: """Checks if supplied location is valid for this snapshot extension""" return location.endswith(self._file_extension) - def discover_snapshots(self) -> "SnapshotFossils": + def discover_snapshots( + self, *, test_location: "PyTestLocation" + ) -> "SnapshotCollections": """ - Returns all snapshot fossils in test site + Returns all snapshot collections in test site """ - discovered: "SnapshotFossils" = SnapshotFossils() - for filepath in walk_snapshot_dir(self._dirname): + discovered: "SnapshotCollections" = SnapshotCollections() + for filepath in walk_snapshot_dir(self.dirname(test_location=test_location)): if self.is_snapshot_location(location=filepath): - snapshot_fossil = self._read_snapshot_fossil(snapshot_location=filepath) - if not snapshot_fossil.has_snapshots: - snapshot_fossil = SnapshotEmptyFossil(location=filepath) + snapshot_collection = self._read_snapshot_collection( + snapshot_location=filepath + ) + if not snapshot_collection.has_snapshots: + snapshot_collection = SnapshotEmptyCollection(location=filepath) else: - snapshot_fossil = SnapshotFossil(location=filepath) + snapshot_collection = SnapshotCollection(location=filepath) - discovered.add(snapshot_fossil) + discovered.add(snapshot_collection) return discovered def read_snapshot( - self, *, index: "SnapshotIndex", session_id: str + self, + *, + test_location: "PyTestLocation", + index: "SnapshotIndex", + session_id: str, ) -> "SerializedData": """ - Utility method for reading the contents of a snapshot assertion. - Will call `_pre_read`, then perform `read` and finally `post_read`, - returning the contents parsed from the `read` method. - This method is _final_, do not override. You can override `_read_snapshot_data_from_location` in a subclass to change behaviour. """ - try: - self._pre_read(index=index) - snapshot_location = self.get_location(index=index) - snapshot_name = self.get_snapshot_name(index=index) - snapshot_data = self._read_snapshot_data_from_location( - snapshot_location=snapshot_location, - snapshot_name=snapshot_name, - session_id=session_id, - ) - if snapshot_data is None: - raise SnapshotDoesNotExist() - return snapshot_data - finally: - self._post_read(index=index) - - def write_snapshot(self, *, data: "SerializedData", index: "SnapshotIndex") -> None: - """ - Utility method for writing the contents of a snapshot assertion. - Will call `_pre_write`, then perform `write` and finally `_post_write`. - - This method is _final_, do not override. You can override - `_write_snapshot_fossil` in a subclass to change behaviour. - """ - self.write_snapshot_batch(snapshots=[(data, index)]) + snapshot_location = self.get_location(test_location=test_location, index=index) + snapshot_name = self.get_snapshot_name(test_location=test_location, index=index) + snapshot_data = self._read_snapshot_data_from_location( + snapshot_location=snapshot_location, + snapshot_name=snapshot_name, + session_id=session_id, + ) + if snapshot_data is None: + raise SnapshotDoesNotExist() + return snapshot_data - def write_snapshot_batch( - self, *, snapshots: List[Tuple["SerializedData", "SnapshotIndex"]] + @classmethod + def write_snapshot( + cls, + *, + snapshot_location: str, + snapshots: List[Tuple["SerializedData", "PyTestLocation", "SnapshotIndex"]], ) -> None: """ - Utility method for writing the contents of multiple snapshot assertions. - Will call `_pre_write` per snapshot, then perform `write` per snapshot - and finally `_post_write`. - This method is _final_, do not override. You can override - `_write_snapshot_fossil` in a subclass to change behaviour. + `_write_snapshot_collection` in a subclass to change behaviour. """ + if not snapshots: + return + # First we group by location since it'll let us batch by file on disk. # Not as useful for single file snapshots, but useful for the standard # Amber extension. - locations: DefaultDict[str, List["Snapshot"]] = defaultdict(list) - for data, index in snapshots: - location = self.get_location(index=index) - snapshot_name = self.get_snapshot_name(index=index) - locations[location].append(Snapshot(name=snapshot_name, data=data)) - - # Is there a better place to do the pre-writes? - # Or can we remove the pre-write concept altogether? - self._pre_write(data=data, index=index) - - for location, location_snapshots in locations.items(): - snapshot_fossil = SnapshotFossil(location=location) + snapshot_collection = SnapshotCollection(location=snapshot_location) + for data, test_location, index in snapshots: + snapshot_name = cls.get_snapshot_name( + test_location=test_location, index=index + ) + snapshot = Snapshot(name=snapshot_name, data=data) + snapshot_collection.add(snapshot) - if not self.test_location.matches_snapshot_location(location): + if not test_location.matches_snapshot_location(snapshot_location): warning_msg = gettext( "{line_end}Can not relate snapshot location '{}' " "to the test location.{line_end}" "Consider adding '{}' to the generated location." ).format( - location, - self.test_location.filename, + snapshot_location, + test_location.basename, line_end="\n", ) warnings.warn(warning_msg) - for snapshot in location_snapshots: - snapshot_fossil.add(snapshot) - - if not self.test_location.matches_snapshot_name(snapshot.name): - warning_msg = gettext( - "{line_end}Can not relate snapshot name '{}' " - "to the test location.{line_end}" - "Consider adding '{}' to the generated name." - ).format( - snapshot.name, - self.test_location.testname, - line_end="\n", - ) - warnings.warn(warning_msg) + if not test_location.matches_snapshot_name(snapshot.name): + warning_msg = gettext( + "{line_end}Can not relate snapshot name '{}' " + "to the test location.{line_end}" + "Consider adding '{}' to the generated name." + ).format( + snapshot.name, + test_location.testname, + line_end="\n", + ) + warnings.warn(warning_msg) - self._write_snapshot_fossil(snapshot_fossil=snapshot_fossil) + # Ensures the folder path for the snapshot file exists. + try: + Path(snapshot_location).parent.mkdir(parents=True) + except FileExistsError: + pass - for data, index in snapshots: - self._post_write(data=data, index=index) + cls._write_snapshot_collection(snapshot_collection=snapshot_collection) @abstractmethod def delete_snapshots( @@ -223,24 +216,12 @@ def delete_snapshots( """ raise NotImplementedError - def _pre_read(self, *, index: "SnapshotIndex" = 0) -> None: - pass - - def _post_read(self, *, index: "SnapshotIndex" = 0) -> None: - pass - - def _pre_write(self, *, data: "SerializedData", index: "SnapshotIndex" = 0) -> None: - self.__ensure_snapshot_dir(index=index) - - def _post_write( - self, *, data: "SerializedData", index: "SnapshotIndex" = 0 - ) -> None: - pass - @abstractmethod - def _read_snapshot_fossil(self, *, snapshot_location: str) -> "SnapshotFossil": + def _read_snapshot_collection( + self, *, snapshot_location: str + ) -> "SnapshotCollection": """ - Read the snapshot location and construct a snapshot fossil object + Read the snapshot location and construct a snapshot collection object """ raise NotImplementedError @@ -253,35 +234,27 @@ def _read_snapshot_data_from_location( """ raise NotImplementedError + @classmethod @abstractmethod - def _write_snapshot_fossil(self, *, snapshot_fossil: "SnapshotFossil") -> None: + def _write_snapshot_collection( + cls, *, snapshot_collection: "SnapshotCollection" + ) -> None: """ - Adds the snapshot data to the snapshots in fossil location + Adds the snapshot data to the snapshots in collection location """ raise NotImplementedError - @property - def _dirname(self) -> str: - test_dir = Path(self.test_location.filepath).parent + @classmethod + def dirname(cls, *, test_location: "PyTestLocation") -> str: + test_dir = Path(test_location.filepath).parent return str(test_dir.joinpath(SNAPSHOT_DIRNAME)) - @property - @abstractmethod - def _file_extension(self) -> str: - raise NotImplementedError - - def _get_file_basename(self, *, index: "SnapshotIndex") -> str: + @classmethod + def _get_file_basename( + cls, *, test_location: "PyTestLocation", index: "SnapshotIndex" + ) -> str: """Returns file basename without extension. Used to create full filepath.""" - return self.test_location.filename - - def __ensure_snapshot_dir(self, *, index: "SnapshotIndex") -> None: - """ - Ensures the folder path for the snapshot file exists. - """ - try: - Path(self.get_location(index=index)).parent.mkdir(parents=True) - except FileExistsError: - pass + return test_location.basename class SnapshotReporter(ABC): @@ -445,11 +418,6 @@ def matches( class AbstractSyrupyExtension( - SnapshotSerializer, SnapshotFossilizer, SnapshotReporter, SnapshotComparator + SnapshotSerializer, SnapshotCollectionStorage, SnapshotReporter, SnapshotComparator ): - def __init__(self, test_location: "PyTestLocation"): - self._test_location = test_location - - @property - def test_location(self) -> "PyTestLocation": - return self._test_location + pass diff --git a/src/syrupy/extensions/image.py b/src/syrupy/extensions/image.py index d900f333..6faf3d17 100644 --- a/src/syrupy/extensions/image.py +++ b/src/syrupy/extensions/image.py @@ -12,15 +12,11 @@ class PNGImageSnapshotExtension(SingleFileSnapshotExtension): - @property - def _file_extension(self) -> str: - return "png" + _file_extension = "png" class SVGImageSnapshotExtension(SingleFileSnapshotExtension): - @property - def _file_extension(self) -> str: - return "svg" + _file_extension = "svg" def serialize(self, data: "SerializableData", **kwargs: Any) -> bytes: return str(data).encode(TEXT_ENCODING) diff --git a/src/syrupy/extensions/json/__init__.py b/src/syrupy/extensions/json/__init__.py index 6d4a49a9..0b9a9540 100644 --- a/src/syrupy/extensions/json/__init__.py +++ b/src/syrupy/extensions/json/__init__.py @@ -31,10 +31,7 @@ class JSONSnapshotExtension(SingleFileSnapshotExtension): _max_depth: int = 99 _write_mode = WriteMode.TEXT - - @property - def _file_extension(self) -> str: - return "json" + _file_extension = "json" @classmethod def sort(cls, iterable: Iterable[Any]) -> Iterable[Any]: diff --git a/src/syrupy/extensions/single_file.py b/src/syrupy/extensions/single_file.py index e80a444f..af53ea4d 100644 --- a/src/syrupy/extensions/single_file.py +++ b/src/syrupy/extensions/single_file.py @@ -13,8 +13,9 @@ from syrupy.constants import TEXT_ENCODING from syrupy.data import ( Snapshot, - SnapshotFossil, + SnapshotCollection, ) +from syrupy.location import PyTestLocation from .base import AbstractSyrupyExtension @@ -39,6 +40,7 @@ def __str__(self) -> str: class SingleFileSnapshotExtension(AbstractSyrupyExtension): _text_encoding = TEXT_ENCODING _write_mode = WriteMode.BINARY + _file_extension = "raw" def serialize( self, @@ -47,11 +49,16 @@ def serialize( exclude: Optional["PropertyFilter"] = None, matcher: Optional["PropertyMatcher"] = None, ) -> "SerializedData": - return self._supported_dataclass(data) - - def get_snapshot_name(self, *, index: "SnapshotIndex" = 0) -> str: - return self.__clean_filename( - super(SingleFileSnapshotExtension, self).get_snapshot_name(index=index) + return self.get_supported_dataclass()(data) + + @classmethod + def get_snapshot_name( + cls, *, test_location: "PyTestLocation", index: "SnapshotIndex" = 0 + ) -> str: + return cls.__clean_filename( + AbstractSyrupyExtension.get_snapshot_name( + test_location=test_location, index=index + ) ) def delete_snapshots( @@ -59,62 +66,74 @@ def delete_snapshots( ) -> None: Path(snapshot_location).unlink() - @property - def _file_extension(self) -> str: - return "raw" - - def _get_file_basename(self, *, index: "SnapshotIndex") -> str: - return self.get_snapshot_name(index=index) + @classmethod + def _get_file_basename( + cls, *, test_location: "PyTestLocation", index: "SnapshotIndex" + ) -> str: + return cls.get_snapshot_name(test_location=test_location, index=index) - @property - def _dirname(self) -> str: - original_dirname = super(SingleFileSnapshotExtension, self)._dirname - return str(Path(original_dirname).joinpath(self.test_location.filename)) + @classmethod + def dirname(cls, *, test_location: "PyTestLocation") -> str: + original_dirname = AbstractSyrupyExtension.dirname(test_location=test_location) + return str(Path(original_dirname).joinpath(test_location.basename)) - def _read_snapshot_fossil(self, *, snapshot_location: str) -> "SnapshotFossil": - snapshot_fossil = SnapshotFossil(location=snapshot_location) - snapshot_fossil.add(Snapshot(name=Path(snapshot_location).stem)) - return snapshot_fossil + def _read_snapshot_collection( + self, *, snapshot_location: str + ) -> "SnapshotCollection": + snapshot_collection = SnapshotCollection(location=snapshot_location) + snapshot_collection.add(Snapshot(name=Path(snapshot_location).stem)) + return snapshot_collection def _read_snapshot_data_from_location( self, *, snapshot_location: str, snapshot_name: str, session_id: str ) -> Optional["SerializableData"]: try: with open( - snapshot_location, f"r{self._write_mode}", encoding=self._write_encoding + snapshot_location, + f"r{self._write_mode}", + encoding=self.get_write_encoding(), ) as f: return f.read() except FileNotFoundError: return None - @property - def _supported_dataclass(self) -> Union[Type[str], Type[bytes]]: - if self._write_mode == WriteMode.TEXT: + @classmethod + def get_supported_dataclass(cls) -> Union[Type[str], Type[bytes]]: + if cls._write_mode == WriteMode.TEXT: return str return bytes - @property - def _write_encoding(self) -> Optional[str]: - if self._write_mode == WriteMode.TEXT: + @classmethod + def get_write_encoding(cls) -> Optional[str]: + if cls._write_mode == WriteMode.TEXT: return TEXT_ENCODING return None - def _write_snapshot_fossil(self, *, snapshot_fossil: "SnapshotFossil") -> None: - filepath, data = snapshot_fossil.location, next(iter(snapshot_fossil)).data - if not isinstance(data, self._supported_dataclass): + @classmethod + def _write_snapshot_collection( + cls, *, snapshot_collection: "SnapshotCollection" + ) -> None: + filepath, data = ( + snapshot_collection.location, + next(iter(snapshot_collection)).data, + ) + if not isinstance(data, cls.get_supported_dataclass()): error_text = gettext( "Can't write non supported data. Expected '{}', got '{}'" ) raise TypeError( error_text.format( - self._supported_dataclass.__name__, type(data).__name__ + cls.get_supported_dataclass().__name__, type(data).__name__ ) ) - with open(filepath, f"w{self._write_mode}", encoding=self._write_encoding) as f: + with open( + filepath, f"w{cls._write_mode}", encoding=cls.get_write_encoding() + ) as f: f.write(data) - def __clean_filename(self, filename: str) -> str: - max_filename_length = 255 - len(self._file_extension or "") + @classmethod + def __clean_filename(cls, filename: str) -> str: + max_filename_length = 255 - len(cls._file_extension or "") exclude_chars = '\\/?*:|"<>' exclude_categ = ("C",) cleaned_filename = "".join( diff --git a/src/syrupy/location.py b/src/syrupy/location.py index ca4aa264..931e462b 100644 --- a/src/syrupy/location.py +++ b/src/syrupy/location.py @@ -66,7 +66,7 @@ def nodeid(self) -> str: return str(getattr(self._node, "nodeid")) # noqa: B009 @property - def filename(self) -> str: + def basename(self) -> str: return Path(self.filepath).stem @property @@ -117,4 +117,4 @@ def matches_snapshot_location(self, snapshot_location: str) -> bool: loc = Path(snapshot_location) # "test_file" should match_"test_file.ext" or "test_file/whatever.ext", but not # "test_file_suffix.ext" - return self.filename == loc.stem or self.filename == loc.parent.name + return self.basename == loc.stem or self.basename == loc.parent.name diff --git a/src/syrupy/report.py b/src/syrupy/report.py index b41911a2..825cf9fa 100644 --- a/src/syrupy/report.py +++ b/src/syrupy/report.py @@ -24,9 +24,9 @@ from .constants import PYTEST_NODE_SEP from .data import ( Snapshot, - SnapshotFossil, - SnapshotFossils, - SnapshotUnknownFossil, + SnapshotCollection, + SnapshotCollections, + SnapshotUnknownCollection, ) from .location import PyTestLocation from .terminal import ( @@ -50,7 +50,7 @@ class SnapshotReport: """ This class is responsible for determining the test summary and post execution results. It will provide the lines of the report to be printed as well as the - information used for removal of unused or orphaned snapshots and fossils. + information used for removal of unused or orphaned snapshots and collections. """ # Initial arguments to the report @@ -61,12 +61,12 @@ class SnapshotReport: assertions: List["SnapshotAssertion"] # All of these are derived from the initial arguments and via walking the filesystem - discovered: "SnapshotFossils" = field(default_factory=SnapshotFossils) - created: "SnapshotFossils" = field(default_factory=SnapshotFossils) - failed: "SnapshotFossils" = field(default_factory=SnapshotFossils) - matched: "SnapshotFossils" = field(default_factory=SnapshotFossils) - updated: "SnapshotFossils" = field(default_factory=SnapshotFossils) - used: "SnapshotFossils" = field(default_factory=SnapshotFossils) + discovered: "SnapshotCollections" = field(default_factory=SnapshotCollections) + created: "SnapshotCollections" = field(default_factory=SnapshotCollections) + failed: "SnapshotCollections" = field(default_factory=SnapshotCollections) + matched: "SnapshotCollections" = field(default_factory=SnapshotCollections) + updated: "SnapshotCollections" = field(default_factory=SnapshotCollections) + used: "SnapshotCollections" = field(default_factory=SnapshotCollections) _provided_test_paths: Dict[str, List[str]] = field(default_factory=dict) _keyword_expressions: Set["Expression"] = field(default_factory=set) _collected_items_by_nodeid: Dict[str, "pytest.Item"] = field( @@ -94,26 +94,32 @@ def __post_init__(self) -> None: # We only need to discover snapshots once per test file, not once per assertion. locations_discovered: DefaultDict[str, Set[Any]] = defaultdict(set) for assertion in self.assertions: - test_location = assertion.extension.test_location.filepath + test_location = assertion.test_location.filepath extension_class = assertion.extension.__class__ if extension_class not in locations_discovered[test_location]: locations_discovered[test_location].add(extension_class) - self.discovered.merge(assertion.extension.discover_snapshots()) + self.discovered.merge( + assertion.extension.discover_snapshots( + test_location=assertion.test_location + ) + ) for result in assertion.executions.values(): - snapshot_fossil = SnapshotFossil(location=result.snapshot_location) - snapshot_fossil.add( + snapshot_collection = SnapshotCollection( + location=result.snapshot_location + ) + snapshot_collection.add( Snapshot(name=result.snapshot_name, data=result.final_data) ) - self.used.update(snapshot_fossil) + self.used.update(snapshot_collection) if result.created: - self.created.update(snapshot_fossil) + self.created.update(snapshot_collection) elif result.updated: - self.updated.update(snapshot_fossil) + self.updated.update(snapshot_collection) elif result.success: - self.matched.update(snapshot_fossil) + self.matched.update(snapshot_collection) else: - self.failed.update(snapshot_fossil) + self.failed.update(snapshot_collection) def __parse_invocation_args(self) -> None: """ @@ -183,7 +189,7 @@ def ran_items(self) -> Iterator["pytest.Item"]: ) @property - def unused(self) -> "SnapshotFossils": + def unused(self) -> "SnapshotCollections": """ Iterate over each snapshot that was discovered but never used and compute if the snapshot was unused because the test attached to it was never run, @@ -192,11 +198,11 @@ def unused(self) -> "SnapshotFossils": Summary, if a snapshot was supposed to be run based on the invocation args and it was not, then it should be marked as unused otherwise ignored. """ - unused_fossils = SnapshotFossils() - for unused_snapshot_fossil in self._diff_snapshot_fossils( + unused_collections = SnapshotCollections() + for unused_snapshot_collection in self._diff_snapshot_collections( self.discovered, self.used ): - snapshot_location = unused_snapshot_fossil.location + snapshot_location = unused_snapshot_collection.location if self._provided_test_paths and not self._ran_items_match_location( snapshot_location ): @@ -207,13 +213,13 @@ def unused(self) -> "SnapshotFossils": provided_nodes = self._get_matching_path_nodes(snapshot_location) if self.selected_all_collected_items and not any(provided_nodes): # All collected tests were run and files were not filtered by ::node - # therefore the snapshot fossil file at this location can be deleted - unused_snapshots = {*unused_snapshot_fossil} + # therefore the snapshot collection file at this location can be deleted + unused_snapshots = {*unused_snapshot_collection} mark_for_removal = snapshot_location not in self.used else: unused_snapshots = { snapshot - for snapshot in unused_snapshot_fossil + for snapshot in unused_snapshot_collection if self._selected_items_match_name( snapshot_location=snapshot_location, snapshot_name=snapshot.name ) @@ -226,15 +232,17 @@ def unused(self) -> "SnapshotFossils": mark_for_removal = False if unused_snapshots: - marked_unused_snapshot_fossil = SnapshotFossil( + marked_unused_snapshot_collection = SnapshotCollection( location=snapshot_location ) for snapshot in unused_snapshots: - marked_unused_snapshot_fossil.add(snapshot) - unused_fossils.add(marked_unused_snapshot_fossil) + marked_unused_snapshot_collection.add(snapshot) + unused_collections.add(marked_unused_snapshot_collection) elif mark_for_removal: - unused_fossils.add(SnapshotUnknownFossil(location=snapshot_location)) - return unused_fossils + unused_collections.add( + SnapshotUnknownCollection(location=snapshot_location) + ) + return unused_collections @property def lines(self) -> Iterator[str]: @@ -299,9 +307,9 @@ def lines(self) -> Iterator[str]: yield "" if self.update_snapshots or self.include_snapshot_details: base_message = "Deleted" if self.update_snapshots else "Unused" - for snapshot_fossil in self.unused: - filepath = snapshot_fossil.location - snapshots = (snapshot.name for snapshot in snapshot_fossil) + for snapshot_collection in self.unused: + filepath = snapshot_collection.location + snapshots = (snapshot.name for snapshot in snapshot_collection) try: path_to_file = str(Path(filepath).relative_to(self.base_dir)) @@ -323,33 +331,40 @@ def lines(self) -> Iterator[str]: else: yield error_style(message) - def _diff_snapshot_fossils( - self, snapshot_fossils1: "SnapshotFossils", snapshot_fossils2: "SnapshotFossils" - ) -> "SnapshotFossils": + def _diff_snapshot_collections( + self, + snapshot_collections1: "SnapshotCollections", + snapshot_collections2: "SnapshotCollections", + ) -> "SnapshotCollections": """ - Find the difference between two collections of snapshot fossils. While - preserving the location site to all fossils in the first collections. That is - a collection with fossil sites {A{1,2}, B{3,4}, C{5,6}} with snapshot fossils - when diffed with another collection with snapshots {A{1,2}, B{3,4}, D{7,8}} - will result in a collection with the contents {A{}, B{}, C{5,6}}. + Find the difference between two collections of snapshot collections. While + preserving the location site to all collections in the first collections. + That is a collection with collection sites {A{1,2}, B{3,4}, C{5,6}} with + snapshot collections when diffed with another collection with snapshots + {A{1,2}, B{3,4}, D{7,8}} will result in a collection with the contents + {A{}, B{}, C{5,6}}. """ - diffed_snapshot_fossils: "SnapshotFossils" = SnapshotFossils() - for snapshot_fossil1 in snapshot_fossils1: - snapshot_fossil2 = snapshot_fossils2.get( - snapshot_fossil1.location - ) or SnapshotFossil(location=snapshot_fossil1.location) - diffed_snapshot_fossil = SnapshotFossil(location=snapshot_fossil1.location) - for snapshot in snapshot_fossil1: - if not snapshot_fossil2.get(snapshot.name): - diffed_snapshot_fossil.add(snapshot) - diffed_snapshot_fossils.add(diffed_snapshot_fossil) - return diffed_snapshot_fossils - - def _count_snapshots(self, snapshot_fossils: "SnapshotFossils") -> int: + diffed_snapshot_collections: "SnapshotCollections" = SnapshotCollections() + for snapshot_collection1 in snapshot_collections1: + snapshot_collection2 = snapshot_collections2.get( + snapshot_collection1.location + ) or SnapshotCollection(location=snapshot_collection1.location) + diffed_snapshot_collection = SnapshotCollection( + location=snapshot_collection1.location + ) + for snapshot in snapshot_collection1: + if not snapshot_collection2.get(snapshot.name): + diffed_snapshot_collection.add(snapshot) + diffed_snapshot_collections.add(diffed_snapshot_collection) + return diffed_snapshot_collections + + def _count_snapshots(self, snapshot_collections: "SnapshotCollections") -> int: """ - Count all the snapshots at all the locations in the snapshot fossil collection + Count all the snapshots at all the locations in the snapshot collections """ - return sum(len(snapshot_fossil) for snapshot_fossil in snapshot_fossils) + return sum( + len(snapshot_collection) for snapshot_collection in snapshot_collections + ) def _is_matching_path(self, snapshot_location: str, provided_path: str) -> bool: """ @@ -428,8 +443,8 @@ def _selected_items_match_name( def _ran_items_match_location(self, snapshot_location: str) -> bool: """ Check if any test run in the current session should match the snapshot location - This being true means that if no snapshot in the fossil was used then it should - be discarded as obsolete + This being true means that if no snapshot in the collection was used then it + should be discarded as obsolete """ return any( PyTestLocation(item).matches_snapshot_location(snapshot_location) diff --git a/src/syrupy/session.py b/src/syrupy/session.py index 95d11f2b..f24ceb1f 100644 --- a/src/syrupy/session.py +++ b/src/syrupy/session.py @@ -14,12 +14,15 @@ Optional, Set, Tuple, + Type, ) import pytest +from syrupy.location import PyTestLocation + from .constants import EXIT_STATUS_FAIL_UNUSED -from .data import SnapshotFossils +from .data import SnapshotCollections from .report import SnapshotReport from .utils import ( is_xdist_controller, @@ -53,23 +56,34 @@ class SnapshotSession: ) _queued_snapshot_writes: Dict[ - "AbstractSyrupyExtension", List[Tuple["SerializedData", "SnapshotIndex"]] + Tuple[Type["AbstractSyrupyExtension"], str], + List[Tuple["SerializedData", "PyTestLocation", "SnapshotIndex"]], ] = field(default_factory=dict) def queue_snapshot_write( self, extension: "AbstractSyrupyExtension", + test_location: "PyTestLocation", data: "SerializedData", index: "SnapshotIndex", ) -> None: - queue = self._queued_snapshot_writes.get(extension, []) - queue.append((data, index)) - self._queued_snapshot_writes[extension] = queue + snapshot_location = extension.get_location( + test_location=test_location, index=index + ) + key = (extension.__class__, snapshot_location) + queue = self._queued_snapshot_writes.get(key, []) + queue.append((data, test_location, index)) + self._queued_snapshot_writes[key] = queue def flush_snapshot_write_queue(self) -> None: - for extension, queued_write in self._queued_snapshot_writes.items(): + for ( + extension_class, + snapshot_location, + ), queued_write in self._queued_snapshot_writes.items(): if queued_write: - extension.write_snapshot_batch(snapshots=queued_write) + extension_class.write_snapshot( + snapshot_location=snapshot_location, snapshots=queued_write + ) self._queued_snapshot_writes = {} @property @@ -124,8 +138,8 @@ def finish(self) -> int: if self.report.num_unused: if self.update_snapshots: self.remove_unused_snapshots( - unused_snapshot_fossils=self.report.unused, - used_snapshot_fossils=self.report.used, + unused_snapshot_collections=self.report.unused, + used_snapshot_collections=self.report.used, ) elif not self.warn_unused_snapshots: exitstatus |= EXIT_STATUS_FAIL_UNUSED @@ -134,38 +148,40 @@ def finish(self) -> int: def register_request(self, assertion: "SnapshotAssertion") -> None: self._assertions.append(assertion) - test_location = assertion.extension.test_location.filepath + test_location = assertion.test_location.filepath extension_class = assertion.extension.__class__ if extension_class not in self._locations_discovered[test_location]: self._locations_discovered[test_location].add(extension_class) discovered_extensions = { discovered.location: assertion.extension - for discovered in assertion.extension.discover_snapshots() + for discovered in assertion.extension.discover_snapshots( + test_location=assertion.test_location + ) if discovered.has_snapshots } self._extensions.update(discovered_extensions) def remove_unused_snapshots( self, - unused_snapshot_fossils: "SnapshotFossils", - used_snapshot_fossils: "SnapshotFossils", + unused_snapshot_collections: "SnapshotCollections", + used_snapshot_collections: "SnapshotCollections", ) -> None: """ - Remove all unused snapshots using the registed extension for the fossil file + Remove all unused snapshots using the registed extension for the collection file If there is not registered extension and the location is unused delete the file """ - for unused_snapshot_fossil in unused_snapshot_fossils: - snapshot_location = unused_snapshot_fossil.location + for unused_snapshot_collection in unused_snapshot_collections: + snapshot_location = unused_snapshot_collection.location extension = self._extensions.get(snapshot_location) if extension: extension.delete_snapshots( snapshot_location=snapshot_location, snapshot_names={ - snapshot.name for snapshot in unused_snapshot_fossil + snapshot.name for snapshot in unused_snapshot_collection }, ) - elif snapshot_location not in used_snapshot_fossils: + elif snapshot_location not in used_snapshot_collections: Path(snapshot_location).unlink() @staticmethod diff --git a/tests/examples/test_custom_image_extension.py b/tests/examples/test_custom_image_extension.py index 5991b742..9cb858c5 100644 --- a/tests/examples/test_custom_image_extension.py +++ b/tests/examples/test_custom_image_extension.py @@ -14,9 +14,7 @@ class JPEGImageExtension(SingleFileSnapshotExtension): - @property - def _file_extension(self) -> str: - return "jpg" + _file_extension = "jpg" @pytest.fixture diff --git a/tests/examples/test_custom_snapshot_directory.py b/tests/examples/test_custom_snapshot_directory.py index 305c27e7..766e5e96 100644 --- a/tests/examples/test_custom_snapshot_directory.py +++ b/tests/examples/test_custom_snapshot_directory.py @@ -15,16 +15,15 @@ import pytest from syrupy.extensions.amber import AmberSnapshotExtension +from syrupy.location import PyTestLocation DIFFERENT_DIRECTORY = "__snaps_example__" class DifferentDirectoryExtension(AmberSnapshotExtension): - @property - def _dirname(self) -> str: - return str( - Path(self.test_location.filepath).parent.joinpath(DIFFERENT_DIRECTORY) - ) + @classmethod + def dirname(cls, *, test_location: "PyTestLocation") -> str: + return str(Path(test_location.filepath).parent.joinpath(DIFFERENT_DIRECTORY)) @pytest.fixture diff --git a/tests/examples/test_custom_snapshot_directory_2.py b/tests/examples/test_custom_snapshot_directory_2.py index 496b8a7e..f4ac65af 100644 --- a/tests/examples/test_custom_snapshot_directory_2.py +++ b/tests/examples/test_custom_snapshot_directory_2.py @@ -15,14 +15,15 @@ import pytest from syrupy.extensions.json import JSONSnapshotExtension +from syrupy.location import PyTestLocation def create_versioned_fixture(version: int): class VersionedJSONExtension(JSONSnapshotExtension): - @property - def _dirname(self) -> str: + @classmethod + def dirname(cls, *, test_location: "PyTestLocation") -> str: return str( - Path(self.test_location.filepath).parent.joinpath( + Path(test_location.filepath).parent.joinpath( "__snapshots__", f"v{version}" ) ) diff --git a/tests/examples/test_custom_snapshot_name.py b/tests/examples/test_custom_snapshot_name.py index 64627742..c627f8c7 100644 --- a/tests/examples/test_custom_snapshot_name.py +++ b/tests/examples/test_custom_snapshot_name.py @@ -4,13 +4,17 @@ import pytest from syrupy.extensions.amber import AmberSnapshotExtension +from syrupy.location import PyTestLocation from syrupy.types import SnapshotIndex class CanadianNameExtension(AmberSnapshotExtension): - def get_snapshot_name(self, *, index: "SnapshotIndex") -> str: - original_name = super(CanadianNameExtension, self).get_snapshot_name( - index=index + @classmethod + def get_snapshot_name( + cls, *, test_location: "PyTestLocation", index: "SnapshotIndex" + ) -> str: + original_name = AmberSnapshotExtension.get_snapshot_name( + test_location=test_location, index=index ) return f"{original_name}🇨🇦" diff --git a/tests/integration/test_snapshot_option_update.py b/tests/integration/test_snapshot_option_update.py index 12f4533d..b898c9a8 100644 --- a/tests/integration/test_snapshot_option_update.py +++ b/tests/integration/test_snapshot_option_update.py @@ -394,7 +394,7 @@ def test_used(snapshot): assert not Path("__snapshots__", "test_used.ambr").exists() -def test_update_removes_empty_snapshot_fossil_only(run_testcases): +def test_update_removes_empty_snapshot_collection_only(run_testcases): testdir = run_testcases[1] snapfile_empty = Path("__snapshots__", "empty_snapfile.ambr") testdir.makefile(".ambr", **{str(snapfile_empty): ""}) @@ -403,7 +403,8 @@ def test_update_removes_empty_snapshot_fossil_only(run_testcases): result.stdout.re_match_lines( ( r"10 snapshots passed\. 1 unused snapshot deleted\.", - r"Deleted empty snapshot fossil \(__snapshots__[\\/]empty_snapfile\.ambr\)", + r"Deleted empty snapshot collection " + r"\(__snapshots__[\\/]empty_snapfile\.ambr\)", ) ) assert result.ret == 0 @@ -411,7 +412,7 @@ def test_update_removes_empty_snapshot_fossil_only(run_testcases): assert Path("__snapshots__", "test_used.ambr").exists() -def test_update_removes_hanging_snapshot_fossil_file(run_testcases): +def test_update_removes_hanging_snapshot_collection_file(run_testcases): testdir = run_testcases[1] snapfile_used = Path("__snapshots__", "test_used.ambr") snapfile_hanging = Path("__snapshots__", "hanging_snapfile.abc") @@ -421,7 +422,7 @@ def test_update_removes_hanging_snapshot_fossil_file(run_testcases): result.stdout.re_match_lines( ( r"10 snapshots passed\. 1 unused snapshot deleted\.", - r"Deleted unknown snapshot fossil " + r"Deleted unknown snapshot collection " r"\(__snapshots__[\\/]hanging_snapfile\.abc\)", ) ) diff --git a/tests/integration/test_snapshot_outside_directory.py b/tests/integration/test_snapshot_outside_directory.py index b241c6f0..350109b3 100644 --- a/tests/integration/test_snapshot_outside_directory.py +++ b/tests/integration/test_snapshot_outside_directory.py @@ -11,8 +11,8 @@ def testcases(testdir, tmp_path): from syrupy.extensions.amber import AmberSnapshotExtension class CustomSnapshotExtension(AmberSnapshotExtension): - @property - def _dirname(self): + @classmethod + def dirname(cls, *, test_location): return {str(dirname)!r} @pytest.fixture diff --git a/tests/integration/test_snapshot_use_extension.py b/tests/integration/test_snapshot_use_extension.py index df77f253..375d3ad5 100644 --- a/tests/integration/test_snapshot_use_extension.py +++ b/tests/integration/test_snapshot_use_extension.py @@ -16,19 +16,19 @@ def testcases_initial(testdir): class CustomSnapshotExtension(AmberSnapshotExtension): - @property - def _file_extension(self): - return "" + _file_extension = "" def serialize(self, data, **kwargs): return str(data) - def get_snapshot_name(self, *, index): - testname = self._test_location.testname[::-1] + @classmethod + def get_snapshot_name(cls, *, test_location, index): + testname = test_location.testname[::-1] return f"{testname}.{index}" - def _get_file_basename(self, *, index): - return self.test_location.filename[::-1] + @classmethod + def _get_file_basename(cls, *, test_location, index): + return test_location.basename[::-1] @pytest.fixture def snapshot_custom(snapshot): diff --git a/tests/syrupy/extensions/test_single_file.py b/tests/syrupy/extensions/test_single_file.py index da7c2280..6adcda72 100644 --- a/tests/syrupy/extensions/test_single_file.py +++ b/tests/syrupy/extensions/test_single_file.py @@ -5,7 +5,7 @@ from syrupy.data import ( Snapshot, - SnapshotFossil, + SnapshotCollection, ) from syrupy.extensions.single_file import ( SingleFileSnapshotExtension, @@ -31,15 +31,15 @@ def snapshot_utf8(snapshot): def test_does_not_write_non_binary(testdir, snapshot_single: "SnapshotAssertion"): - snapshot_fossil = SnapshotFossil( - location=str(Path(testdir.tmpdir).joinpath("snapshot_fossil.raw")), + snapshot_collection = SnapshotCollection( + location=str(Path(testdir.tmpdir).joinpath("snapshot_collection.raw")), ) - snapshot_fossil.add(Snapshot(name="snapshot_name", data="non binary data")) + snapshot_collection.add(Snapshot(name="snapshot_name", data="non binary data")) with pytest.raises(TypeError, match="Expected 'bytes', got 'str'"): - snapshot_single.extension._write_snapshot_fossil( - snapshot_fossil=snapshot_fossil + snapshot_single.extension._write_snapshot_collection( + snapshot_collection=snapshot_collection ) - assert not Path(snapshot_fossil.location).exists() + assert not Path(snapshot_collection.location).exists() class TestClass: diff --git a/tests/syrupy/test_location.py b/tests/syrupy/test_location.py index 6da7f9ab..4bf79920 100644 --- a/tests/syrupy/test_location.py +++ b/tests/syrupy/test_location.py @@ -54,7 +54,7 @@ def test_location_properties( ): location = PyTestLocation(mock_pytest_item(node_id, method_name)) assert location.classname == expected_classname - assert location.filename == expected_filename + assert location.basename == expected_filename assert location.snapshot_name == expected_snapshotname