Skip to content

Commit c05a6b7

Browse files
🔧 mypy type check tools/ and utils/ (#931)
- Add explicit type annotations in visualization and transforms modules. - Revise type declarations and cast usage in the wsi_registration module. - Adjust input type handling in the patchextraction module. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent ad66b2d commit c05a6b7

File tree

4 files changed

+32
-14
lines changed

4 files changed

+32
-14
lines changed

tiatoolbox/tools/patchextraction.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing_extensions import Unpack
1010

1111
from tiatoolbox import logger
12+
from tiatoolbox.annotation.storage import AnnotationStore
1213
from tiatoolbox.utils import misc
1314
from tiatoolbox.utils.exceptions import FileNotSupportedError, MethodNotSupportedError
1415
from tiatoolbox.utils.visualization import AnnotationRenderer
@@ -19,7 +20,6 @@
1920

2021
from pandas import DataFrame
2122

22-
from tiatoolbox.annotation.storage import AnnotationStore
2323
from tiatoolbox.type_hints import Resolution, Units
2424

2525

@@ -237,7 +237,9 @@ def __init__(
237237

238238
if input_mask is None:
239239
self.mask = None
240-
elif isinstance(input_mask, str) and input_mask.endswith(".db"):
240+
elif (isinstance(input_mask, str) and input_mask.endswith(".db")) or isinstance(
241+
input_mask, AnnotationStore
242+
):
241243
# input_mask is an annotation store
242244
renderer = AnnotationRenderer(
243245
max_scale=10000, edge_thickness=0, where=store_filter
@@ -670,7 +672,12 @@ def __init__( # noqa: PLR0913
670672
self: SlidingWindowPatchExtractor,
671673
input_img: str | Path | np.ndarray | wsireader.WSIReader,
672674
patch_size: int | tuple[int, int],
673-
input_mask: str | Path | np.ndarray | wsireader.VirtualWSIReader | None = None,
675+
input_mask: str
676+
| Path
677+
| np.ndarray
678+
| wsireader.VirtualWSIReader
679+
| AnnotationStore
680+
| None = None,
674681
resolution: Resolution = 0,
675682
units: Units = "level",
676683
stride: int | tuple[int, int] | None = None,

tiatoolbox/tools/registration/wsi_registration.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
import itertools
6-
from typing import TYPE_CHECKING, Callable
6+
from typing import TYPE_CHECKING, Callable, cast
77

88
import cv2
99
import numpy as np
@@ -338,19 +338,24 @@ def __init__(self: DFBRFeatureExtractor) -> None:
338338
super().__init__()
339339
output_layers_id: list[str] = ["16", "23", "30"]
340340
output_layers_key: list[str] = ["block3_pool", "block4_pool", "block5_pool"]
341-
self.features: dict = dict.fromkeys(output_layers_key, None)
342-
self.pretrained: torch.nn.Sequential = compile_model(
341+
self.features: dict[str, torch.Tensor] = dict.fromkeys(
342+
output_layers_key, torch.Tensor()
343+
)
344+
345+
compiled_model = compile_model(
343346
torchvision.models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1),
344347
mode=rcParam["torch_compile_mode"],
345-
).features
348+
)
349+
self.pretrained = cast("torch.nn.Module", compiled_model.features)
350+
346351
self.f_hooks = [
347352
getattr(self.pretrained, layer).register_forward_hook(
348353
self.forward_hook(output_layers_key[i]),
349354
)
350355
for i, layer in enumerate(output_layers_id)
351356
]
352357

353-
def forward_hook(self: torch.nn.Module, layer_name: str) -> Callable:
358+
def forward_hook(self: DFBRFeatureExtractor, layer_name: str) -> Callable:
354359
"""Register a hook.
355360
356361
Args:
@@ -386,7 +391,7 @@ def hook(
386391

387392
return hook
388393

389-
def forward(self: torch.nn.Module, x: torch.Tensor) -> dict[str, torch.Tensor]:
394+
def forward(self: DFBRFeatureExtractor, x: torch.Tensor) -> dict[str, torch.Tensor]:
390395
"""Forward pass for feature extraction.
391396
392397
Args:

tiatoolbox/utils/transforms.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def imresize(
189189
img_channels = [
190190
cv2.resize(
191191
src=img[..., ch],
192-
dsize=output_size_array,
192+
dsize=(output_size_array[0], output_size_array[1]),
193193
interpolation=cv2_interpolation,
194194
)[
195195
...,
@@ -199,7 +199,11 @@ def imresize(
199199
]
200200
return np.concatenate(img_channels, axis=-1)
201201

202-
return cv2.resize(src=img, dsize=output_size_array, interpolation=cv2_interpolation)
202+
return cv2.resize(
203+
src=img,
204+
dsize=(output_size_array[0], output_size_array[1]),
205+
interpolation=cv2_interpolation,
206+
)
203207

204208

205209
def rgb2od(img: np.ndarray) -> np.ndarray:

tiatoolbox/utils/visualization.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -475,14 +475,15 @@ def overlay_prediction_contours(
475475
inst_colours_array = inst_colours_array.astype(np.uint8)
476476

477477
for idx, [_, inst_info] in enumerate(inst_dict.items()):
478-
inst_contour = inst_info["contour"]
478+
inst_contour: np.ndarray = inst_info["contour"]
479479
if "type" in inst_info and type_colours is not None:
480480
inst_colour = type_colours[inst_info["type"]][1]
481481
else:
482482
inst_colour = (inst_colours_array[idx]).tolist()
483+
contours: list[np.ndarray] = [np.array(inst_contour)]
483484
cv2.drawContours(
484485
overlay,
485-
[np.array(inst_contour)],
486+
contours,
486487
-1,
487488
inst_colour,
488489
line_thickness,
@@ -881,9 +882,10 @@ def render_line(
881882
top_left,
882883
scale,
883884
)
885+
pts: list[np.ndarray] = [np.array(cnt)]
884886
cv2.polylines(
885887
tile,
886-
[np.array(cnt)],
888+
pts,
887889
isClosed=False,
888890
color=col,
889891
thickness=3,

0 commit comments

Comments
 (0)