Skip to content

Commit ee25842

Browse files
authored
Merge branch 'develop' into models-abc-multigpu
2 parents 698f16a + c05a6b7 commit ee25842

File tree

5 files changed

+72
-15
lines changed

5 files changed

+72
-15
lines changed

tiatoolbox/tools/patchextraction.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing_extensions import Unpack
1010

1111
from tiatoolbox import logger
12+
from tiatoolbox.annotation.storage import AnnotationStore
1213
from tiatoolbox.utils import misc
1314
from tiatoolbox.utils.exceptions import FileNotSupportedError, MethodNotSupportedError
1415
from tiatoolbox.utils.visualization import AnnotationRenderer
@@ -19,7 +20,6 @@
1920

2021
from pandas import DataFrame
2122

22-
from tiatoolbox.annotation.storage import AnnotationStore
2323
from tiatoolbox.type_hints import Resolution, Units
2424

2525

@@ -237,7 +237,9 @@ def __init__(
237237

238238
if input_mask is None:
239239
self.mask = None
240-
elif isinstance(input_mask, str) and input_mask.endswith(".db"):
240+
elif (isinstance(input_mask, str) and input_mask.endswith(".db")) or isinstance(
241+
input_mask, AnnotationStore
242+
):
241243
# input_mask is an annotation store
242244
renderer = AnnotationRenderer(
243245
max_scale=10000, edge_thickness=0, where=store_filter
@@ -670,7 +672,12 @@ def __init__( # noqa: PLR0913
670672
self: SlidingWindowPatchExtractor,
671673
input_img: str | Path | np.ndarray | wsireader.WSIReader,
672674
patch_size: int | tuple[int, int],
673-
input_mask: str | Path | np.ndarray | wsireader.VirtualWSIReader | None = None,
675+
input_mask: str
676+
| Path
677+
| np.ndarray
678+
| wsireader.VirtualWSIReader
679+
| AnnotationStore
680+
| None = None,
674681
resolution: Resolution = 0,
675682
units: Units = "level",
676683
stride: int | tuple[int, int] | None = None,

tiatoolbox/tools/registration/wsi_registration.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
import itertools
6-
from typing import TYPE_CHECKING, Callable
6+
from typing import TYPE_CHECKING, Callable, cast
77

88
import cv2
99
import numpy as np
@@ -338,19 +338,24 @@ def __init__(self: DFBRFeatureExtractor) -> None:
338338
super().__init__()
339339
output_layers_id: list[str] = ["16", "23", "30"]
340340
output_layers_key: list[str] = ["block3_pool", "block4_pool", "block5_pool"]
341-
self.features: dict = dict.fromkeys(output_layers_key, None)
342-
self.pretrained: torch.nn.Sequential = compile_model(
341+
self.features: dict[str, torch.Tensor] = dict.fromkeys(
342+
output_layers_key, torch.Tensor()
343+
)
344+
345+
compiled_model = compile_model(
343346
torchvision.models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1),
344347
mode=rcParam["torch_compile_mode"],
345-
).features
348+
)
349+
self.pretrained = cast("torch.nn.Module", compiled_model.features)
350+
346351
self.f_hooks = [
347352
getattr(self.pretrained, layer).register_forward_hook(
348353
self.forward_hook(output_layers_key[i]),
349354
)
350355
for i, layer in enumerate(output_layers_id)
351356
]
352357

353-
def forward_hook(self: torch.nn.Module, layer_name: str) -> Callable:
358+
def forward_hook(self: DFBRFeatureExtractor, layer_name: str) -> Callable:
354359
"""Register a hook.
355360
356361
Args:
@@ -386,7 +391,7 @@ def hook(
386391

387392
return hook
388393

389-
def forward(self: torch.nn.Module, x: torch.Tensor) -> dict[str, torch.Tensor]:
394+
def forward(self: DFBRFeatureExtractor, x: torch.Tensor) -> dict[str, torch.Tensor]:
390395
"""Forward pass for feature extraction.
391396
392397
Args:

tiatoolbox/utils/metrics.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,20 @@ def pair_coordinates(
4141
- :class:`numpy.ndarray` - Unpaired B:
4242
Indices of unpaired points in set B.
4343
44+
45+
Examples:
46+
>>> from tiatoolbox.utils.metrics import pair_coordinates
47+
>>> # Generate two random example sets; replace with your own data
48+
>>> import numpy as np
49+
>>> np.random.seed(6)
50+
>>> set_a_num_points = np.random.randint(low=10, high=30)
51+
>>> set_b_num_points = np.random.randint(low=10, high=30)
52+
>>> set_a = np.random.randint(low=0, high=25, size=(set_a_num_points, 2))
53+
>>> set_b = np.random.randint(low=0, high=25, size=(set_b_num_points, 2))
54+
>>> radius = 2.0
55+
>>> # Example usage of pair_coordinates
56+
>>> pairing, unpaired_a, unpaired_b = pair_coordinates(set_a, set_b, radius)
57+
4458
"""
4559
# * Euclidean distance as the cost matrix
4660
pair_distance = distance.cdist(set_a, set_b, metric="euclidean")
@@ -65,7 +79,22 @@ def pair_coordinates(
6579

6680

6781
def f1_detection(true: np.ndarray, pred: np.ndarray, radius: float) -> float:
68-
"""Calculate the F1-score for predicted set of coordinates."""
82+
"""Calculate the F1-score for predicted set of coordinates.
83+
84+
Examples:
85+
>>> from tiatoolbox.utils.metrics import f1_detection
86+
>>> # Generate two random example sets; replace with your own data
87+
>>> import numpy as np
88+
>>> np.random.seed(6)
89+
>>> true_num_points = np.random.randint(low=10, high=30)
90+
>>> pred_num_points = np.random.randint(low=10, high=30)
91+
>>> true = np.random.randint(low=0, high=25, size=(true_num_points, 2))
92+
>>> pred = np.random.randint(low=0, high=25, size=(pred_num_points, 2))
93+
>>> radius = 2.0
94+
>>> # Example usage of f1_detection
95+
>>> f1_score = f1_detection(true, pred, radius)
96+
97+
"""
6998
(paired_true, unpaired_true, unpaired_pred) = pair_coordinates(true, pred, radius)
7099

71100
tp = len(paired_true)
@@ -94,6 +123,16 @@ def dice(gt_mask: np.ndarray, pred_mask: np.ndarray) -> float:
94123
:class:`float`:
95124
An estimate of Sørensen-Dice coefficient value.
96125
126+
Examples:
127+
>>> from tiatoolbox.utils.metrics import dice
128+
>>> # Generate two random example masks; replace with your own data
129+
>>> import numpy as np
130+
>>> np.random.seed(6)
131+
>>> gt_mask = (np.random.rand(256, 256) > 0.8).astype(np.uint8)
132+
>>> pred_mask = (np.random.rand(256, 256) > 0.8).astype(np.uint8)
133+
>>> # Example usage of dice
134+
>>> dice_score = dice(gt_mask, pred_mask)
135+
97136
"""
98137
if gt_mask.shape != pred_mask.shape:
99138
msg = f"{'Shape mismatch between the two masks.'}"

tiatoolbox/utils/transforms.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def imresize(
189189
img_channels = [
190190
cv2.resize(
191191
src=img[..., ch],
192-
dsize=output_size_array,
192+
dsize=(output_size_array[0], output_size_array[1]),
193193
interpolation=cv2_interpolation,
194194
)[
195195
...,
@@ -199,7 +199,11 @@ def imresize(
199199
]
200200
return np.concatenate(img_channels, axis=-1)
201201

202-
return cv2.resize(src=img, dsize=output_size_array, interpolation=cv2_interpolation)
202+
return cv2.resize(
203+
src=img,
204+
dsize=(output_size_array[0], output_size_array[1]),
205+
interpolation=cv2_interpolation,
206+
)
203207

204208

205209
def rgb2od(img: np.ndarray) -> np.ndarray:

tiatoolbox/utils/visualization.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -475,14 +475,15 @@ def overlay_prediction_contours(
475475
inst_colours_array = inst_colours_array.astype(np.uint8)
476476

477477
for idx, [_, inst_info] in enumerate(inst_dict.items()):
478-
inst_contour = inst_info["contour"]
478+
inst_contour: np.ndarray = inst_info["contour"]
479479
if "type" in inst_info and type_colours is not None:
480480
inst_colour = type_colours[inst_info["type"]][1]
481481
else:
482482
inst_colour = (inst_colours_array[idx]).tolist()
483+
contours: list[np.ndarray] = [np.array(inst_contour)]
483484
cv2.drawContours(
484485
overlay,
485-
[np.array(inst_contour)],
486+
contours,
486487
-1,
487488
inst_colour,
488489
line_thickness,
@@ -881,9 +882,10 @@ def render_line(
881882
top_left,
882883
scale,
883884
)
885+
pts: list[np.ndarray] = [np.array(cnt)]
884886
cv2.polylines(
885887
tile,
886-
[np.array(cnt)],
888+
pts,
887889
isClosed=False,
888890
color=col,
889891
thickness=3,

0 commit comments

Comments
 (0)