From d3f891ea4e561cd1b182e9b2c5d0414821187cd7 Mon Sep 17 00:00:00 2001 From: Noah Date: Mon, 28 Aug 2023 18:28:53 -0400 Subject: [PATCH] feat: add include option to snapshots, similar to exclude (#797) --- README.md | 47 ++++++++++++++++--- src/syrupy/assertion.py | 9 +++- src/syrupy/extensions/amber/serializer.py | 17 +++++-- src/syrupy/extensions/base.py | 1 + src/syrupy/extensions/json/__init__.py | 14 +++++- src/syrupy/extensions/single_file.py | 1 + .../__snapshots__/test_amber_filters.ambr | 15 ++++++ .../extensions/amber/test_amber_filters.py | 12 +++++ .../test_include_simple.1.json | 4 ++ .../test_include_simple.json | 4 ++ .../extensions/json/test_json_filters.py | 13 +++++ 11 files changed, 126 insertions(+), 11 deletions(-) create mode 100644 tests/syrupy/extensions/json/__snapshots__/test_json_filters/test_include_simple.1.json create mode 100644 tests/syrupy/extensions/json/__snapshots__/test_json_filters/test_include_simple.json diff --git a/README.md b/README.md index 42830af5..23f3b3ca 100644 --- a/README.md +++ b/README.md @@ -99,20 +99,32 @@ If you want to limit what properties are serialized at a class type level you co ```py def limit_foo_attrs(prop, path): - allowed_foo_attrs = {"only", "serialize", "these", "attrs"} - return isinstance(path[-1][1], Foo) and prop in allowed_foo_attrs + allowed_foo_attrs = {"do", "not", "serialize", "these", "attrs"} + return isinstance(path[-1][1], Foo) and prop in allowed_foo_attrs def test_bar(snapshot): actual = Foo(...) assert actual == snapshot(exclude=limit_foo_attrs) ``` -**B**. Or override the `__dir__` implementation to control the attribute list. +**B**. Provide a filter function to the snapshot [include](#include) configuration option. + +```py +def limit_foo_attrs(prop, path): + allowed_foo_attrs = {"only", "serialize", "these", "attrs"} + return isinstance(path[-1][1], Foo) and prop in allowed_foo_attrs + +def test_bar(snapshot): + actual = Foo(...) + assert actual == snapshot(include=limit_foo_attrs) +``` + +**C**. Or override the `__dir__` implementation to control the attribute list. ```py class Foo: - def __dir__(self): - return ["only", "serialize", "these", "attrs"] + def __dir__(self): + return ["only", "serialize", "these", "attrs"] def test_bar(snapshot): actual = Foo(...) @@ -211,7 +223,7 @@ Only runs replacement for objects at a matching path where the value of the mapp This allows you to filter out object properties from the serialized snapshot. The exclude parameter takes a filter function that accepts two keyword arguments. -It should return `true` or `false` if the property should be excluded or included respectively. +It should return `true` if the property should be excluded, or `false` if the property should be included. | Argument | Description | | -------- | --------------------------------------------------------------------------------------------------------------------------------------------- | @@ -278,6 +290,29 @@ def test_bar(snapshot): # --- ``` +#### `include` + +This allows you filter an object's properties to a subset using a predicate. This is the opposite of [exclude](#exclude). All the same property filters supporterd by [exclude](#exclude) are supported for `include`. + +The include parameter takes a filter function that accepts two keyword arguments. +It should return `true` if the property should be include, or `false` if the property should not be included. + +| Argument | Description | +| -------- | --------------------------------------------------------------------------------------------------------------------------------------------- | +| `prop` | Current property on the object, could be any hashable value that can be used to retrieve a value e.g. `1`, `"prop_str"`, `SomeHashableObject` | +| `path` | Ordered path traversed to the current value e.g. `(("a", dict), ("b", dict))` from `{ "a": { "b": { "c": 1 } } }`} + +Note that `include` has some caveats which make it a bit more difficult to use than `exclude`. Both `include` and `exclude` are evaluated for each key of an object before traversing down nested paths. This means if you want to include a nested path, you must include all parents of the nested path, otherwise the nested child will never be reached to be evaluated against the include predicate. For example: + +```py +obj = { + "nested": { "key": True } +} +assert obj == snapshot(include=paths("nested", "nested.key")) +``` + +The extra "nested" is required, otherwise the nested dictionary will never be searched -- it'd get pruned too early. + #### `extension_class` This is a way to modify how the snapshot matches and serializes your data in a single assertion. diff --git a/src/syrupy/assertion.py b/src/syrupy/assertion.py index 1eeb6a35..a3a7c0f8 100644 --- a/src/syrupy/assertion.py +++ b/src/syrupy/assertion.py @@ -65,6 +65,10 @@ class SnapshotAssertion: init=False, default=None, ) + _include: Optional["PropertyFilter"] = field( + init=False, + default=None, + ) _custom_index: Optional[str] = field( init=False, default=None, @@ -180,7 +184,7 @@ def assert_match(self, data: "SerializableData") -> None: def _serialize(self, data: "SerializableData") -> "SerializedData": return self.extension.serialize( - data, exclude=self._exclude, matcher=self.__matcher + data, exclude=self._exclude, include=self._include, matcher=self.__matcher ) def get_assert_diff(self) -> List[str]: @@ -233,6 +237,7 @@ def __call__( *, diff: Optional["SnapshotIndex"] = None, exclude: Optional["PropertyFilter"] = None, + include: Optional["PropertyFilter"] = None, extension_class: Optional[Type["AbstractSyrupyExtension"]] = None, matcher: Optional["PropertyMatcher"] = None, name: Optional["SnapshotIndex"] = None, @@ -242,6 +247,8 @@ def __call__( """ if exclude: self.__with_prop("_exclude", exclude) + if include: + self.__with_prop("_include", include) if extension_class: self.__with_prop("_extension", self.__init_extension(extension_class)) if matcher: diff --git a/src/syrupy/extensions/amber/serializer.py b/src/syrupy/extensions/amber/serializer.py index 30715a6e..07e25177 100644 --- a/src/syrupy/extensions/amber/serializer.py +++ b/src/syrupy/extensions/amber/serializer.py @@ -203,6 +203,7 @@ def serialize( data: "SerializableData", *, exclude: Optional["PropertyFilter"] = None, + include: Optional["PropertyFilter"] = None, matcher: Optional["PropertyMatcher"] = None, ) -> str: """ @@ -211,7 +212,9 @@ def serialize( same new line control characters. Example snapshots generated on windows os should not break when running the tests on a unix based system and vice versa. """ - serialized = cls._serialize(data, exclude=exclude, matcher=matcher) + serialized = cls._serialize( + data, exclude=exclude, include=include, matcher=matcher + ) return serialized.replace("\r\n", "\n").replace("\r", "\n") @classmethod @@ -221,6 +224,7 @@ def _serialize( *, depth: int = 0, exclude: Optional["PropertyFilter"] = None, + include: Optional["PropertyFilter"] = None, matcher: Optional["PropertyMatcher"] = None, path: "PropertyPath" = (), visited: Optional[Set[Any]] = None, @@ -235,6 +239,7 @@ def _serialize( "data": data, "depth": depth, "exclude": exclude, + "include": include, "matcher": matcher, "path": path, "visited": {*visited, data_id}, @@ -400,6 +405,7 @@ def serialize_custom_iterable( close_paren: Optional[str] = None, depth: int = 0, exclude: Optional["PropertyFilter"] = None, + include: Optional["PropertyFilter"] = None, path: "PropertyPath" = (), separator: Optional[str] = None, serialize_key: bool = False, @@ -414,7 +420,8 @@ def serialize_custom_iterable( key_values = ( (key, get_value(data, key)) for key in keys - if not exclude or not exclude(prop=key, path=path) + if (not exclude or not exclude(prop=key, path=path)) + and (not include or include(prop=key, path=path)) ) entries = ( entry @@ -433,7 +440,11 @@ def key_str(key: "PropertyName") -> str: def value_str(key: "PropertyName", value: "SerializableData") -> str: serialized = cls._serialize( - data=value, exclude=exclude, path=(*path, (key, type(value))), **kwargs + data=value, + exclude=exclude, + include=include, + path=(*path, (key, type(value))), + **kwargs, ) return serialized if separator is None else serialized.lstrip(cls._indent) diff --git a/src/syrupy/extensions/base.py b/src/syrupy/extensions/base.py index b4a2b430..945cf20b 100644 --- a/src/syrupy/extensions/base.py +++ b/src/syrupy/extensions/base.py @@ -65,6 +65,7 @@ def serialize( data: "SerializableData", *, exclude: Optional["PropertyFilter"] = None, + include: Optional["PropertyFilter"] = None, matcher: Optional["PropertyMatcher"] = None, ) -> "SerializedData": """ diff --git a/src/syrupy/extensions/json/__init__.py b/src/syrupy/extensions/json/__init__.py index 016ad393..5b52a8d5 100644 --- a/src/syrupy/extensions/json/__init__.py +++ b/src/syrupy/extensions/json/__init__.py @@ -55,6 +55,7 @@ def _filter( depth: int = 0, path: "PropertyPath", exclude: Optional["PropertyFilter"] = None, + include: Optional["PropertyFilter"] = None, matcher: Optional["PropertyMatcher"] = None, visited: Optional[Set[Any]] = None, ) -> "SerializableData": @@ -80,6 +81,8 @@ def _filter( value = data[key] if exclude and exclude(prop=key, path=path): continue + if include and not include(prop=key, path=path): + continue if not isinstance(key, (str,)): continue filtered_dct[key] = cls._filter( @@ -87,6 +90,7 @@ def _filter( depth=depth + 1, path=(*path, (key, type(value))), exclude=exclude, + include=include, matcher=matcher, visited={*visited, data_id}, ) @@ -101,6 +105,7 @@ def _filter( depth=depth + 1, path=(*path, (key, type(value))), exclude=exclude, + include=include, matcher=matcher, visited={*visited, data_id}, ) @@ -118,6 +123,7 @@ def _filter( depth=depth + 1, path=(*path, (key, type(value))), exclude=exclude, + include=include, matcher=matcher, visited={*visited, data_id}, ) @@ -137,9 +143,15 @@ def serialize( data: "SerializableData", *, exclude: Optional["PropertyFilter"] = None, + include: Optional["PropertyFilter"] = None, matcher: Optional["PropertyMatcher"] = None, ) -> "SerializedData": data = self._filter( - data=data, depth=0, path=(), exclude=exclude, matcher=matcher + data=data, + depth=0, + path=(), + exclude=exclude, + include=include, + matcher=matcher, ) return json.dumps(data, indent=2, ensure_ascii=False, sort_keys=False) + "\n" diff --git a/src/syrupy/extensions/single_file.py b/src/syrupy/extensions/single_file.py index 19b9838d..0b216115 100644 --- a/src/syrupy/extensions/single_file.py +++ b/src/syrupy/extensions/single_file.py @@ -47,6 +47,7 @@ def serialize( data: "SerializableData", *, exclude: Optional["PropertyFilter"] = None, + include: Optional["PropertyFilter"] = None, matcher: Optional["PropertyMatcher"] = None, ) -> "SerializedData": return self.get_supported_dataclass()(data) diff --git a/tests/syrupy/extensions/amber/__snapshots__/test_amber_filters.ambr b/tests/syrupy/extensions/amber/__snapshots__/test_amber_filters.ambr index f35f2e8f..eea00b21 100644 --- a/tests/syrupy/extensions/amber/__snapshots__/test_amber_filters.ambr +++ b/tests/syrupy/extensions/amber/__snapshots__/test_amber_filters.ambr @@ -35,3 +35,18 @@ }), }) # --- +# name: test_only_includes_expected_props + dict({ + 'date': 'utc', + 0: 'some value', + }) +# --- +# name: test_only_includes_expected_props.1 + dict({ + 'date': 'utc', + 'nested': dict({ + 'id': 4, + }), + 0: 'some value', + }) +# --- diff --git a/tests/syrupy/extensions/amber/test_amber_filters.py b/tests/syrupy/extensions/amber/test_amber_filters.py index 726d95eb..6e317439 100644 --- a/tests/syrupy/extensions/amber/test_amber_filters.py +++ b/tests/syrupy/extensions/amber/test_amber_filters.py @@ -38,6 +38,18 @@ def test_filters_expected_props(snapshot): assert actual == snapshot(exclude=props("0", "date", "id")) +def test_only_includes_expected_props(snapshot): + actual = { + 0: "some value", + "date": "utc", + "nested": {"id": 4, "other": "value"}, + "list": [1, 2], + } + # Note that "id" won't get included because "nested" (its parent) is not included. + assert actual == snapshot(include=props("0", "date", "id")) + assert actual == snapshot(include=paths("0", "date", "nested", "nested.id")) + + @pytest.mark.parametrize( "predicate", [paths("exclude_me", "nested.exclude_me"), props("exclude_me")] ) diff --git a/tests/syrupy/extensions/json/__snapshots__/test_json_filters/test_include_simple.1.json b/tests/syrupy/extensions/json/__snapshots__/test_json_filters/test_include_simple.1.json new file mode 100644 index 00000000..d8f4abad --- /dev/null +++ b/tests/syrupy/extensions/json/__snapshots__/test_json_filters/test_include_simple.1.json @@ -0,0 +1,4 @@ +{ + "foo": "__SHOULD_BE_REMOVED_FROM_JSON__", + "id": 123456789 +} diff --git a/tests/syrupy/extensions/json/__snapshots__/test_json_filters/test_include_simple.json b/tests/syrupy/extensions/json/__snapshots__/test_json_filters/test_include_simple.json new file mode 100644 index 00000000..d8f4abad --- /dev/null +++ b/tests/syrupy/extensions/json/__snapshots__/test_json_filters/test_include_simple.json @@ -0,0 +1,4 @@ +{ + "foo": "__SHOULD_BE_REMOVED_FROM_JSON__", + "id": 123456789 +} diff --git a/tests/syrupy/extensions/json/test_json_filters.py b/tests/syrupy/extensions/json/test_json_filters.py index 744a563b..9095ca86 100644 --- a/tests/syrupy/extensions/json/test_json_filters.py +++ b/tests/syrupy/extensions/json/test_json_filters.py @@ -46,6 +46,19 @@ def test_exclude_simple(snapshot_json): assert snapshot_json(exclude=paths("id", "foo")) == content +def test_include_simple(snapshot_json): + content = { + "id": 123456789, + "foo": "__SHOULD_BE_REMOVED_FROM_JSON__", + "I'm": "still alive", + "nested": { + "foo": "is still alive", + }, + } + assert snapshot_json(include=props("id", "foo")) == content + assert snapshot_json(include=paths("id", "foo")) == content + + def test_exclude_nested(snapshot_json): content = { "a": "b",