Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
b271f3e
FIX: Update for multi-GPU support in models_abc
adamshephard Apr 16, 2025
cc5407a
UPD: Update code
adamshephard Apr 16, 2025
e124cd8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 16, 2025
698f16a
Merge branch 'develop' into models-abc-multigpu
adamshephard Apr 25, 2025
ee25842
Merge branch 'develop' into models-abc-multigpu
adamshephard May 9, 2025
1f15307
Merge branch 'develop' into models-abc-multigpu
adamshephard May 11, 2025
e7b0822
FIX: Fix to work on other machines
adamshephard May 12, 2025
0615636
FIX: Fix to work on other machines
adamshephard May 12, 2025
b830c11
Merge branch 'models-abc-multigpu' of https://github.com/TissueImageA…
adamshephard May 12, 2025
a1d7357
FIX: Fix to work on other machines
adamshephard May 12, 2025
73440c2
Merge branch 'develop' into models-abc-multigpu
adamshephard May 27, 2025
b1d80dc
Merge branch 'models-abc-multigpu' of https://github.com/TissueImageA…
adamshephard Jun 2, 2025
41a74aa
Merge branch 'develop' into models-abc-multigpu
adamshephard Jun 2, 2025
56df269
Merge branch 'develop' into models-abc-multigpu
shaneahmed Jun 9, 2025
9b7b24e
Merge branch 'models-abc-multigpu' of https://github.com/TissueImageA…
adamshephard Jun 12, 2025
f914933
Merge branch 'develop' into models-abc-multigpu
adamshephard Jun 13, 2025
409498c
Merge branch 'models-abc-multigpu' of github.com:TissueImageAnalytics…
adamshephard Jun 13, 2025
f1d7cc4
UPD: Comment out cuda for coverage
adamshephard Jun 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions tests/models/test_feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,48 @@ def test_full_inference(
# ! else the output values will not exactly be the same (still < 1.0e-4
# ! of epsilon though)
assert np.mean(np.abs(features[:4] - _features)) < 1.0e-1


@pytest.mark.skipif(
toolbox_env.running_on_ci() or not ON_GPU,
reason="Local test on machine with GPU.",
)
def test_multi_gpu_feature_extraction(remote_sample: Callable, tmp_path: Path) -> None:
"""Local functionality test for feature extraction using multiple GPUs."""
save_dir = tmp_path / "output"
mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs"))
shutil.rmtree(save_dir, ignore_errors=True)

# Use multiple GPUs
device = select_device(on_gpu=ON_GPU)

wsi_ioconfig = IOSegmentorConfig(
input_resolutions=[{"units": "mpp", "resolution": 0.5}],
patch_input_shape=[224, 224],
output_resolutions=[{"units": "mpp", "resolution": 0.5}],
patch_output_shape=[224, 224],
stride_shape=[224, 224],
)

model = TimmBackbone(backbone="UNI", pretrained=True)
extractor = DeepFeatureExtractor(
model=model,
auto_generate_mask=True,
batch_size=32,
num_loader_workers=4,
num_postproc_workers=4,
)

output_list = extractor.predict(
[mini_wsi_svs],
mode="wsi",
device=device,
ioconfig=wsi_ioconfig,
crash_on_exception=True,
save_dir=save_dir,
)
wsi_0_root_path = output_list[0][1]
positions = np.load(f"{wsi_0_root_path}.position.npy")
features = np.load(f"{wsi_0_root_path}.features.0.npy")
assert len(positions.shape) == 2
assert len(features.shape) == 2
9 changes: 9 additions & 0 deletions tiatoolbox/models/engine/semantic_segmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import joblib
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as torch_mp
import torch.utils.data as torch_data
import tqdm
Expand Down Expand Up @@ -1421,6 +1422,14 @@ def predict( # noqa: PLR0913
logger.warning("Unable to remove %s", self._cache_dir)

self._memory_cleanup()
from tiatoolbox.models.architecture.utils import is_torch_compile_compatible

if (
device == "cuda"
and torch.cuda.device_count() > 1
and is_torch_compile_compatible()
): # pragma: no cover
dist.destroy_process_group()
Copy link
Preview

Copilot AI Jun 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Destroying the process group without verifying initialization may error if no group exists. Add if dist.is_initialized(): before calling destroy_process_group().

Suggested change
dist.destroy_process_group()
if dist.is_initialized():
dist.destroy_process_group()

Copilot uses AI. Check for mistakes.


return self._outputs

Expand Down
29 changes: 26 additions & 3 deletions tiatoolbox/models/models_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@

from __future__ import annotations

import os
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable

import torch
import torch._dynamo
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel

from tiatoolbox.models.architecture.utils import is_torch_compile_compatible

torch._dynamo.config.suppress_errors = True # skipcq: PYL-W0212 # noqa: SLF001

Expand Down Expand Up @@ -51,12 +56,30 @@ def model_to(model: torch.nn.Module, device: str = "cpu") -> torch.nn.Module:
The model after being moved to specified device.

"""
if device != "cpu":
torch_device = torch.device(device)

# Use DDP if multiple GPUs and not on CPU
if (
device == "cuda"
and torch.cuda.device_count() > 1
and is_torch_compile_compatible()
): # pragma: no cover
# This assumes a single-process DDP setup for inference
model = model.to(torch_device)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group(backend="nccl", rank=0, world_size=1)
model = DistributedDataParallel(model, device_ids=[torch_device.index])

elif device != "cpu":
# DataParallel work only for cuda
model = torch.nn.DataParallel(model)
model = model.to(torch_device)

torch_device = torch.device(device)
return model.to(torch_device)
else:
model = model.to(torch_device)

return model


class ModelABC(ABC, torch.nn.Module):
Expand Down