diff --git a/src/syrupy/assertion.py b/src/syrupy/assertion.py index b66cec75..cc43fa71 100644 --- a/src/syrupy/assertion.py +++ b/src/syrupy/assertion.py @@ -26,6 +26,7 @@ PropertyMatcher, SerializableData, SerializedData, + SnapshotIndex, ) @@ -169,10 +170,11 @@ def __with_prop(self, prop_name: str, prop_value: Any) -> None: def __call__( self, *, + diff: Optional["SnapshotIndex"] = None, exclude: Optional["PropertyFilter"] = None, extension_class: Optional[Type["AbstractSyrupyExtension"]] = None, matcher: Optional["PropertyMatcher"] = None, - name: Optional[str] = None, + name: Optional["SnapshotIndex"] = None, ) -> "SnapshotAssertion": """ Modifies assertion instance options @@ -185,6 +187,8 @@ def __call__( self.__with_prop("_matcher", matcher) if name: self.__with_prop("_custom_index", name) + if diff is not None: + self.__with_prop("_snapshot_diff", diff) return self def __dir__(self) -> List[str]: @@ -202,8 +206,17 @@ def _assert(self, data: "SerializableData") -> bool: assertion_success = False assertion_exception = None try: - snapshot_data = self._recall_data() + snapshot_data = self._recall_data(index=self.index) serialized_data = self._serialize(data) + snapshot_diff = getattr(self, "_snapshot_diff", None) + if snapshot_diff is not None: + snapshot_data_diff = self._recall_data(index=snapshot_diff) + if snapshot_data_diff is None: + raise SnapshotDoesNotExist() + serialized_data = self.extension.diff_snapshots( + serialized_data=serialized_data, + snapshot_data=snapshot_data_diff, + ) matches = snapshot_data is not None and self.extension.matches( serialized_data=serialized_data, snapshot_data=snapshot_data ) @@ -241,8 +254,8 @@ def _post_assert(self) -> None: while self._post_assert_actions: self._post_assert_actions.pop()() - def _recall_data(self) -> Optional["SerializableData"]: + def _recall_data(self, index: "SnapshotIndex") -> Optional["SerializableData"]: try: - return self.extension.read_snapshot(index=self.index) + return self.extension.read_snapshot(index=index) except SnapshotDoesNotExist: return None diff --git a/src/syrupy/extensions/amber/serializer.py b/src/syrupy/extensions/amber/serializer.py index b9ef4879..5469a0dc 100644 --- a/src/syrupy/extensions/amber/serializer.py +++ b/src/syrupy/extensions/amber/serializer.py @@ -1,6 +1,9 @@ import functools import os -from types import GeneratorType +from types import ( + GeneratorType, + MappingProxyType, +) from typing import ( TYPE_CHECKING, Any, @@ -163,7 +166,7 @@ def _serialize( serialize_method = cls.serialize_number elif isinstance(data, (set, frozenset)): serialize_method = cls.serialize_set - elif isinstance(data, dict): + elif isinstance(data, (dict, MappingProxyType)): serialize_method = cls.serialize_dict elif cls.__is_namedtuple(data): serialize_method = cls.serialize_namedtuple diff --git a/src/syrupy/extensions/base.py b/src/syrupy/extensions/base.py index 419f1c2d..57298b73 100644 --- a/src/syrupy/extensions/base.py +++ b/src/syrupy/extensions/base.py @@ -15,10 +15,10 @@ List, Optional, Set, - Union, ) from syrupy.constants import ( + DISABLE_COLOR_ENV_VAR, SNAPSHOT_DIRNAME, SYMBOL_CARRIAGE, SYMBOL_ELLIPSIS, @@ -40,7 +40,11 @@ snapshot_diff_style, snapshot_style, ) -from syrupy.utils import walk_snapshot_dir +from syrupy.utils import ( + env_context, + obj_attrs, + walk_snapshot_dir, +) if TYPE_CHECKING: from syrupy.location import PyTestLocation @@ -49,6 +53,7 @@ PropertyMatcher, SerializableData, SerializedData, + SnapshotIndex, ) @@ -74,7 +79,7 @@ class SnapshotFossilizer(ABC): def test_location(self) -> "PyTestLocation": raise NotImplementedError - def get_snapshot_name(self, *, index: Union[str, int] = 0) -> str: + def get_snapshot_name(self, *, index: "SnapshotIndex" = 0) -> str: """Get the snapshot name for the assertion index in a test location""" index_suffix = "" if isinstance(index, (str,)): @@ -83,7 +88,7 @@ def get_snapshot_name(self, *, index: Union[str, int] = 0) -> str: index_suffix = f".{index}" return f"{self.test_location.snapshot_name}{index_suffix}" - def get_location(self, *, index: Union[str, int]) -> str: + def get_location(self, *, index: "SnapshotIndex") -> str: """Returns full location where snapshot data is stored.""" basename = self._get_file_basename(index=index) fileext = f".{self._file_extension}" if self._file_extension else "" @@ -110,7 +115,7 @@ def discover_snapshots(self) -> "SnapshotFossils": return discovered - def read_snapshot(self, *, index: Union[str, int]) -> "SerializedData": + def read_snapshot(self, *, index: "SnapshotIndex") -> "SerializedData": """ Utility method for reading the contents of a snapshot assertion. Will call `_pre_read`, then perform `read` and finally `post_read`, @@ -132,7 +137,7 @@ def read_snapshot(self, *, index: Union[str, int]) -> "SerializedData": finally: self._post_read(index=index) - def write_snapshot(self, *, data: "SerializedData", index: Union[str, int]) -> None: + def write_snapshot(self, *, data: "SerializedData", index: "SnapshotIndex") -> None: """ Utility method for writing the contents of a snapshot assertion. Will call `_pre_write`, then perform `write` and finally `_post_write`. @@ -178,17 +183,17 @@ def delete_snapshots( """ raise NotImplementedError - def _pre_read(self, *, index: Union[str, int] = 0) -> None: + def _pre_read(self, *, index: "SnapshotIndex" = 0) -> None: pass - def _post_read(self, *, index: Union[str, int] = 0) -> None: + def _post_read(self, *, index: "SnapshotIndex" = 0) -> None: pass - def _pre_write(self, *, data: "SerializedData", index: Union[str, int] = 0) -> None: + def _pre_write(self, *, data: "SerializedData", index: "SnapshotIndex" = 0) -> None: self.__ensure_snapshot_dir(index=index) def _post_write( - self, *, data: "SerializedData", index: Union[str, int] = 0 + self, *, data: "SerializedData", index: "SnapshotIndex" = 0 ) -> None: pass @@ -225,11 +230,11 @@ def _dirname(self) -> str: def _file_extension(self) -> str: raise NotImplementedError - def _get_file_basename(self, *, index: Union[str, int]) -> str: + def _get_file_basename(self, *, index: "SnapshotIndex") -> str: """Returns file basename without extension. Used to create full filepath.""" return self.test_location.filename - def __ensure_snapshot_dir(self, *, index: Union[str, int]) -> None: + def __ensure_snapshot_dir(self, *, index: "SnapshotIndex") -> None: """ Ensures the folder path for the snapshot file exists. """ @@ -240,6 +245,16 @@ def __ensure_snapshot_dir(self, *, index: Union[str, int]) -> None: class SnapshotReporter(ABC): + _context_line_count = 1 + + def diff_snapshots( + self, serialized_data: "SerializedData", snapshot_data: "SerializedData" + ) -> "SerializedData": + env = {DISABLE_COLOR_ENV_VAR: "true"} + attrs = {"_context_line_count": 0} + with env_context(**env), obj_attrs(self, attrs): + return "\n".join(self.diff_lines(serialized_data, snapshot_data)) + def diff_lines( self, serialized_data: "SerializedData", snapshot_data: "SerializedData" ) -> Iterator[str]: @@ -250,10 +265,6 @@ def diff_lines( def _ends(self) -> Dict[str, str]: return {"\n": self._marker_new_line, "\r": self._marker_carriage} - @property - def _context_line_count(self) -> int: - return 1 - @property def _context_line_max(self) -> int: return self._context_line_count * 2 diff --git a/src/syrupy/extensions/single_file.py b/src/syrupy/extensions/single_file.py index 48986fa7..52ac234f 100644 --- a/src/syrupy/extensions/single_file.py +++ b/src/syrupy/extensions/single_file.py @@ -4,7 +4,6 @@ TYPE_CHECKING, Optional, Set, - Union, ) from unicodedata import category @@ -21,6 +20,7 @@ PropertyMatcher, SerializableData, SerializedData, + SnapshotIndex, ) @@ -34,7 +34,7 @@ def serialize( ) -> "SerializedData": return bytes(data) - def get_snapshot_name(self, *, index: Union[str, int] = 0) -> str: + def get_snapshot_name(self, *, index: "SnapshotIndex" = 0) -> str: return self.__clean_filename( super(SingleFileSnapshotExtension, self).get_snapshot_name(index=index) ) @@ -48,7 +48,7 @@ def delete_snapshots( def _file_extension(self) -> str: return "raw" - def _get_file_basename(self, *, index: Union[str, int]) -> str: + def _get_file_basename(self, *, index: "SnapshotIndex") -> str: return self.get_snapshot_name(index=index) @property diff --git a/src/syrupy/types.py b/src/syrupy/types.py index 7591e706..6366ab41 100644 --- a/src/syrupy/types.py +++ b/src/syrupy/types.py @@ -8,6 +8,7 @@ Union, ) +SnapshotIndex = Union[int, str] SerializableData = Any SerializedData = Union[str, bytes] PropertyName = Hashable diff --git a/src/syrupy/utils.py b/src/syrupy/utils.py index 86370921..6b024c85 100644 --- a/src/syrupy/utils.py +++ b/src/syrupy/utils.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import ( Any, + Dict, Iterator, ) @@ -66,3 +67,17 @@ def env_context(**kwargs: str) -> Iterator[None]: finally: os.environ.clear() os.environ.update(prev_env) + + +def set_attrs(obj: Any, attrs: Dict[str, Any]) -> Any: + for k in attrs: + setattr(obj, k, attrs[k]) + + +@contextmanager +def obj_attrs(obj: Any, attrs: Dict[str, Any]) -> Iterator[None]: + prev_attrs = {k: getattr(obj, k, None) for k in attrs} + try: + yield set_attrs(obj, attrs) + finally: + set_attrs(obj, prev_attrs) diff --git a/tests/syrupy/extensions/amber/__snapshots__/test_amber_snapshot_diff.ambr b/tests/syrupy/extensions/amber/__snapshots__/test_amber_snapshot_diff.ambr new file mode 100644 index 00000000..e9e87972 --- /dev/null +++ b/tests/syrupy/extensions/amber/__snapshots__/test_amber_snapshot_diff.ambr @@ -0,0 +1,112 @@ +# name: test_snapshot_diff + dict({ + 'field_0': True, + 'field_1': 'no_value', + 'nested': dict({ + 'field_0': 1, + }), + }) +# --- +# name: test_snapshot_diff.1 + ... + - 'field_1': 'no_value', + + 'field_1': 'yes_value', + ... +# --- +# name: test_snapshot_diff.2 + ... + - 'field_1': 'no_value', + + 'field_1': 'yes_value', + ... + - 'field_0': 1, + + 'field_0': 2, + ... +# --- +# name: test_snapshot_diff_id.1 + ... + - 'field_1': 'no_value', + + 'field_1': 'yes_value', + ... + - True, + ... + - None, + + False, + ... + - 'no', + + 'yes', + - False, + + 0, + ... +# --- +# name: test_snapshot_diff_id[case3] + ... + - 'nested_0': dict({ + + 'nested_0': mappingproxy({ + - 'field_0': True, + + 'field_0': 2, + ... + - 'nested_1': dict({ + + 'nested_1': mappingproxy({ + - 'field_0': True, + + 'field_0': 2, + ... +# --- +# name: test_snapshot_diff_id[large snapshot] + dict({ + 'field_0': True, + 'field_1': 'no_value', + 'field_2': 0, + 'field_3': None, + 'field_4': 1, + 'field_5': False, + 'field_6': tuple( + True, + 'hey', + 2, + None, + ), + 'field_7': set({ + 'no', + False, + None, + }), + 'nested_0': dict({ + 'field_0': True, + 'field_1': 'no_value', + 'field_2': 0, + 'field_3': None, + 'field_4': 1, + 'field_5': False, + 'field_6': tuple( + True, + 'hey', + 2, + None, + ), + 'field_7': set({ + 'no', + False, + None, + }), + }), + 'nested_1': dict({ + 'field_0': True, + 'field_1': 'no_value', + 'field_2': 0, + 'field_3': None, + 'field_4': 1, + 'field_5': False, + 'field_6': tuple( + True, + 'hey', + 2, + None, + ), + 'field_7': set({ + 'no', + False, + None, + }), + }), + }) +# --- diff --git a/tests/syrupy/extensions/amber/test_amber_snapshot_diff.py b/tests/syrupy/extensions/amber/test_amber_snapshot_diff.py new file mode 100644 index 00000000..e5bac11a --- /dev/null +++ b/tests/syrupy/extensions/amber/test_amber_snapshot_diff.py @@ -0,0 +1,64 @@ +from types import MappingProxyType + +import pytest + + +def test_snapshot_diff(snapshot): + my_dict = { + "field_0": True, + "field_1": "no_value", + "nested": { + "field_0": 1, + }, + } + assert my_dict == snapshot + my_dict["field_1"] = "yes_value" + assert my_dict == snapshot(diff=0) + my_dict["nested"]["field_0"] = 2 + assert my_dict == snapshot(diff=0) + + +def test_snapshot_diff_id(snapshot): + my_dict = { + "field_0": True, + "field_1": "no_value", + "field_2": 0, + "field_3": None, + "field_4": 1, + "field_5": False, + "field_6": (True, "hey", 2, None), + "field_7": {False, "no", 0, None}, + } + dictLargeSnapshot = dict( + { + **my_dict, + "nested_0": dict(my_dict), + "nested_1": dict(my_dict), + } + ) + assert dictLargeSnapshot == snapshot(name="large snapshot") + dictDiffLargeSnapshot = dict( + { + **dictLargeSnapshot, + "field_1": "yes_value", + "field_6": ("hey", 2, False), + "field_7": {"yes", 0, None}, + } + ) + assert dictDiffLargeSnapshot == snapshot(diff="large snapshot") + dictCase3 = dict( + { + **dictLargeSnapshot, + "nested_0": MappingProxyType({**my_dict, "field_0": 2}), + "nested_1": MappingProxyType({**my_dict, "field_0": 2}), + } + ) + assert dictCase3 == snapshot(name="case3", diff="large snapshot") + + +def test_snapshot_no_diff_raises_exception(snapshot): + my_dict = { + "field_0": "value_0", + } + with pytest.raises(AssertionError, match="SnapshotDoesNotExist"): + assert my_dict == snapshot(diff="does not exist index")