Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add snapshot diffing support #526

Merged
merged 3 commits into from
May 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,23 @@ def test_foo(snapshot):
assert actual_svg == snapshot(extension_class=SVGImageSnapshotExtension)
```

#### `diff`

This is an option to snapshot only the diff between the actual object and a previous snapshot, with the `diff` argument being the previous snapshot `index`/`name`.

```py
def test_diff(snapshot):
actual0 = [1,2,3,4]
actual1 = [0,1,3,4]

assert actual0 == snapshot
assert actual1 == snapshot(diff=0)
# This is equivalent to the lines above
# Must use the index name to diff when given
assert actual0 == snapshot(name='snap_name')
assert actual1 == snapshot(diff='snap_name')
```

##### Built-In Extensions

Syrupy comes with a few built-in preset configurations for you to choose from. You should also feel free to extend the `AbstractSyrupyExtension` if your project has a need not captured by one our built-ins.
Expand Down Expand Up @@ -295,7 +312,7 @@ from syrupy.extensions.json import JSONSnapshotExtension

@pytest.fixture
def snapshot_json(snapshot):
return snapshot.use_extension(JSONSnapshotExtension)
return snapshot.use_extension(JSONSnapshotExtension)


def test_api_call(client, snapshot_json):
Expand Down Expand Up @@ -400,5 +417,4 @@ This section is automatically generated via tagging the all-contributors bot in

## License


Syrupy is licensed under [Apache License Version 2.0](https://github.com/tophat/syrupy/tree/master/LICENSE).
24 changes: 18 additions & 6 deletions src/syrupy/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
List,
Optional,
Type,
Union,
)

from .exceptions import SnapshotDoesNotExist
Expand All @@ -26,6 +25,7 @@
PropertyMatcher,
SerializableData,
SerializedData,
SnapshotIndex,
)


Expand Down Expand Up @@ -108,7 +108,7 @@ def executions(self) -> Dict[int, AssertionResult]:
return self._execution_results

@property
def index(self) -> Union[str, int]:
def index(self) -> "SnapshotIndex":
if self._custom_index:
return self._custom_index
return self.num_executions
Expand Down Expand Up @@ -169,10 +169,11 @@ def __with_prop(self, prop_name: str, prop_value: Any) -> None:
def __call__(
self,
*,
diff: Optional["SnapshotIndex"] = None,
exclude: Optional["PropertyFilter"] = None,
extension_class: Optional[Type["AbstractSyrupyExtension"]] = None,
matcher: Optional["PropertyMatcher"] = None,
name: Optional[str] = None,
name: Optional["SnapshotIndex"] = None,
) -> "SnapshotAssertion":
"""
Modifies assertion instance options
Expand All @@ -185,6 +186,8 @@ def __call__(
self.__with_prop("_matcher", matcher)
if name:
self.__with_prop("_custom_index", name)
if diff is not None:
self.__with_prop("_snapshot_diff", diff)
return self

def __dir__(self) -> List[str]:
Expand All @@ -202,8 +205,17 @@ def _assert(self, data: "SerializableData") -> bool:
assertion_success = False
assertion_exception = None
try:
snapshot_data = self._recall_data()
snapshot_data = self._recall_data(index=self.index)
serialized_data = self._serialize(data)
snapshot_diff = getattr(self, "_snapshot_diff", None)
if snapshot_diff is not None:
snapshot_data_diff = self._recall_data(index=snapshot_diff)
if snapshot_data_diff is None:
raise SnapshotDoesNotExist()
serialized_data = self.extension.diff_snapshots(
serialized_data=serialized_data,
snapshot_data=snapshot_data_diff,
)
matches = snapshot_data is not None and self.extension.matches(
serialized_data=serialized_data, snapshot_data=snapshot_data
)
Expand Down Expand Up @@ -241,8 +253,8 @@ def _post_assert(self) -> None:
while self._post_assert_actions:
self._post_assert_actions.pop()()

def _recall_data(self) -> Optional["SerializableData"]:
def _recall_data(self, index: "SnapshotIndex") -> Optional["SerializableData"]:
try:
return self.extension.read_snapshot(index=self.index)
return self.extension.read_snapshot(index=index)
except SnapshotDoesNotExist:
return None
7 changes: 5 additions & 2 deletions src/syrupy/extensions/amber/serializer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import functools
import os
from types import GeneratorType
from types import (
GeneratorType,
MappingProxyType,
)
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -163,7 +166,7 @@ def _serialize(
serialize_method = cls.serialize_number
elif isinstance(data, (set, frozenset)):
serialize_method = cls.serialize_set
elif isinstance(data, dict):
elif isinstance(data, (dict, MappingProxyType)):
serialize_method = cls.serialize_dict
elif cls.__is_namedtuple(data):
serialize_method = cls.serialize_namedtuple
Expand Down
43 changes: 27 additions & 16 deletions src/syrupy/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
List,
Optional,
Set,
Union,
)

from syrupy.constants import (
DISABLE_COLOR_ENV_VAR,
SNAPSHOT_DIRNAME,
SYMBOL_CARRIAGE,
SYMBOL_ELLIPSIS,
Expand All @@ -40,7 +40,11 @@
snapshot_diff_style,
snapshot_style,
)
from syrupy.utils import walk_snapshot_dir
from syrupy.utils import (
env_context,
obj_attrs,
walk_snapshot_dir,
)

if TYPE_CHECKING:
from syrupy.location import PyTestLocation
Expand All @@ -49,6 +53,7 @@
PropertyMatcher,
SerializableData,
SerializedData,
SnapshotIndex,
)


Expand All @@ -74,7 +79,7 @@ class SnapshotFossilizer(ABC):
def test_location(self) -> "PyTestLocation":
raise NotImplementedError

def get_snapshot_name(self, *, index: Union[str, int] = 0) -> str:
def get_snapshot_name(self, *, index: "SnapshotIndex" = 0) -> str:
"""Get the snapshot name for the assertion index in a test location"""
index_suffix = ""
if isinstance(index, (str,)):
Expand All @@ -83,7 +88,7 @@ def get_snapshot_name(self, *, index: Union[str, int] = 0) -> str:
index_suffix = f".{index}"
return f"{self.test_location.snapshot_name}{index_suffix}"

def get_location(self, *, index: Union[str, int]) -> str:
def get_location(self, *, index: "SnapshotIndex") -> str:
"""Returns full location where snapshot data is stored."""
basename = self._get_file_basename(index=index)
fileext = f".{self._file_extension}" if self._file_extension else ""
Expand All @@ -110,7 +115,7 @@ def discover_snapshots(self) -> "SnapshotFossils":

return discovered

def read_snapshot(self, *, index: Union[str, int]) -> "SerializedData":
def read_snapshot(self, *, index: "SnapshotIndex") -> "SerializedData":
"""
Utility method for reading the contents of a snapshot assertion.
Will call `_pre_read`, then perform `read` and finally `post_read`,
Expand All @@ -132,7 +137,7 @@ def read_snapshot(self, *, index: Union[str, int]) -> "SerializedData":
finally:
self._post_read(index=index)

def write_snapshot(self, *, data: "SerializedData", index: Union[str, int]) -> None:
def write_snapshot(self, *, data: "SerializedData", index: "SnapshotIndex") -> None:
"""
Utility method for writing the contents of a snapshot assertion.
Will call `_pre_write`, then perform `write` and finally `_post_write`.
Expand Down Expand Up @@ -178,17 +183,17 @@ def delete_snapshots(
"""
raise NotImplementedError

def _pre_read(self, *, index: Union[str, int] = 0) -> None:
def _pre_read(self, *, index: "SnapshotIndex" = 0) -> None:
pass

def _post_read(self, *, index: Union[str, int] = 0) -> None:
def _post_read(self, *, index: "SnapshotIndex" = 0) -> None:
pass

def _pre_write(self, *, data: "SerializedData", index: Union[str, int] = 0) -> None:
def _pre_write(self, *, data: "SerializedData", index: "SnapshotIndex" = 0) -> None:
self.__ensure_snapshot_dir(index=index)

def _post_write(
self, *, data: "SerializedData", index: Union[str, int] = 0
self, *, data: "SerializedData", index: "SnapshotIndex" = 0
) -> None:
pass

Expand Down Expand Up @@ -225,11 +230,11 @@ def _dirname(self) -> str:
def _file_extension(self) -> str:
raise NotImplementedError

def _get_file_basename(self, *, index: Union[str, int]) -> str:
def _get_file_basename(self, *, index: "SnapshotIndex") -> str:
"""Returns file basename without extension. Used to create full filepath."""
return self.test_location.filename

def __ensure_snapshot_dir(self, *, index: Union[str, int]) -> None:
def __ensure_snapshot_dir(self, *, index: "SnapshotIndex") -> None:
"""
Ensures the folder path for the snapshot file exists.
"""
Expand All @@ -240,6 +245,16 @@ def __ensure_snapshot_dir(self, *, index: Union[str, int]) -> None:


class SnapshotReporter(ABC):
_context_line_count = 1

def diff_snapshots(
self, serialized_data: "SerializedData", snapshot_data: "SerializedData"
) -> "SerializedData":
env = {DISABLE_COLOR_ENV_VAR: "true"}
attrs = {"_context_line_count": 0}
with env_context(**env), obj_attrs(self, attrs):
return "\n".join(self.diff_lines(serialized_data, snapshot_data))

def diff_lines(
self, serialized_data: "SerializedData", snapshot_data: "SerializedData"
) -> Iterator[str]:
Expand All @@ -250,10 +265,6 @@ def diff_lines(
def _ends(self) -> Dict[str, str]:
return {"\n": self._marker_new_line, "\r": self._marker_carriage}

@property
def _context_line_count(self) -> int:
return 1

@property
def _context_line_max(self) -> int:
return self._context_line_count * 2
Expand Down
6 changes: 3 additions & 3 deletions src/syrupy/extensions/single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
TYPE_CHECKING,
Optional,
Set,
Union,
)
from unicodedata import category

Expand All @@ -21,6 +20,7 @@
PropertyMatcher,
SerializableData,
SerializedData,
SnapshotIndex,
)


Expand All @@ -34,7 +34,7 @@ def serialize(
) -> "SerializedData":
return bytes(data)

def get_snapshot_name(self, *, index: Union[str, int] = 0) -> str:
def get_snapshot_name(self, *, index: "SnapshotIndex" = 0) -> str:
return self.__clean_filename(
super(SingleFileSnapshotExtension, self).get_snapshot_name(index=index)
)
Expand All @@ -48,7 +48,7 @@ def delete_snapshots(
def _file_extension(self) -> str:
return "raw"

def _get_file_basename(self, *, index: Union[str, int]) -> str:
def _get_file_basename(self, *, index: "SnapshotIndex") -> str:
return self.get_snapshot_name(index=index)

@property
Expand Down
1 change: 1 addition & 0 deletions src/syrupy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Union,
)

SnapshotIndex = Union[int, str]
SerializableData = Any
SerializedData = Union[str, bytes]
PropertyName = Hashable
Expand Down
15 changes: 15 additions & 0 deletions src/syrupy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
from typing import (
Any,
Dict,
Iterator,
)

Expand Down Expand Up @@ -66,3 +67,17 @@ def env_context(**kwargs: str) -> Iterator[None]:
finally:
os.environ.clear()
os.environ.update(prev_env)


def set_attrs(obj: Any, attrs: Dict[str, Any]) -> Any:
for k in attrs:
setattr(obj, k, attrs[k])


@contextmanager
def obj_attrs(obj: Any, attrs: Dict[str, Any]) -> Iterator[None]:
prev_attrs = {k: getattr(obj, k, None) for k in attrs}
try:
yield set_attrs(obj, attrs)
finally:
set_attrs(obj, prev_attrs)
5 changes: 2 additions & 3 deletions tests/examples/test_custom_snapshot_name.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
"""
Example: Custom Snapshot Name
"""
from typing import Union

import pytest

from syrupy.extensions.amber import AmberSnapshotExtension
from syrupy.types import SnapshotIndex


class CanadianNameExtension(AmberSnapshotExtension):
def get_snapshot_name(self, *, index: Union[str, int]) -> str:
def get_snapshot_name(self, *, index: "SnapshotIndex") -> str:
original_name = super(CanadianNameExtension, self).get_snapshot_name(
index=index
)
Expand Down
Loading