Skip to content

Commit

Permalink
feat: support overriding the amber serializer class (#683)
Browse files Browse the repository at this point in the history
  • Loading branch information
noahnu authored Feb 2, 2023
1 parent 4ca0716 commit 662c93f
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 58 deletions.
18 changes: 11 additions & 7 deletions src/syrupy/extensions/amber/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
Any,
Optional,
Set,
Type,
)

from syrupy.data import SnapshotCollection
from syrupy.exceptions import TaintedSnapshotError
from syrupy.extensions.base import AbstractSyrupyExtension

from .serializer import AmberDataSerializerSorted # noqa: F401 # re-exported
from .serializer import AmberDataSerializer

if TYPE_CHECKING:
Expand All @@ -24,12 +26,14 @@ class AmberSnapshotExtension(AbstractSyrupyExtension):

_file_extension = "ambr"

serializer_class: Type["AmberDataSerializer"] = AmberDataSerializer

def serialize(self, data: "SerializableData", **kwargs: Any) -> str:
"""
Returns the serialized form of 'data' to be compared
with the snapshot data written to disk.
"""
return AmberDataSerializer.serialize(data, **kwargs)
return self.serializer_class.serialize(data, **kwargs)

def delete_snapshots(
self, snapshot_location: str, snapshot_names: Set[str]
Expand All @@ -39,19 +43,19 @@ def delete_snapshots(
snapshot_collection_to_update.remove(snapshot_name)

if snapshot_collection_to_update.has_snapshots:
AmberDataSerializer.write_file(snapshot_collection_to_update)
self.serializer_class.write_file(snapshot_collection_to_update)
else:
Path(snapshot_location).unlink()

def _read_snapshot_collection(self, snapshot_location: str) -> "SnapshotCollection":
return AmberDataSerializer.read_file(snapshot_location)
return self.serializer_class.read_file(snapshot_location)

@staticmethod
@classmethod
@lru_cache()
def __cacheable_read_snapshot(
snapshot_location: str, cache_key: str
cls, snapshot_location: str, cache_key: str
) -> "SnapshotCollection":
return AmberDataSerializer.read_file(snapshot_location)
return cls.serializer_class.read_file(snapshot_location)

def _read_snapshot_data_from_location(
self, snapshot_location: str, snapshot_name: str, session_id: str
Expand All @@ -70,7 +74,7 @@ def _read_snapshot_data_from_location(
def _write_snapshot_collection(
cls, *, snapshot_collection: "SnapshotCollection"
) -> None:
AmberDataSerializer.write_file(snapshot_collection, merge=True)
cls.serializer_class.write_file(snapshot_collection, merge=True)


__all__ = ["AmberSnapshotExtension", "AmberDataSerializer"]
59 changes: 37 additions & 22 deletions src/syrupy/extensions/amber/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
Dict,
Generator,
Iterable,
List,
NamedTuple,
Optional,
Set,
Expand Down Expand Up @@ -77,7 +76,13 @@ class MissingVersionError(Exception):


class AmberDataSerializer:
VERSION = 1
"""
If extending the serializer, change the VERSION property to some unique value
for your iteration of the serializer so as to force invalidation of existing
snapshots.
"""

VERSION = "1"

_indent: str = " "
_max_depth: int = 99
Expand All @@ -89,23 +94,8 @@ class Marker:
Divider = "---"

@classmethod
def __maybe_int(cls, part: str) -> Tuple[int, Union[str, int]]:
try:
# cast to int only if the string is the exact representation of the int
# for example, '012' != str(int('012'))
i = int(part)
if str(i) == part:
return (1, i)
return (0, part)
except ValueError:
# the nested tuple is to prevent comparing a str to an int
return (0, part)

@classmethod
def __snapshot_sort_key(
cls, snapshot: "Snapshot"
) -> List[Tuple[int, Union[str, int]]]:
return [cls.__maybe_int(part) for part in snapshot.name.split(".")]
def _snapshot_sort_key(cls, snapshot: "Snapshot") -> Any:
return snapshot.name

@classmethod
def write_file(
Expand All @@ -123,7 +113,7 @@ def write_file(
with open(filepath, "w", encoding=TEXT_ENCODING, newline=None) as f:
f.write(f"{cls._marker_prefix}{cls.Marker.Version}: {cls.VERSION}\n")
for snapshot in sorted(
snapshot_collection, key=lambda s: cls.__snapshot_sort_key(s)
snapshot_collection, key=lambda s: cls._snapshot_sort_key(s) # type: ignore # noqa: E501
):
snapshot_data = str(snapshot.data)
if snapshot_data is not None:
Expand Down Expand Up @@ -152,14 +142,14 @@ def __read_file_with_markers(
":", maxsplit=1
)
marker_key = marker_key.rstrip(" \r\n")
marker_value = marker_rest[0] if marker_rest else None
marker_value = marker_rest[0].strip() if marker_rest else None

if marker_key == cls.Marker.Version:
if line_no:
raise MalformedAmberFile(
"Version must be specified at the top of the file."
)
if not marker_value or int(marker_value) != cls.VERSION:
if not marker_value or marker_value != cls.VERSION:
tainted = True
continue
missing_version = False
Expand Down Expand Up @@ -457,3 +447,28 @@ def __serialize_lines(
formatted_open_tag = cls.with_indent(f"{maybe_obj_type}{open_tag}", depth)
formatted_close_tag = cls.with_indent(close_tag, depth)
return f"{formatted_open_tag}\n{lines}{lines_end}{formatted_close_tag}"


class AmberDataSerializerSorted(AmberDataSerializer):
"""
This is an experimental serializer with known performance issues.
"""

VERSION = f"{AmberDataSerializer.VERSION}-sorted"

@classmethod
def __maybe_int(cls, part: str) -> Tuple[int, Union[str, int]]:
try:
# cast to int only if the string is the exact representation of the int
# for example, '012' != str(int('012'))
i = int(part)
if str(i) == part:
return (1, i)
return (0, part)
except ValueError:
# the nested tuple is to prevent comparing a str to an int
return (0, part)

@classmethod
def _snapshot_sort_key(cls, snapshot: "Snapshot") -> Any:
return [cls.__maybe_int(part) for part in snapshot.name.split(".")]
10 changes: 5 additions & 5 deletions tests/syrupy/__snapshots__/test_doctest.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
obj_attr='test class attr',
)
# ---
# name: DocTestClass.1
DocTestClass(
obj_attr='test class attr',
)
# ---
# name: DocTestClass.NestedDocTestClass
NestedDocTestClass(
nested_obj_attr='nested doc test class attr',
Expand All @@ -15,11 +20,6 @@
# name: DocTestClass.doctest_method
'doc test method return value'
# ---
# name: DocTestClass.1
DocTestClass(
obj_attr='test class attr',
)
# ---
# name: doctest_fn
'doc test fn return value'
# ---
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,30 +253,6 @@
# name: test_many_sorted.1
1
# ---
# name: test_many_sorted.2
2
# ---
# name: test_many_sorted.3
3
# ---
# name: test_many_sorted.4
4
# ---
# name: test_many_sorted.5
5
# ---
# name: test_many_sorted.6
6
# ---
# name: test_many_sorted.7
7
# ---
# name: test_many_sorted.8
8
# ---
# name: test_many_sorted.9
9
# ---
# name: test_many_sorted.10
10
# ---
Expand Down Expand Up @@ -307,6 +283,9 @@
# name: test_many_sorted.19
19
# ---
# name: test_many_sorted.2
2
# ---
# name: test_many_sorted.20
20
# ---
Expand All @@ -322,6 +301,27 @@
# name: test_many_sorted.24
24
# ---
# name: test_many_sorted.3
3
# ---
# name: test_many_sorted.4
4
# ---
# name: test_many_sorted.5
5
# ---
# name: test_many_sorted.6
6
# ---
# name: test_many_sorted.7
7
# ---
# name: test_many_sorted.8
8
# ---
# name: test_many_sorted.9
9
# ---
# name: test_multiline_string_in_dict
dict({
'value': '''
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# serializer version: 1-sorted
# name: test_many_sorted
0
# ---
# name: test_many_sorted.1
1
# ---
# name: test_many_sorted.2
2
# ---
# name: test_many_sorted.3
3
# ---
# name: test_many_sorted.4
4
# ---
# name: test_many_sorted.5
5
# ---
# name: test_many_sorted.6
6
# ---
# name: test_many_sorted.7
7
# ---
# name: test_many_sorted.8
8
# ---
# name: test_many_sorted.9
9
# ---
# name: test_many_sorted.10
10
# ---
# name: test_many_sorted.11
11
# ---
# name: test_many_sorted.12
12
# ---
# name: test_many_sorted.13
13
# ---
# name: test_many_sorted.14
14
# ---
# name: test_many_sorted.15
15
# ---
# name: test_many_sorted.16
16
# ---
# name: test_many_sorted.17
17
# ---
# name: test_many_sorted.18
18
# ---
# name: test_many_sorted.19
19
# ---
# name: test_many_sorted.20
20
# ---
# name: test_many_sorted.21
21
# ---
# name: test_many_sorted.22
22
# ---
# name: test_many_sorted.23
23
# ---
# name: test_many_sorted.24
24
# ---
20 changes: 20 additions & 0 deletions tests/syrupy/extensions/amber_sorted/test_amber_sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest

from syrupy.extensions.amber import (
AmberDataSerializerSorted,
AmberSnapshotExtension,
)


class AmberSortedSnapshotExtension(AmberSnapshotExtension):
serializer_class = AmberDataSerializerSorted


@pytest.fixture
def snapshot(snapshot):
return snapshot.use_extension(AmberSortedSnapshotExtension)


def test_many_sorted(snapshot):
for i in range(25):
assert i == snapshot

1 comment on commit 662c93f

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: 662c93f Previous: 02abef5 Ratio
benchmarks/test_1000x.py::test_1000x_reads 0.5341319969507773 iter/sec (stddev: 0.07128995514696808) 0.8381195242511715 iter/sec (stddev: 0.04240394140227035) 1.57
benchmarks/test_1000x.py::test_1000x_writes 0.5217680846071134 iter/sec (stddev: 0.08274122550544077) 0.8626650008455868 iter/sec (stddev: 0.05153168408309042) 1.65
benchmarks/test_standard.py::test_standard 0.5065343890498211 iter/sec (stddev: 0.12248840331953859) 0.7465173870618954 iter/sec (stddev: 0.1502009356924296) 1.47

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.