Skip to content

Commit ca13e7f

Browse files
authored
♻️ Update Changes from New Engine Design (#882)
- Add changes from New engine design #578. This will not only simplify the PR but also keep the main repo up to date. - Refactor `model_to` to `model_abc` - Instead of `on_gpu` use `device` as an input in line with `PyTorch`. - `infer_batch` uses `device` as an input instead of `on_gpu`
1 parent 32cae0b commit ca13e7f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+342
-290
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,6 @@ ENV/
116116

117117
# vim/vi generated
118118
*.swp
119+
120+
# output zarr generated
121+
*.zarr

tests/models/test_abc.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66

77
import pytest
88
import torch
9+
import torchvision.models as torch_models
910
from torch import nn
1011

11-
from tiatoolbox import rcParam
12+
from tiatoolbox import rcParam, utils
1213
from tiatoolbox.models.architecture import (
1314
fetch_pretrained_weights,
1415
get_pretrained_model,
1516
)
16-
from tiatoolbox.models.models_abc import ModelABC
17+
from tiatoolbox.models.models_abc import ModelABC, model_to
1718
from tiatoolbox.utils import env_detection as toolbox_env
1819

1920
if TYPE_CHECKING:
@@ -149,3 +150,18 @@ def test_model_abc() -> None:
149150
weights_path = fetch_pretrained_weights("alexnet-kather100k")
150151
with pytest.raises(RuntimeError, match=r".*loading state_dict*"):
151152
_ = model.load_weights_from_file(weights_path)
153+
154+
155+
def test_model_to() -> None:
156+
"""Test for placing model on device."""
157+
# Test on GPU
158+
# no GPU on GitHub Actions so this will crash
159+
if not utils.env_detection.has_gpu():
160+
model = torch_models.resnet18()
161+
with pytest.raises((AssertionError, RuntimeError)):
162+
_ = model_to(device="cuda", model=model)
163+
164+
# Test on CPU
165+
model = torch_models.resnet18()
166+
model = model_to(device="cpu", model=model)
167+
assert isinstance(model, nn.Module)

tests/models/test_arch_mapde.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_functionality(remote_sample: Callable) -> None:
4545
model = _load_mapde(name="mapde-conic")
4646
patch = model.preproc(patch)
4747
batch = torch.from_numpy(patch)[None]
48-
model = model.to(select_device(on_gpu=ON_GPU))
49-
output = model.infer_batch(model, batch, on_gpu=ON_GPU)
48+
model = model.to()
49+
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
5050
output = model.postproc(output[0])
5151
assert np.all(output[0:2] == [[19, 171], [53, 89]])

tests/models/test_arch_micronet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_functionality(
3939
model = model.to(map_location)
4040
pretrained = torch.load(weights_path, map_location=map_location)
4141
model.load_state_dict(pretrained)
42-
output = model.infer_batch(model, batch, on_gpu=ON_GPU)
42+
output = model.infer_batch(model, batch, device=map_location)
4343
output, _ = model.postproc(output[0])
4444
assert np.max(np.unique(output)) == 46
4545

tests/models/test_arch_nuclick.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from tiatoolbox.models import NuClick
1111
from tiatoolbox.models.architecture import fetch_pretrained_weights
1212
from tiatoolbox.utils import imread
13+
from tiatoolbox.utils.misc import select_device
1314

1415
ON_GPU = False
1516

@@ -53,7 +54,7 @@ def test_functional_nuclick(
5354
model = NuClick(num_input_channels=5, num_output_channels=1)
5455
pretrained = torch.load(weights_path, map_location="cpu")
5556
model.load_state_dict(pretrained)
56-
output = model.infer_batch(model, batch, on_gpu=ON_GPU)
57+
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
5758
postproc_masks = model.postproc(
5859
output,
5960
do_reconstruction=True,

tests/models/test_arch_sccnn.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,18 @@
55
import numpy as np
66
import torch
77

8-
from tiatoolbox import utils
98
from tiatoolbox.models import SCCNN
109
from tiatoolbox.models.architecture import fetch_pretrained_weights
10+
from tiatoolbox.utils import env_detection
11+
from tiatoolbox.utils.misc import select_device
1112
from tiatoolbox.wsicore.wsireader import WSIReader
1213

1314

1415
def _load_sccnn(name: str) -> torch.nn.Module:
1516
"""Loads SCCNN model with specified weights."""
1617
model = SCCNN()
1718
weights_path = fetch_pretrained_weights(name)
18-
map_location = utils.misc.select_device(on_gpu=utils.env_detection.has_gpu())
19+
map_location = select_device(on_gpu=env_detection.has_gpu())
1920
pretrained = torch.load(weights_path, map_location=map_location)
2021
model.load_state_dict(pretrained)
2122

@@ -40,11 +41,19 @@ def test_functionality(remote_sample: Callable) -> None:
4041
)
4142
batch = torch.from_numpy(patch)[None]
4243
model = _load_sccnn(name="sccnn-crchisto")
43-
output = model.infer_batch(model, batch, on_gpu=False)
44+
output = model.infer_batch(
45+
model,
46+
batch,
47+
device=select_device(on_gpu=env_detection.has_gpu()),
48+
)
4449
output = model.postproc(output[0])
4550
assert np.all(output == [[8, 7]])
4651

4752
model = _load_sccnn(name="sccnn-conic")
48-
output = model.infer_batch(model, batch, on_gpu=False)
53+
output = model.infer_batch(
54+
model,
55+
batch,
56+
device=select_device(on_gpu=env_detection.has_gpu()),
57+
)
4958
output = model.postproc(output[0])
5059
assert np.all(output == [[7, 8]])

tests/models/test_arch_unet.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from tiatoolbox.models.architecture import fetch_pretrained_weights
1111
from tiatoolbox.models.architecture.unet import UNetModel
12+
from tiatoolbox.utils.misc import select_device
1213
from tiatoolbox.wsicore.wsireader import WSIReader
1314

1415
ON_GPU = False
@@ -48,7 +49,7 @@ def test_functional_unet(remote_sample: Callable) -> None:
4849
model = UNetModel(3, 2, encoder="resnet50", decoder_block=[3])
4950
pretrained = torch.load(pretrained_weights, map_location="cpu")
5051
model.load_state_dict(pretrained)
51-
output = model.infer_batch(model, batch, on_gpu=ON_GPU)
52+
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
5253
_ = output[0]
5354

5455
# run untrained network to test for architecture
@@ -60,4 +61,4 @@ def test_functional_unet(remote_sample: Callable) -> None:
6061
encoder_levels=[32, 64],
6162
skip_type="concat",
6263
)
63-
_ = model.infer_batch(model, batch, on_gpu=ON_GPU)
64+
_ = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))

tests/models/test_arch_vanilla.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
import torch
66

77
from tiatoolbox.models.architecture.vanilla import CNNModel, TimmModel
8-
from tiatoolbox.utils.misc import model_to
8+
from tiatoolbox.models.models_abc import model_to
99

1010
ON_GPU = False
1111
RNG = np.random.default_rng() # Numpy Random Generator
12+
device = "cuda" if ON_GPU else "cpu"
1213

1314

1415
def test_functional() -> None:
@@ -43,8 +44,8 @@ def test_functional() -> None:
4344
try:
4445
for backbone in backbones:
4546
model = CNNModel(backbone, num_classes=1)
46-
model_ = model_to(on_gpu=ON_GPU, model=model)
47-
model.infer_batch(model_, samples, on_gpu=ON_GPU)
47+
model_ = model_to(device=device, model=model)
48+
model.infer_batch(model_, samples, device=device)
4849
except ValueError as exc:
4950
msg = f"Model {backbone} failed."
5051
raise AssertionError(msg) from exc
@@ -70,8 +71,8 @@ def test_timm_functional() -> None:
7071
try:
7172
for backbone in backbones:
7273
model = TimmModel(backbone=backbone, num_classes=1, pretrained=False)
73-
model_ = model_to(on_gpu=ON_GPU, model=model)
74-
model.infer_batch(model_, samples, on_gpu=ON_GPU)
74+
model_ = model_to(device=device, model=model)
75+
model.infer_batch(model_, samples, device=device)
7576
except ValueError as exc:
7677
msg = f"Model {backbone} failed."
7778
raise AssertionError(msg) from exc

tests/models/test_feature_extractor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
IOSegmentorConfig,
1515
)
1616
from tiatoolbox.utils import env_detection as toolbox_env
17+
from tiatoolbox.utils.misc import select_device
1718
from tiatoolbox.wsicore.wsireader import WSIReader
1819

1920
ON_GPU = not toolbox_env.running_on_ci() and toolbox_env.has_gpu()
@@ -35,7 +36,7 @@ def test_engine(remote_sample: Callable, tmp_path: Path) -> None:
3536
output_list = extractor.predict(
3637
[mini_wsi_svs],
3738
mode="wsi",
38-
on_gpu=ON_GPU,
39+
device=select_device(on_gpu=ON_GPU),
3940
crash_on_exception=True,
4041
save_dir=save_dir,
4142
)
@@ -82,7 +83,7 @@ def test_full_inference(
8283
[mini_wsi_svs],
8384
mode="wsi",
8485
ioconfig=ioconfig,
85-
on_gpu=ON_GPU,
86+
device=select_device(on_gpu=ON_GPU),
8687
crash_on_exception=True,
8788
save_dir=save_dir,
8889
)

tests/models/test_hovernet.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
ResidualBlock,
1515
TFSamepaddingLayer,
1616
)
17+
from tiatoolbox.utils.misc import select_device
1718
from tiatoolbox.wsicore.wsireader import WSIReader
1819

1920

@@ -34,7 +35,7 @@ def test_functionality(remote_sample: Callable) -> None:
3435
weights_path = fetch_pretrained_weights("hovernet_fast-pannuke")
3536
pretrained = torch.load(weights_path)
3637
model.load_state_dict(pretrained)
37-
output = model.infer_batch(model, batch, on_gpu=False)
38+
output = model.infer_batch(model, batch, device=select_device(on_gpu=False))
3839
output = [v[0] for v in output]
3940
output = model.postproc(output)
4041
assert len(output[1]) > 0, "Must have some nuclei."
@@ -51,7 +52,7 @@ def test_functionality(remote_sample: Callable) -> None:
5152
weights_path = fetch_pretrained_weights("hovernet_fast-monusac")
5253
pretrained = torch.load(weights_path)
5354
model.load_state_dict(pretrained)
54-
output = model.infer_batch(model, batch, on_gpu=False)
55+
output = model.infer_batch(model, batch, device=select_device(on_gpu=False))
5556
output = [v[0] for v in output]
5657
output = model.postproc(output)
5758
assert len(output[1]) > 0, "Must have some nuclei."
@@ -68,7 +69,7 @@ def test_functionality(remote_sample: Callable) -> None:
6869
weights_path = fetch_pretrained_weights("hovernet_original-consep")
6970
pretrained = torch.load(weights_path)
7071
model.load_state_dict(pretrained)
71-
output = model.infer_batch(model, batch, on_gpu=False)
72+
output = model.infer_batch(model, batch, device=select_device(on_gpu=False))
7273
output = [v[0] for v in output]
7374
output = model.postproc(output)
7475
assert len(output[1]) > 0, "Must have some nuclei."
@@ -85,7 +86,7 @@ def test_functionality(remote_sample: Callable) -> None:
8586
weights_path = fetch_pretrained_weights("hovernet_original-kumar")
8687
pretrained = torch.load(weights_path)
8788
model.load_state_dict(pretrained)
88-
output = model.infer_batch(model, batch, on_gpu=False)
89+
output = model.infer_batch(model, batch, device=select_device(on_gpu=False))
8990
output = [v[0] for v in output]
9091
output = model.postproc(output)
9192
assert len(output[1]) > 0, "Must have some nuclei."

0 commit comments

Comments
 (0)