Skip to content

Commit

Permalink
refactor: remove usage of self.test_location
Browse files Browse the repository at this point in the history
  • Loading branch information
Noah Negin-Ulster committed Dec 1, 2022
1 parent a5055c7 commit 6e7fc50
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 20 deletions.
6 changes: 4 additions & 2 deletions src/syrupy/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,9 @@ def __eq__(self, other: "SerializableData") -> bool:
return self._assert(other)

def _assert(self, data: "SerializableData") -> bool:
snapshot_location = self.extension.get_location(index=self.index)
snapshot_location = self.extension.get_location(
test_location=self.test_location, index=self.index
)
snapshot_name = self.extension.get_snapshot_name(
test_location=self.test_location, index=self.index
)
Expand Down Expand Up @@ -301,7 +303,7 @@ def _post_assert(self) -> None:
def _recall_data(self, index: "SnapshotIndex") -> Optional["SerializableData"]:
try:
return self.extension.read_snapshot(
test_location=self.extension.test_location,
test_location=self.test_location,
index=index,
session_id=str(id(self.session)),
)
Expand Down
28 changes: 18 additions & 10 deletions src/syrupy/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,14 @@ def get_snapshot_name(
index_suffix = f".{index}"
return f"{test_location.snapshot_name}{index_suffix}"

def get_location(self, *, index: "SnapshotIndex") -> str:
def get_location(
self, *, test_location: "PyTestLocation", index: "SnapshotIndex"
) -> str:
"""Returns full location where snapshot data is stored."""
basename = self._get_file_basename(index=index)
basename = self._get_file_basename(test_location=test_location, index=index)
fileext = f".{self._file_extension}" if self._file_extension else ""
return str(
Path(self.dirname(test_location=self.test_location)).joinpath(
Path(self.dirname(test_location=test_location)).joinpath(
f"{basename}{fileext}"
)
)
Expand Down Expand Up @@ -142,7 +144,7 @@ def read_snapshot(
This method is _final_, do not override. You can override
`_read_snapshot_data_from_location` in a subclass to change behaviour.
"""
snapshot_location = self.get_location(index=index)
snapshot_location = self.get_location(test_location=test_location, index=index)
snapshot_name = self.get_snapshot_name(test_location=test_location, index=index)
snapshot_data = self._read_snapshot_data_from_location(
snapshot_location=snapshot_location,
Expand All @@ -168,13 +170,13 @@ def write_snapshot(
# Amber extension.
locations: DefaultDict[str, List["Snapshot"]] = defaultdict(list)
for data, index in snapshots:
location = self.get_location(index=index)
location = self.get_location(test_location=test_location, index=index)
snapshot_name = self.get_snapshot_name(
test_location=test_location, index=index
)
locations[location].append(Snapshot(name=snapshot_name, data=data))

self.__ensure_snapshot_dir(index=index)
self.__ensure_snapshot_dir(test_location=test_location, index=index)

for location, location_snapshots in locations.items():
snapshot_collection = SnapshotCollection(location=location)
Expand Down Expand Up @@ -249,16 +251,22 @@ def dirname(self, *, test_location: "PyTestLocation") -> str:
test_dir = Path(test_location.filepath).parent
return str(test_dir.joinpath(SNAPSHOT_DIRNAME))

def _get_file_basename(self, *, index: "SnapshotIndex") -> str:
def _get_file_basename(
self, *, test_location: "PyTestLocation", index: "SnapshotIndex"
) -> str:
"""Returns file basename without extension. Used to create full filepath."""
return self.test_location.basename
return test_location.basename

def __ensure_snapshot_dir(self, *, index: "SnapshotIndex") -> None:
def __ensure_snapshot_dir(
self, *, test_location: "PyTestLocation", index: "SnapshotIndex"
) -> None:
"""
Ensures the folder path for the snapshot file exists.
"""
try:
Path(self.get_location(index=index)).parent.mkdir(parents=True)
Path(
self.get_location(test_location=test_location, index=index)
).parent.mkdir(parents=True)
except FileExistsError:
pass

Expand Down
6 changes: 4 additions & 2 deletions src/syrupy/extensions/single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@ def delete_snapshots(
) -> None:
Path(snapshot_location).unlink()

def _get_file_basename(self, *, index: "SnapshotIndex") -> str:
return self.get_snapshot_name(test_location=self.test_location, index=index)
def _get_file_basename(
self, *, test_location: "PyTestLocation", index: "SnapshotIndex"
) -> str:
return self.get_snapshot_name(test_location=test_location, index=index)

def dirname(self, *, test_location: "PyTestLocation") -> str:
original_dirname = super(SingleFileSnapshotExtension, self).dirname(
Expand Down
4 changes: 2 additions & 2 deletions src/syrupy/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,13 @@ def __post_init__(self) -> None:
# We only need to discover snapshots once per test file, not once per assertion.
locations_discovered: DefaultDict[str, Set[Any]] = defaultdict(set)
for assertion in self.assertions:
test_location = assertion.extension.test_location.filepath
test_location = assertion.test_location.filepath
extension_class = assertion.extension.__class__
if extension_class not in locations_discovered[test_location]:
locations_discovered[test_location].add(extension_class)
self.discovered.merge(
assertion.extension.discover_snapshots(
test_location=assertion.extension.test_location
test_location=assertion.test_location
)
)

Expand Down
4 changes: 2 additions & 2 deletions src/syrupy/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,14 @@ def finish(self) -> int:
def register_request(self, assertion: "SnapshotAssertion") -> None:
self._assertions.append(assertion)

test_location = assertion.extension.test_location.filepath
test_location = assertion.test_location.filepath
extension_class = assertion.extension.__class__
if extension_class not in self._locations_discovered[test_location]:
self._locations_discovered[test_location].add(extension_class)
discovered_extensions = {
discovered.location: assertion.extension
for discovered in assertion.extension.discover_snapshots(
test_location=assertion.extension.test_location
test_location=assertion.test_location
)
if discovered.has_snapshots
}
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/test_snapshot_use_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def get_snapshot_name(cls, *, test_location, index):
testname = test_location.testname[::-1]
return f"{testname}.{index}"
def _get_file_basename(self, *, index):
return self.test_location.basename[::-1]
def _get_file_basename(self, *, test_location, index):
return test_location.basename[::-1]
@pytest.fixture
def snapshot_custom(snapshot):
Expand Down

0 comments on commit 6e7fc50

Please sign in to comment.