Skip to content

Commit

Permalink
fix: support ignoring fields when serializing (#262)
Browse files Browse the repository at this point in the history
* wip: property filter

* wip: property filter

* test: excludes property before accessing value

* test: iter index filter

* docs: show how exclude is used

* chore: update wording

* chore: update wording
  • Loading branch information
iamogbz authored Jun 12, 2020
1 parent c3ca255 commit f67268e
Show file tree
Hide file tree
Showing 11 changed files with 212 additions and 59 deletions.
46 changes: 45 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ It should return the replacement value to be serialized or the original unmutate

| Argument | Description |
| -------- | ------------------------------------------------------------------------------------------------------------------ |
| `data` | Current serializable value being matched on |
| `data` | Current serializable value being matched on |
| `path` | Ordered path traversed to the current value e.g. `(("a", dict), ("b", dict))` from `{ "a": { "b": { "c": 1 } } }`} |

**NOTE:** Do not mutate the value received as it could cause unintended side effects.
Expand Down Expand Up @@ -138,6 +138,50 @@ def test_bar(snapshot):
---
```

#### `exclude`

This allows you to filter out object properties from the serialized snapshot.

The exclude parameter takes a filter function that accepts two keyword arguments.
It should return `true` or `false` if the property should be excluded or included respectively.

| Argument | Description |
| -------- | --------------------------------------------------------------------------------------------------------------------------------------------- |
| `prop` | Current property on the object, could be any hashable value that can be used to retrieve a value e.g. `1`, `"prop_str"`, `SomeHashableObject` |
| `path` | Ordered path traversed to the current value e.g. `(("a", dict), ("b", dict))` from `{ "a": { "b": { "c": 1 } } }`} |

##### Built-In Filters

Syrupy comes with built-in helpers that can be used to make easy work of using the filter options.

###### `paths(path_string, *path_strings)`

Easy way to build a filter that uses full path strings delimited with `.`.

Takes an argument list of path strings.

```py
from syrupy.filters import paths

def test_bar(snapshot):
actual = {
"date": datetime.now(),
"list": [1,2,3],
}
assert actual == snapshot(exclude=paths("date_created", "list.1"))
```

```ambr
# name: test_bar
<class 'dict'> {
'list': <class 'list'> [
1,
3,
],
}
---
```

#### `extension_class`

This is a way to modify how the snapshot matches and serializes your data in a single assertion.
Expand Down
31 changes: 16 additions & 15 deletions src/syrupy/assertion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from gettext import gettext
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Expand All @@ -17,7 +18,7 @@
from .location import TestLocation
from .extensions.base import AbstractSyrupyExtension
from .session import SnapshotSession
from .types import PropertyMatcher, SerializableData, SerializedData
from .types import PropertyFilter, PropertyMatcher, SerializableData, SerializedData


@attr.s
Expand All @@ -44,10 +45,11 @@ class SnapshotAssertion:
_extension_class: Type["AbstractSyrupyExtension"] = attr.ib(kw_only=True)
_test_location: "TestLocation" = attr.ib(kw_only=True)
_update_snapshots: bool = attr.ib(kw_only=True)
_exclude: Optional["PropertyFilter"] = attr.ib(init=False, default=None)
_extension: Optional["AbstractSyrupyExtension"] = attr.ib(init=False, default=None)
_matcher: Optional["PropertyMatcher"] = attr.ib(init=False, default=None)
_executions: int = attr.ib(init=False, default=0)
_execution_results: Dict[int, "AssertionResult"] = attr.ib(init=False, factory=dict)
_matcher: Optional["PropertyMatcher"] = attr.ib(init=False, default=None)
_post_assert_actions: List[Callable[..., None]] = attr.ib(init=False, factory=list)

def __attrs_post_init__(self) -> None:
Expand Down Expand Up @@ -90,7 +92,9 @@ def assert_match(self, data: "SerializableData") -> None:
assert self == data

def _serialize(self, data: "SerializableData") -> "SerializedData":
return self.extension.serialize(data, matcher=self._matcher)
return self.extension.serialize(
data, exclude=self._exclude, matcher=self._matcher
)

def get_assert_diff(self) -> List[str]:
assertion_result = self._execution_results[self.num_executions - 1]
Expand All @@ -103,29 +107,26 @@ def get_assert_diff(self) -> List[str]:
diff.extend(self.extension.diff_lines(serialized_data, snapshot_data or ""))
return diff

def __with_prop(self, prop_name: str, prop_value: Any) -> None:
setattr(self, prop_name, prop_value)
self._post_assert_actions.append(lambda: setattr(self, prop_name, None))

def __call__(
self,
*,
exclude: Optional["PropertyFilter"] = None,
extension_class: Optional[Type["AbstractSyrupyExtension"]] = None,
matcher: Optional["PropertyMatcher"] = None,
) -> "SnapshotAssertion":
"""
Modifies assertion instance options
"""
if exclude:
self.__with_prop("_exclude", exclude)
if extension_class:
self._extension = self.__init_extension(extension_class)

def clear_extension() -> None:
self._extension = None

self._post_assert_actions.append(clear_extension)
self.__with_prop("_extension", self.__init_extension(extension_class))
if matcher:
self._matcher = matcher

def clear_matcher() -> None:
self._matcher = None

self._post_assert_actions.append(clear_matcher)
self.__with_prop("_matcher", matcher)
return self

def __repr__(self) -> str:
Expand Down
9 changes: 4 additions & 5 deletions src/syrupy/extensions/amber/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Optional,
Set,
)
Expand All @@ -12,7 +13,7 @@


if TYPE_CHECKING:
from syrupy.types import PropertyMatcher, SerializableData
from syrupy.types import SerializableData


class AmberSnapshotExtension(AbstractSyrupyExtension):
Expand All @@ -28,14 +29,12 @@ class AmberSnapshotExtension(AbstractSyrupyExtension):
```
"""

def serialize(
self, data: "SerializableData", *, matcher: Optional["PropertyMatcher"] = None
) -> str:
def serialize(self, data: "SerializableData", **kwargs: Any) -> str:
"""
Returns the serialized form of 'data' to be compared
with the snapshot data written to disk.
"""
return DataSerializer.serialize(data, matcher=matcher)
return DataSerializer.serialize(data, **kwargs)

def delete_snapshots(
self, snapshot_location: str, snapshot_names: Set[str]
Expand Down
81 changes: 54 additions & 27 deletions src/syrupy/extensions/amber/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Optional,
Set,
Expand All @@ -17,12 +18,21 @@

if TYPE_CHECKING:
from syrupy.types import (
PropertyFilter,
PropertyMatcher,
PropertyName,
PropertyPath,
SerializableData,
)

PropertyValueFilter = Callable[[SerializableData], bool]
PropertyValueGetter = Callable[..., SerializableData]
IterableEntries = Tuple[
Iterable["PropertyName"],
"PropertyValueGetter",
Optional["PropertyValueFilter"],
]


class Repr:
def __init__(self, _repr: str):
Expand Down Expand Up @@ -90,6 +100,7 @@ def serialize(
data: "SerializableData",
*,
depth: int = 0,
exclude: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
path: "PropertyPath" = (),
visited: Optional[Set[Any]] = None,
Expand All @@ -103,6 +114,7 @@ def serialize(
serialize_kwargs = {
"data": data,
"depth": depth,
"exclude": exclude,
"matcher": matcher,
"path": path,
"visited": {*visited, data_id},
Expand Down Expand Up @@ -158,9 +170,10 @@ def serialize_iterable(cls, data: "SerializableData", **kwargs: Any) -> str:
}.items()
if isinstance(data, iter_type)
)
values = list(data)
return cls.__serialize_iterable(
data=data,
entries=enumerate(data),
resolve_entries=(range(len(values)), lambda _, p: values[p], None),
open_tag=open_paren,
close_tag=close_paren,
**kwargs,
Expand All @@ -170,7 +183,7 @@ def serialize_iterable(cls, data: "SerializableData", **kwargs: Any) -> str:
def serialize_set(cls, data: "SerializableData", **kwargs: Any) -> str:
return cls.__serialize_iterable(
data=data,
entries=((d, d) for d in cls.sort(data)),
resolve_entries=(cls.sort(data), lambda _, p: p, None),
open_tag="{",
close_tag="}",
**kwargs,
Expand All @@ -180,7 +193,7 @@ def serialize_set(cls, data: "SerializableData", **kwargs: Any) -> str:
def serialize_namedtuple(cls, data: "SerializableData", **kwargs: Any) -> str:
return cls.__serialize_iterable(
data=data,
entries=((name, getattr(data, name)) for name in cls.sort(data._fields)),
resolve_entries=(cls.sort(data._fields), getattr, None),
open_tag="(",
close_tag=")",
separator="=",
Expand All @@ -191,7 +204,7 @@ def serialize_namedtuple(cls, data: "SerializableData", **kwargs: Any) -> str:
def serialize_dict(cls, data: "SerializableData", **kwargs: Any) -> str:
return cls.__serialize_iterable(
data=data,
entries=((key, data[key]) for key in cls.sort(data.keys())),
resolve_entries=(cls.sort(data.keys()), lambda d, p: d[p], None),
open_tag="{",
close_tag="}",
separator=": ",
Expand All @@ -206,10 +219,10 @@ def serialize_unknown(cls, data: Any, *, depth: int = 0, **kwargs: Any) -> str:

return cls.__serialize_iterable(
data=data,
entries=(
(name, getattr(data, name))
for name in cls.sort(dir(data))
if not name.startswith("_") and not callable(getattr(data, name))
resolve_entries=(
(name for name in cls.sort(dir(data)) if not name.startswith("_")),
getattr,
lambda v: not callable(v),
),
depth=depth,
open_tag="{",
Expand Down Expand Up @@ -239,38 +252,35 @@ def __is_namedtuple(cls, obj: Any) -> bool:
type(n) == str for n in getattr(obj, "_fields", [None])
)

@classmethod
def __serialize_lines(
cls,
*,
data: "SerializableData",
lines: Iterable[str],
open_tag: str,
close_tag: str,
depth: int = 0,
include_type: bool = True,
ends: str = "\n",
) -> str:
return (
f"{cls.with_indent(cls.object_type(data), depth)} " if include_type else ""
) + f"{open_tag}\n{ends.join(lines)}\n{cls.with_indent(close_tag, depth)}"

@classmethod
def __serialize_iterable(
cls,
*,
data: "SerializableData",
entries: Iterable[Tuple["PropertyName", "SerializableData"]],
resolve_entries: "IterableEntries",
open_tag: str,
close_tag: str,
depth: int = 0,
exclude: Optional["PropertyFilter"] = None,
path: "PropertyPath" = (),
separator: Optional[str] = None,
serialize_key: bool = False,
**kwargs: Any,
) -> str:
kwargs["depth"] = depth + 1

keys, get_value, include_value = resolve_entries
key_values = (
(key, get_value(data, key))
for key in keys
if not exclude or not exclude(prop=key, path=path)
)
entries = (
entry
for entry in key_values
if not include_value or include_value(entry[1])
)

def key_str(key: "PropertyName") -> str:
if separator is None:
return ""
Expand All @@ -281,8 +291,9 @@ def key_str(key: "PropertyName") -> str:
) + separator

def value_str(key: "PropertyName", value: "SerializableData") -> str:
_path = (*path, (key, type(value)))
serialized = cls.serialize(data=value, path=_path, **kwargs)
serialized = cls.serialize(
data=value, exclude=exclude, path=(*path, (key, type(value))), **kwargs
)
return serialized if separator is None else serialized.lstrip(cls._indent)

return cls.__serialize_lines(
Expand All @@ -292,3 +303,19 @@ def value_str(key: "PropertyName", value: "SerializableData") -> str:
open_tag=open_tag,
close_tag=close_tag,
)

@classmethod
def __serialize_lines(
cls,
*,
data: "SerializableData",
lines: Iterable[str],
open_tag: str,
close_tag: str,
depth: int = 0,
include_type: bool = True,
ends: str = "\n",
) -> str:
return (
f"{cls.with_indent(cls.object_type(data), depth)} " if include_type else ""
) + f"{open_tag}\n{ends.join(lines)}\n{cls.with_indent(close_tag, depth)}"
13 changes: 11 additions & 2 deletions src/syrupy/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,22 @@

if TYPE_CHECKING:
from syrupy.location import TestLocation
from syrupy.types import PropertyMatcher, SerializableData, SerializedData
from syrupy.types import (
PropertyFilter,
PropertyMatcher,
SerializableData,
SerializedData,
)


class SnapshotSerializer(ABC):
@abstractmethod
def serialize(
self, data: "SerializableData", *, matcher: Optional["PropertyMatcher"] = None,
self,
data: "SerializableData",
*,
exclude: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
) -> "SerializedData":
"""
Serializes a python object / data structure into a string
Expand Down
Loading

0 comments on commit f67268e

Please sign in to comment.