Skip to content

Commit

Permalink
feat: add snapshot diffing support
Browse files Browse the repository at this point in the history
  • Loading branch information
iamogbz committed May 10, 2022
1 parent 3f32f4b commit 1440e9a
Show file tree
Hide file tree
Showing 8 changed files with 238 additions and 25 deletions.
21 changes: 17 additions & 4 deletions src/syrupy/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
PropertyMatcher,
SerializableData,
SerializedData,
SnapshotIndex,
)


Expand Down Expand Up @@ -169,10 +170,11 @@ def __with_prop(self, prop_name: str, prop_value: Any) -> None:
def __call__(
self,
*,
diff: Optional["SnapshotIndex"] = None,
exclude: Optional["PropertyFilter"] = None,
extension_class: Optional[Type["AbstractSyrupyExtension"]] = None,
matcher: Optional["PropertyMatcher"] = None,
name: Optional[str] = None,
name: Optional["SnapshotIndex"] = None,
) -> "SnapshotAssertion":
"""
Modifies assertion instance options
Expand All @@ -185,6 +187,8 @@ def __call__(
self.__with_prop("_matcher", matcher)
if name:
self.__with_prop("_custom_index", name)
if diff is not None:
self.__with_prop("_snapshot_diff", diff)
return self

def __dir__(self) -> List[str]:
Expand All @@ -202,8 +206,17 @@ def _assert(self, data: "SerializableData") -> bool:
assertion_success = False
assertion_exception = None
try:
snapshot_data = self._recall_data()
snapshot_data = self._recall_data(index=self.index)
serialized_data = self._serialize(data)
snapshot_diff = getattr(self, "_snapshot_diff", None)
if snapshot_diff is not None:
snapshot_data_diff = self._recall_data(index=snapshot_diff)
if snapshot_data_diff is None:
raise SnapshotDoesNotExist()
serialized_data = self.extension.diff_snapshots(
serialized_data=serialized_data,
snapshot_data=snapshot_data_diff,
)
matches = snapshot_data is not None and self.extension.matches(
serialized_data=serialized_data, snapshot_data=snapshot_data
)
Expand Down Expand Up @@ -241,8 +254,8 @@ def _post_assert(self) -> None:
while self._post_assert_actions:
self._post_assert_actions.pop()()

def _recall_data(self) -> Optional["SerializableData"]:
def _recall_data(self, index: "SnapshotIndex") -> Optional["SerializableData"]:
try:
return self.extension.read_snapshot(index=self.index)
return self.extension.read_snapshot(index=index)
except SnapshotDoesNotExist:
return None
7 changes: 5 additions & 2 deletions src/syrupy/extensions/amber/serializer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import functools
import os
from types import GeneratorType
from types import (
GeneratorType,
MappingProxyType,
)
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -163,7 +166,7 @@ def _serialize(
serialize_method = cls.serialize_number
elif isinstance(data, (set, frozenset)):
serialize_method = cls.serialize_set
elif isinstance(data, dict):
elif isinstance(data, (dict, MappingProxyType)):
serialize_method = cls.serialize_dict
elif cls.__is_namedtuple(data):
serialize_method = cls.serialize_namedtuple
Expand Down
43 changes: 27 additions & 16 deletions src/syrupy/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
List,
Optional,
Set,
Union,
)

from syrupy.constants import (
DISABLE_COLOR_ENV_VAR,
SNAPSHOT_DIRNAME,
SYMBOL_CARRIAGE,
SYMBOL_ELLIPSIS,
Expand All @@ -40,7 +40,11 @@
snapshot_diff_style,
snapshot_style,
)
from syrupy.utils import walk_snapshot_dir
from syrupy.utils import (
env_context,
obj_attrs,
walk_snapshot_dir,
)

if TYPE_CHECKING:
from syrupy.location import PyTestLocation
Expand All @@ -49,6 +53,7 @@
PropertyMatcher,
SerializableData,
SerializedData,
SnapshotIndex,
)


Expand All @@ -74,7 +79,7 @@ class SnapshotFossilizer(ABC):
def test_location(self) -> "PyTestLocation":
raise NotImplementedError

def get_snapshot_name(self, *, index: Union[str, int] = 0) -> str:
def get_snapshot_name(self, *, index: "SnapshotIndex" = 0) -> str:
"""Get the snapshot name for the assertion index in a test location"""
index_suffix = ""
if isinstance(index, (str,)):
Expand All @@ -83,7 +88,7 @@ def get_snapshot_name(self, *, index: Union[str, int] = 0) -> str:
index_suffix = f".{index}"
return f"{self.test_location.snapshot_name}{index_suffix}"

def get_location(self, *, index: Union[str, int]) -> str:
def get_location(self, *, index: "SnapshotIndex") -> str:
"""Returns full location where snapshot data is stored."""
basename = self._get_file_basename(index=index)
fileext = f".{self._file_extension}" if self._file_extension else ""
Expand All @@ -110,7 +115,7 @@ def discover_snapshots(self) -> "SnapshotFossils":

return discovered

def read_snapshot(self, *, index: Union[str, int]) -> "SerializedData":
def read_snapshot(self, *, index: "SnapshotIndex") -> "SerializedData":
"""
Utility method for reading the contents of a snapshot assertion.
Will call `_pre_read`, then perform `read` and finally `post_read`,
Expand All @@ -132,7 +137,7 @@ def read_snapshot(self, *, index: Union[str, int]) -> "SerializedData":
finally:
self._post_read(index=index)

def write_snapshot(self, *, data: "SerializedData", index: Union[str, int]) -> None:
def write_snapshot(self, *, data: "SerializedData", index: "SnapshotIndex") -> None:
"""
Utility method for writing the contents of a snapshot assertion.
Will call `_pre_write`, then perform `write` and finally `_post_write`.
Expand Down Expand Up @@ -178,17 +183,17 @@ def delete_snapshots(
"""
raise NotImplementedError

def _pre_read(self, *, index: Union[str, int] = 0) -> None:
def _pre_read(self, *, index: "SnapshotIndex" = 0) -> None:
pass

def _post_read(self, *, index: Union[str, int] = 0) -> None:
def _post_read(self, *, index: "SnapshotIndex" = 0) -> None:
pass

def _pre_write(self, *, data: "SerializedData", index: Union[str, int] = 0) -> None:
def _pre_write(self, *, data: "SerializedData", index: "SnapshotIndex" = 0) -> None:
self.__ensure_snapshot_dir(index=index)

def _post_write(
self, *, data: "SerializedData", index: Union[str, int] = 0
self, *, data: "SerializedData", index: "SnapshotIndex" = 0
) -> None:
pass

Expand Down Expand Up @@ -225,11 +230,11 @@ def _dirname(self) -> str:
def _file_extension(self) -> str:
raise NotImplementedError

def _get_file_basename(self, *, index: Union[str, int]) -> str:
def _get_file_basename(self, *, index: "SnapshotIndex") -> str:
"""Returns file basename without extension. Used to create full filepath."""
return self.test_location.filename

def __ensure_snapshot_dir(self, *, index: Union[str, int]) -> None:
def __ensure_snapshot_dir(self, *, index: "SnapshotIndex") -> None:
"""
Ensures the folder path for the snapshot file exists.
"""
Expand All @@ -240,6 +245,16 @@ def __ensure_snapshot_dir(self, *, index: Union[str, int]) -> None:


class SnapshotReporter(ABC):
_context_line_count = 1

def diff_snapshots(
self, serialized_data: "SerializedData", snapshot_data: "SerializedData"
) -> "SerializedData":
env = {DISABLE_COLOR_ENV_VAR: "true"}
attrs = {"_context_line_count": 0}
with env_context(**env), obj_attrs(self, attrs):
return "\n".join(self.diff_lines(serialized_data, snapshot_data))

def diff_lines(
self, serialized_data: "SerializedData", snapshot_data: "SerializedData"
) -> Iterator[str]:
Expand All @@ -250,10 +265,6 @@ def diff_lines(
def _ends(self) -> Dict[str, str]:
return {"\n": self._marker_new_line, "\r": self._marker_carriage}

@property
def _context_line_count(self) -> int:
return 1

@property
def _context_line_max(self) -> int:
return self._context_line_count * 2
Expand Down
6 changes: 3 additions & 3 deletions src/syrupy/extensions/single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
TYPE_CHECKING,
Optional,
Set,
Union,
)
from unicodedata import category

Expand All @@ -21,6 +20,7 @@
PropertyMatcher,
SerializableData,
SerializedData,
SnapshotIndex,
)


Expand All @@ -34,7 +34,7 @@ def serialize(
) -> "SerializedData":
return bytes(data)

def get_snapshot_name(self, *, index: Union[str, int] = 0) -> str:
def get_snapshot_name(self, *, index: "SnapshotIndex" = 0) -> str:
return self.__clean_filename(
super(SingleFileSnapshotExtension, self).get_snapshot_name(index=index)
)
Expand All @@ -48,7 +48,7 @@ def delete_snapshots(
def _file_extension(self) -> str:
return "raw"

def _get_file_basename(self, *, index: Union[str, int]) -> str:
def _get_file_basename(self, *, index: "SnapshotIndex") -> str:
return self.get_snapshot_name(index=index)

@property
Expand Down
1 change: 1 addition & 0 deletions src/syrupy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Union,
)

SnapshotIndex = Union[int, str]
SerializableData = Any
SerializedData = Union[str, bytes]
PropertyName = Hashable
Expand Down
15 changes: 15 additions & 0 deletions src/syrupy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
from typing import (
Any,
Dict,
Iterator,
)

Expand Down Expand Up @@ -66,3 +67,17 @@ def env_context(**kwargs: str) -> Iterator[None]:
finally:
os.environ.clear()
os.environ.update(prev_env)


def set_attrs(obj: Any, attrs: Dict[str, Any]) -> Any:
for k in attrs:
setattr(obj, k, attrs[k])


@contextmanager
def obj_attrs(obj: Any, attrs: Dict[str, Any]) -> Iterator[None]:
prev_attrs = {k: getattr(obj, k, None) for k in attrs}
try:
yield set_attrs(obj, attrs)
finally:
set_attrs(obj, prev_attrs)
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# name: test_snapshot_diff
dict({
'field_0': True,
'field_1': 'no_value',
'nested': dict({
'field_0': 1,
}),
})
# ---
# name: test_snapshot_diff.1
...
- 'field_1': 'no_value',
+ 'field_1': 'yes_value',
...
# ---
# name: test_snapshot_diff.2
...
- 'field_1': 'no_value',
+ 'field_1': 'yes_value',
...
- 'field_0': 1,
+ 'field_0': 2,
...
# ---
# name: test_snapshot_diff_id.1
...
- 'field_1': 'no_value',
+ 'field_1': 'yes_value',
...
- True,
...
- None,
+ False,
...
- 'no',
+ 'yes',
- False,
+ 0,
...
# ---
# name: test_snapshot_diff_id[case3]
...
- 'field_0': True,
+ 'field_0': 2,
...
- 'field_0': True,
+ 'field_0': 2,
...
# ---
# name: test_snapshot_diff_id[large snapshot]
dict({
'field_0': True,
'field_1': 'no_value',
'field_2': 0,
'field_3': None,
'field_4': 1,
'field_5': False,
'field_6': tuple(
True,
'hey',
2,
None,
),
'field_7': set({
'no',
False,
None,
}),
'nested_0': dict({
'field_0': True,
'field_1': 'no_value',
'field_2': 0,
'field_3': None,
'field_4': 1,
'field_5': False,
'field_6': tuple(
True,
'hey',
2,
None,
),
'field_7': set({
'no',
False,
None,
}),
}),
'nested_1': dict({
'field_0': True,
'field_1': 'no_value',
'field_2': 0,
'field_3': None,
'field_4': 1,
'field_5': False,
'field_6': tuple(
True,
'hey',
2,
None,
),
'field_7': set({
'no',
False,
None,
}),
}),
})
# ---
Loading

0 comments on commit 1440e9a

Please sign in to comment.