From ae0743596bc7b4586d49b374fd411a2a2a39d533 Mon Sep 17 00:00:00 2001 From: Noah Negin-Ulster Date: Thu, 1 Dec 2022 17:08:22 -0500 Subject: [PATCH] refactor: make write_snapshot a classmethod --- src/syrupy/extensions/amber/__init__.py | 3 +- src/syrupy/extensions/base.py | 50 +++++++++---------- src/syrupy/extensions/single_file.py | 41 ++++++++------- .../test_custom_snapshot_directory.py | 3 +- .../test_custom_snapshot_directory_2.py | 3 +- .../test_snapshot_outside_directory.py | 3 +- .../test_snapshot_use_extension.py | 3 +- 7 files changed, 57 insertions(+), 49 deletions(-) diff --git a/src/syrupy/extensions/amber/__init__.py b/src/syrupy/extensions/amber/__init__.py index d8559d84..f6ca5773 100644 --- a/src/syrupy/extensions/amber/__init__.py +++ b/src/syrupy/extensions/amber/__init__.py @@ -61,8 +61,9 @@ def _read_snapshot_data_from_location( snapshot = snapshots.get(snapshot_name) return snapshot.data if snapshot else None + @classmethod def _write_snapshot_collection( - self, *, snapshot_collection: "SnapshotCollection" + cls, *, snapshot_collection: "SnapshotCollection" ) -> None: DataSerializer.write_file(snapshot_collection, merge=True) diff --git a/src/syrupy/extensions/base.py b/src/syrupy/extensions/base.py index 79417765..1067e7fd 100644 --- a/src/syrupy/extensions/base.py +++ b/src/syrupy/extensions/base.py @@ -96,14 +96,15 @@ def get_snapshot_name( index_suffix = f".{index}" return f"{test_location.snapshot_name}{index_suffix}" + @classmethod def get_location( - self, *, test_location: "PyTestLocation", index: "SnapshotIndex" + cls, *, test_location: "PyTestLocation", index: "SnapshotIndex" ) -> str: - """Returns full location where snapshot data is stored.""" - basename = self._get_file_basename(test_location=test_location, index=index) - fileext = f".{self._file_extension}" if self._file_extension else "" + """Returns full filepath where snapshot data is stored.""" + basename = cls._get_file_basename(test_location=test_location, index=index) + fileext = f".{cls._file_extension}" if cls._file_extension else "" return str( - Path(self.dirname(test_location=test_location)).joinpath( + Path(cls.dirname(test_location=test_location)).joinpath( f"{basename}{fileext}" ) ) @@ -155,8 +156,9 @@ def read_snapshot( raise SnapshotDoesNotExist() return snapshot_data + @classmethod def write_snapshot( - self, + cls, *, test_location: "PyTestLocation", snapshots: List[Tuple["SerializedData", "SnapshotIndex"]], @@ -170,13 +172,19 @@ def write_snapshot( # Amber extension. locations: DefaultDict[str, List["Snapshot"]] = defaultdict(list) for data, index in snapshots: - location = self.get_location(test_location=test_location, index=index) - snapshot_name = self.get_snapshot_name( + location = cls.get_location(test_location=test_location, index=index) + snapshot_name = cls.get_snapshot_name( test_location=test_location, index=index ) locations[location].append(Snapshot(name=snapshot_name, data=data)) - self.__ensure_snapshot_dir(test_location=test_location, index=index) + # Ensures the folder path for the snapshot file exists. + try: + Path( + cls.get_location(test_location=test_location, index=index) + ).parent.mkdir(parents=True) + except FileExistsError: + pass for location, location_snapshots in locations.items(): snapshot_collection = SnapshotCollection(location=location) @@ -208,7 +216,7 @@ def write_snapshot( ) warnings.warn(warning_msg) - self._write_snapshot_collection(snapshot_collection=snapshot_collection) + cls._write_snapshot_collection(snapshot_collection=snapshot_collection) @abstractmethod def delete_snapshots( @@ -238,38 +246,28 @@ def _read_snapshot_data_from_location( """ raise NotImplementedError + @classmethod @abstractmethod def _write_snapshot_collection( - self, *, snapshot_collection: "SnapshotCollection" + cls, *, snapshot_collection: "SnapshotCollection" ) -> None: """ Adds the snapshot data to the snapshots in collection location """ raise NotImplementedError - def dirname(self, *, test_location: "PyTestLocation") -> str: + @classmethod + def dirname(cls, *, test_location: "PyTestLocation") -> str: test_dir = Path(test_location.filepath).parent return str(test_dir.joinpath(SNAPSHOT_DIRNAME)) + @classmethod def _get_file_basename( - self, *, test_location: "PyTestLocation", index: "SnapshotIndex" + cls, *, test_location: "PyTestLocation", index: "SnapshotIndex" ) -> str: """Returns file basename without extension. Used to create full filepath.""" return test_location.basename - def __ensure_snapshot_dir( - self, *, test_location: "PyTestLocation", index: "SnapshotIndex" - ) -> None: - """ - Ensures the folder path for the snapshot file exists. - """ - try: - Path( - self.get_location(test_location=test_location, index=index) - ).parent.mkdir(parents=True) - except FileExistsError: - pass - class SnapshotReporter(ABC): _context_line_count = 1 diff --git a/src/syrupy/extensions/single_file.py b/src/syrupy/extensions/single_file.py index 6291c3e1..af53ea4d 100644 --- a/src/syrupy/extensions/single_file.py +++ b/src/syrupy/extensions/single_file.py @@ -49,7 +49,7 @@ def serialize( exclude: Optional["PropertyFilter"] = None, matcher: Optional["PropertyMatcher"] = None, ) -> "SerializedData": - return self._supported_dataclass(data) + return self.get_supported_dataclass()(data) @classmethod def get_snapshot_name( @@ -66,15 +66,15 @@ def delete_snapshots( ) -> None: Path(snapshot_location).unlink() + @classmethod def _get_file_basename( - self, *, test_location: "PyTestLocation", index: "SnapshotIndex" + cls, *, test_location: "PyTestLocation", index: "SnapshotIndex" ) -> str: - return self.get_snapshot_name(test_location=test_location, index=index) + return cls.get_snapshot_name(test_location=test_location, index=index) - def dirname(self, *, test_location: "PyTestLocation") -> str: - original_dirname = super(SingleFileSnapshotExtension, self).dirname( - test_location=test_location - ) + @classmethod + def dirname(cls, *, test_location: "PyTestLocation") -> str: + original_dirname = AbstractSyrupyExtension.dirname(test_location=test_location) return str(Path(original_dirname).joinpath(test_location.basename)) def _read_snapshot_collection( @@ -89,41 +89,46 @@ def _read_snapshot_data_from_location( ) -> Optional["SerializableData"]: try: with open( - snapshot_location, f"r{self._write_mode}", encoding=self._write_encoding + snapshot_location, + f"r{self._write_mode}", + encoding=self.get_write_encoding(), ) as f: return f.read() except FileNotFoundError: return None - @property - def _supported_dataclass(self) -> Union[Type[str], Type[bytes]]: - if self._write_mode == WriteMode.TEXT: + @classmethod + def get_supported_dataclass(cls) -> Union[Type[str], Type[bytes]]: + if cls._write_mode == WriteMode.TEXT: return str return bytes - @property - def _write_encoding(self) -> Optional[str]: - if self._write_mode == WriteMode.TEXT: + @classmethod + def get_write_encoding(cls) -> Optional[str]: + if cls._write_mode == WriteMode.TEXT: return TEXT_ENCODING return None + @classmethod def _write_snapshot_collection( - self, *, snapshot_collection: "SnapshotCollection" + cls, *, snapshot_collection: "SnapshotCollection" ) -> None: filepath, data = ( snapshot_collection.location, next(iter(snapshot_collection)).data, ) - if not isinstance(data, self._supported_dataclass): + if not isinstance(data, cls.get_supported_dataclass()): error_text = gettext( "Can't write non supported data. Expected '{}', got '{}'" ) raise TypeError( error_text.format( - self._supported_dataclass.__name__, type(data).__name__ + cls.get_supported_dataclass().__name__, type(data).__name__ ) ) - with open(filepath, f"w{self._write_mode}", encoding=self._write_encoding) as f: + with open( + filepath, f"w{cls._write_mode}", encoding=cls.get_write_encoding() + ) as f: f.write(data) @classmethod diff --git a/tests/examples/test_custom_snapshot_directory.py b/tests/examples/test_custom_snapshot_directory.py index 345143a5..766e5e96 100644 --- a/tests/examples/test_custom_snapshot_directory.py +++ b/tests/examples/test_custom_snapshot_directory.py @@ -21,7 +21,8 @@ class DifferentDirectoryExtension(AmberSnapshotExtension): - def dirname(self, *, test_location: "PyTestLocation") -> str: + @classmethod + def dirname(cls, *, test_location: "PyTestLocation") -> str: return str(Path(test_location.filepath).parent.joinpath(DIFFERENT_DIRECTORY)) diff --git a/tests/examples/test_custom_snapshot_directory_2.py b/tests/examples/test_custom_snapshot_directory_2.py index 5bae0d80..f4ac65af 100644 --- a/tests/examples/test_custom_snapshot_directory_2.py +++ b/tests/examples/test_custom_snapshot_directory_2.py @@ -20,7 +20,8 @@ def create_versioned_fixture(version: int): class VersionedJSONExtension(JSONSnapshotExtension): - def dirname(self, *, test_location: "PyTestLocation") -> str: + @classmethod + def dirname(cls, *, test_location: "PyTestLocation") -> str: return str( Path(test_location.filepath).parent.joinpath( "__snapshots__", f"v{version}" diff --git a/tests/integration/test_snapshot_outside_directory.py b/tests/integration/test_snapshot_outside_directory.py index fb444a54..350109b3 100644 --- a/tests/integration/test_snapshot_outside_directory.py +++ b/tests/integration/test_snapshot_outside_directory.py @@ -11,7 +11,8 @@ def testcases(testdir, tmp_path): from syrupy.extensions.amber import AmberSnapshotExtension class CustomSnapshotExtension(AmberSnapshotExtension): - def dirname(self, *, test_location): + @classmethod + def dirname(cls, *, test_location): return {str(dirname)!r} @pytest.fixture diff --git a/tests/integration/test_snapshot_use_extension.py b/tests/integration/test_snapshot_use_extension.py index 98049368..375d3ad5 100644 --- a/tests/integration/test_snapshot_use_extension.py +++ b/tests/integration/test_snapshot_use_extension.py @@ -26,7 +26,8 @@ def get_snapshot_name(cls, *, test_location, index): testname = test_location.testname[::-1] return f"{testname}.{index}" - def _get_file_basename(self, *, test_location, index): + @classmethod + def _get_file_basename(cls, *, test_location, index): return test_location.basename[::-1] @pytest.fixture