Skip to content

Commit

Permalink
🐛 Add mypy Type Check tools/wsi_registration.py (#831)
Browse files Browse the repository at this point in the history
- Fix type errors
  • Loading branch information
Jiaqi-Lv authored Jul 12, 2024
1 parent 647d30b commit 54fa32a
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions tiatoolbox/tools/registration/wsi_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ class DFBRFeatureExtractor(torch.nn.Module):
"""

def __init__(self: torch.nn.Module) -> None:
def __init__(self: DFBRFeatureExtractor) -> None:
"""Initialize :class:`DFBRFeatureExtractor`."""
super().__init__()
output_layers_id: list[str] = ["16", "23", "30"]
Expand Down Expand Up @@ -434,8 +434,8 @@ class DFBRegister:
def __init__(self: DFBRegister, patch_size: tuple[int, int] = (224, 224)) -> None:
"""Initialize :class:`DFBRegister`."""
self.patch_size = patch_size
self.x_scale: list[float] = []
self.y_scale: list[float] = []
self.x_scale: np.ndarray
self.y_scale: np.ndarray
self.feature_extractor = DFBRFeatureExtractor()

# Make this function private when full pipeline is implemented.
Expand Down Expand Up @@ -796,7 +796,7 @@ def find_points_inside_boundary(mask: np.ndarray, points: np.ndarray) -> np.ndar
return PatchExtractor.filter_coordinates(
mask_reader,
bbox_coord,
mask.shape[::-1],
(mask.shape[1], mask.shape[0]),
)

def filtering_matching_points(
Expand Down Expand Up @@ -1521,21 +1521,21 @@ def get_patch_dimensions(
"""
width, height = size[0], size[1]

x = [
x_info = [
np.linspace(1, width, width, endpoint=True),
np.ones(height) * width,
np.linspace(1, width, width, endpoint=True),
np.ones(height),
]
x = np.array(list(itertools.chain.from_iterable(x)))
x = np.array(list(itertools.chain.from_iterable(x_info)))

y = [
y_info = [
np.ones(width),
np.linspace(1, height, height, endpoint=True),
np.ones(width) * height,
np.linspace(1, height, height, endpoint=True),
]
y = np.array(list(itertools.chain.from_iterable(y)))
y = np.array(list(itertools.chain.from_iterable(y_info)))

points = np.array([x, y]).transpose()
transform = transform * [[1, 1, 0], [1, 1, 0], [1, 1, 1]] # remove translation
Expand Down

0 comments on commit 54fa32a

Please sign in to comment.