Skip to content

Commit

Permalink
Attempt at type checking
Browse files Browse the repository at this point in the history
  • Loading branch information
atharva-2001 committed Sep 25, 2023
1 parent f5b3c4c commit 26d3347
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions src/syrupy/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Optional,
Set,
Tuple,
Any
)

from syrupy.constants import (
Expand Down Expand Up @@ -67,7 +68,7 @@ def serialize(
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
**kwargs
**kwargs: Dict[Any, Any],
) -> "SerializedData":
"""
Serializes a python object / data structure into a string
Expand Down Expand Up @@ -109,7 +110,7 @@ def is_snapshot_location(self, *, location: str) -> bool:
return location.endswith(self._file_extension)

def discover_snapshots(
self, *, test_location: "PyTestLocation", **kwargs
self, *, test_location: "PyTestLocation", **kwargs: Dict[Any, Any]
) -> "SnapshotCollections":
"""
Returns all snapshot collections in test site
Expand Down Expand Up @@ -217,7 +218,7 @@ def delete_snapshots(

@abstractmethod
def _read_snapshot_collection(
self, *, snapshot_location: str, **kwargs
self, *, snapshot_location: str, **kwargs: Dict[Any, Any]
) -> "SnapshotCollection":
"""
Read the snapshot location and construct a snapshot collection object
Expand All @@ -236,15 +237,15 @@ def _read_snapshot_data_from_location(
@classmethod
@abstractmethod
def _write_snapshot_collection(
cls, *, snapshot_collection: "SnapshotCollection", **kwargs
cls, *, snapshot_collection: "SnapshotCollection", **kwargs: Dict[Any, Any]
) -> None:
"""
Adds the snapshot data to the snapshots in collection location
"""
raise NotImplementedError

@classmethod
def dirname(cls, *, test_location: "PyTestLocation", **kwargs) -> str:
def dirname(cls, *, test_location: "PyTestLocation", **kwargs: Dict[Any, Any]) -> str:
test_dir = Path(test_location.filepath).parent
return str(test_dir.joinpath(SNAPSHOT_DIRNAME))

Expand All @@ -260,15 +261,21 @@ class SnapshotReporter(ABC):
_context_line_count = 1

def diff_snapshots(
self, serialized_data: "SerializedData", snapshot_data: "SerializedData", **kwargs
self,
serialized_data: "SerializedData",
snapshot_data: "SerializedData",
**kwargs: Dict[Any, Any],
) -> "SerializedData":
env = {DISABLE_COLOR_ENV_VAR: "true"}
attrs = {"_context_line_count": 0}
with env_context(**env), obj_attrs(self, attrs):
return "\n".join(self.diff_lines(serialized_data, snapshot_data))

def diff_lines(
self, serialized_data: "SerializedData", snapshot_data: "SerializedData", **kwargs
self,
serialized_data: "SerializedData",
snapshot_data: "SerializedData",
**kwargs: Dict[Any, Any],
) -> Iterator[str]:
for line in self.__diff_lines(str(snapshot_data), str(serialized_data)):
yield reset(line)
Expand Down Expand Up @@ -408,7 +415,7 @@ def matches(
*,
serialized_data: "SerializableData",
snapshot_data: "SerializableData",
**kwargs
**kwargs: Dict[Any, Any],
) -> bool:
"""
Compares serialized data and snapshot data and returns
Expand Down

0 comments on commit 26d3347

Please sign in to comment.