Skip to content

✨ Integrate SAM as an Interactive Segmentation Tool #918

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 58 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
ca37400
Merge pull request #1 from TissueImageAnalytics/develop
mbasheer04 Jan 17, 2025
dab2693
Merge branch 'TissueImageAnalytics:develop' into develop
mbasheer04 Jan 24, 2025
bc6afda
Squashed commit of the following:
mbasheer04 Jan 24, 2025
5f8032d
Merge branch 'develop' of https://github.com/mbasheer04/tiatoolbox in…
mbasheer04 Jan 24, 2025
a543460
Integrating SAM into bokeh
mbasheer04 Jan 30, 2025
c620ca4
Added save file for GeneralSegmentor output
mbasheer04 Jan 30, 2025
9217b69
Added on-click prompt segementation to TIAViz
mbasheer04 Feb 19, 2025
60f4d5f
Added multi-prompt segmentation & bounding-box
mbasheer04 Feb 20, 2025
09a79fd
Added scores to masks
mbasheer04 Feb 22, 2025
1a4a76c
Attempting to add resolution/window-based segmentation
mbasheer04 Feb 26, 2025
b322539
Successfully implemented window-based segmentation
mbasheer04 Feb 27, 2025
35237c4
Fixing issues
mbasheer04 Mar 4, 2025
6ad56fe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 20, 2025
2cb7d9c
Merge branch 'develop' into sam-viz
shaneahmed Mar 21, 2025
7b403c9
Preparing for PR
mbasheer04 Mar 20, 2025
a8f2938
Merge branch 'sam-viz' of https://github.com/mbasheer04/tiatoolbox in…
mbasheer04 Mar 27, 2025
6e8b859
Restoring notebook changes
mbasheer04 Mar 27, 2025
51ecbe9
Pre-commit fixes
mbasheer04 Mar 28, 2025
fd3a2c2
Added centroid extraction
mbasheer04 Apr 3, 2025
315d2c8
Merge branch 'develop' into sam-viz
shaneahmed Apr 4, 2025
a5d5971
Merge branch 'develop' into sam-viz
adamshephard Apr 10, 2025
ec9b76c
Improving Engine
mbasheer04 Apr 11, 2025
5db9657
Removing eval files
mbasheer04 Apr 11, 2025
e193646
Merge branch 'sam-viz' of https://github.com/mbasheer04/tiatoolbox in…
mbasheer04 Apr 11, 2025
effdadc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2025
f69dfb7
Cleaning up code
mbasheer04 Apr 11, 2025
c02f153
Merge branch 'sam-viz' of https://github.com/mbasheer04/tiatoolbox in…
mbasheer04 Apr 11, 2025
c1cb059
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2025
71eedb2
Added docstrings
mbasheer04 Apr 16, 2025
0878a66
Added SAM to requirements.txt
mbasheer04 Apr 17, 2025
65d72aa
Added pretrained model and fixed requirements
mbasheer04 Apr 24, 2025
be52c04
Adding engine unit tests
mbasheer04 Apr 25, 2025
2517844
Improving unit tests
mbasheer04 Apr 25, 2025
210ad2d
Improving unit tests
mbasheer04 Apr 29, 2025
821a848
Add store from #926
mbasheer04 Apr 29, 2025
5bf382b
Added annotation-based save for engine
mbasheer04 May 1, 2025
b5afc18
Merge branch 'develop' into sam-viz
mbasheer04 May 2, 2025
fbfa884
Fixing DeepSource issues
mbasheer04 May 2, 2025
37939c3
Merge branch 'sam-viz' of https://github.com/mbasheer04/tiatoolbox in…
mbasheer04 May 2, 2025
673c318
Debugging engine
mbasheer04 May 7, 2025
63d8012
Switched to transformers
mbasheer04 May 8, 2025
5c3656a
Finishing unit tests
mbasheer04 May 9, 2025
fd0539e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 9, 2025
01af916
Fixing deepsource
mbasheer04 May 9, 2025
311295d
Merge branch 'develop' into sam-viz
mbasheer04 May 9, 2025
c992fc8
Merge branch 'sam-viz' of https://github.com/mbasheer04/tiatoolbox in…
mbasheer04 May 9, 2025
4765906
Fixing build errors
mbasheer04 May 9, 2025
73150f2
Merge branch 'develop' into sam-viz
mbasheer04 May 9, 2025
7f122b3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 9, 2025
cc4f883
Fixing more build errors
mbasheer04 May 9, 2025
9c0e1bd
Merge branch 'sam-viz' of https://github.com/mbasheer04/tiatoolbox in…
mbasheer04 May 9, 2025
53374aa
Fixing arch test
mbasheer04 May 9, 2025
0e13290
Removed whole image mask generation
mbasheer04 May 10, 2025
6190647
Added missing exception tests
mbasheer04 May 10, 2025
94e09f1
Fixed image encoding to work with non-square images
mbasheer04 May 11, 2025
c24016c
Fixing issues with tile mode
mbasheer04 May 12, 2025
c7d3dc6
Fixed TIAViz issues
mbasheer04 May 12, 2025
85b1148
Added mask again
mbasheer04 May 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ timm>=1.0.3
torch>=2.1.0
torchvision>=0.15.0
tqdm>=4.64.1
transformers>=4.51.1
umap-learn>=0.5.3
wsidicom>=0.18.0
zarr>=2.13.3, <3.0.0
65 changes: 65 additions & 0 deletions tests/models/test_arch_sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Unit test package for SAM."""

from pathlib import Path
from typing import Callable

import numpy as np
import torch

from tiatoolbox.models.architecture.sam import SAM
from tiatoolbox.utils import env_detection as toolbox_env
from tiatoolbox.utils import imread
from tiatoolbox.utils.misc import select_device

ON_GPU = toolbox_env.has_gpu()

# Test pretrained Model =============================


def test_functional_sam(
remote_sample: Callable,
) -> None:
"""Test for SAM."""
# convert to pathlib Path to prevent wsireader complaint
tile_path = Path(remote_sample("patch-extraction-vf"))
img = imread(tile_path)

# test creation

model = SAM(device=select_device(on_gpu=ON_GPU))

# create image patch and prompts
patch = img[63:191, 750:878, :]

points = [[[64, 64]]]
boxes = [[[64, 64, 128, 128]]]

# test preproc
tensor = torch.from_numpy(img)
patch = np.expand_dims(model.preproc(tensor), axis=0)
patch = model.preproc(patch)

# test inference

mask_output, score_output = model.infer_batch(
model, patch, points, device=select_device(on_gpu=ON_GPU)
)

assert mask_output is not None, "Output should not be None"
assert len(mask_output) > 0, "Output should have at least one element"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on points, it should have just one element from the one point? Maybe correct the wording of this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for below bounding boxes

assert len(score_output) > 0, "Output should have at least one element"

mask_output, score_output = model.infer_batch(
model, patch, box_coords=boxes, device=select_device(on_gpu=ON_GPU)
)

assert len(mask_output) > 0, "Output should have at least one element"
assert len(score_output) > 0, "Output should have at least one element"

mask_output, score_output = model.infer_batch(
model, patch, device=select_device(on_gpu=ON_GPU)
)

assert mask_output is not None, "Output should not be None"
assert len(mask_output) > 0, "Output should have at least one element"
assert len(score_output) > 0, "Output should have at least one element"
273 changes: 273 additions & 0 deletions tests/models/test_prompt_segmentor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
"""Unit test package for Prompt Segmentor."""

from __future__ import annotations

# ! The garbage collector
import multiprocessing
import shutil
from pathlib import Path
from typing import Callable

import numpy as np
import pytest

from tiatoolbox.models import PromptSegmentor
from tiatoolbox.models.architecture.sam import SAM
from tiatoolbox.models.engine.semantic_segmentor import (
IOSegmentorConfig,
)
from tiatoolbox.utils import env_detection as toolbox_env
from tiatoolbox.utils import imwrite
from tiatoolbox.utils.misc import select_device
from tiatoolbox.wsicore.wsireader import WSIReader

ON_GPU = toolbox_env.has_gpu()
BATCH_SIZE = 1 if not ON_GPU else 2
try:
NUM_LOADER_WORKERS = multiprocessing.cpu_count()
except NotImplementedError:
NUM_LOADER_WORKERS = 2


def test_functional_segmentor(
remote_sample: Callable,
tmp_path: Path,
) -> None:
"""Functional test for segmentor."""
save_dir = tmp_path / "dump"
# # convert to pathlib Path to prevent wsireader complaint
resolution = 2.0
mini_wsi_svs = Path(remote_sample("patch-extraction-vf"))
reader = WSIReader.open(mini_wsi_svs, resolution)
thumb = reader.slide_thumbnail(resolution=resolution, units="mpp")
thumb = thumb[63:191, 750:878, :]
mini_wsi_jpg = f"{tmp_path}/mini_svs.jpg"
imwrite(mini_wsi_jpg, thumb)

# preemptive clean up
shutil.rmtree(save_dir, ignore_errors=True)

model = SAM()

# test engine setup

_ = PromptSegmentor(None, BATCH_SIZE, NUM_LOADER_WORKERS)

prompt_segmentor = PromptSegmentor(model, BATCH_SIZE, NUM_LOADER_WORKERS)

ioconfig = IOSegmentorConfig(
input_resolutions=[
{"units": "mpp", "resolution": 4.0},
],
output_resolutions=[{"units": "mpp", "resolution": 4.0}],
patch_input_shape=[512, 512],
patch_output_shape=[512, 512],
stride_shape=[512, 512],
)

# test inference

points = np.array([[[64, 64]], [[64, 64]]]) # Point on nuclei
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why two of the same points? Does it not make more sense to have two different points here?


# Run on tile mode with multi-prompt
# Test running with multiple images
shutil.rmtree(save_dir, ignore_errors=True)
output_list = prompt_segmentor.predict(
[mini_wsi_jpg, mini_wsi_jpg],
mode="tile",
multi_prompt=True,
device=select_device(on_gpu=ON_GPU),
point_coords=points,
ioconfig=ioconfig,
crash_on_exception=False,
save_dir=save_dir,
)

pred_1 = np.load(output_list[0][1] + "/0.raw.0.npy")
pred_2 = np.load(output_list[1][1] + "/0.raw.0.npy")
assert len(output_list) == 2
assert np.sum(pred_1 - pred_2) == 0

points = np.array([[[64, 64], [100, 40], [100, 70]]]) # Points on nuclei
boxes = np.array([[[10, 10, 50, 50], [80, 80, 110, 110]]]) # Boxes on nuclei

# Run on tile mode with single-prompt
# Also tests boxes
shutil.rmtree(save_dir, ignore_errors=True)
output_list = prompt_segmentor.predict(
[mini_wsi_jpg],
mode="tile",
multi_prompt=False,
device=select_device(on_gpu=ON_GPU),
point_coords=points,
box_coords=boxes,
ioconfig=ioconfig,
crash_on_exception=False,
save_dir=save_dir,
)

total_prompts = points.shape[1] + boxes.shape[1]
preds = [
np.load(output_list[0][1] + f"/{i}.raw.0.npy") for i in range(total_prompts)
]

assert len(output_list) == 1
assert len(preds) == total_prompts

# Generate mask
mask = np.zeros((thumb.shape[0], thumb.shape[1]), dtype=np.uint8)
mask[32:120, 32:120] = 1
mini_wsi_msk = f"{tmp_path}/mini_svs_mask.jpg"
imwrite(mini_wsi_msk, mask)

ioconfig = IOSegmentorConfig(
input_resolutions=[
{"units": "baseline", "resolution": 1.0},
],
output_resolutions=[{"units": "baseline", "resolution": 1.0}],
patch_input_shape=[512, 512],
patch_output_shape=[512, 512],
stride_shape=[512, 512],
save_resolution={"units": "baseline", "resolution": 1.0},
)

# Only point within mask should generate a segmentation
points = np.array([[[64, 64], [100, 40]]])
save_dir = tmp_path / "dump"

# Run on wsi mode with multi-prompt
# Also tests masks
shutil.rmtree(save_dir, ignore_errors=True)
output_list = prompt_segmentor.predict(
[mini_wsi_jpg],
masks=[mini_wsi_msk],
mode="wsi",
multi_prompt=True,
device=select_device(on_gpu=ON_GPU),
point_coords=points,
ioconfig=ioconfig,
crash_on_exception=False,
save_dir=save_dir,
)

# Check if db exists
assert Path(output_list[0][1] + ".0.db").exists()

points = np.array([[[10, 30]]])
boxes = np.array([[[10, 10, 30, 30]]])
# Test no prompts within mask
shutil.rmtree(save_dir, ignore_errors=True)
output_list = prompt_segmentor.predict(
[mini_wsi_jpg],
masks=[mini_wsi_msk],
mode="wsi",
multi_prompt=True,
device=select_device(on_gpu=ON_GPU),
point_coords=points,
box_coords=boxes,
ioconfig=ioconfig,
crash_on_exception=False,
save_dir=save_dir,
)
# Check if db exists
assert Path(output_list[0][1] + ".0.db").exists()

# Run on wsi mode with single-prompt
shutil.rmtree(save_dir, ignore_errors=True)
output_list = prompt_segmentor.predict(
[mini_wsi_jpg],
mode="wsi",
multi_prompt=False,
device=select_device(on_gpu=ON_GPU),
point_coords=points,
ioconfig=ioconfig,
crash_on_exception=False,
save_dir=save_dir,
)

# Check if db exists
assert Path(output_list[0][1] + ".0.db").exists()


def test_crash_segmentor(remote_sample: Callable, tmp_path: Path) -> None:
"""Functional crash tests for segmentor."""
# # convert to pathlib Path to prevent wsireader complaint
mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs"))
mini_wsi_msk = Path(remote_sample("wsi2_4k_4k_msk"))

save_dir = tmp_path / "test_crash_segmentor"
prompt_segmentor = PromptSegmentor(batch_size=BATCH_SIZE)

# * test basic crash
with pytest.raises(TypeError, match=r".*`mask_reader`.*"):
prompt_segmentor.filter_coordinates(mini_wsi_msk, np.array(["a", "b", "c"]))
with pytest.raises(TypeError, match=r".*`mask_reader`.*"):
prompt_segmentor.get_mask_bounds(mini_wsi_msk)
with pytest.raises(TypeError, match=r".*mask_reader.*"):
prompt_segmentor.clip_coordinates(mini_wsi_msk, np.array(["a", "b", "c"]))

with pytest.raises(ValueError, match=r".*ndarray.*integer.*"):
prompt_segmentor.filter_coordinates(
WSIReader.open(mini_wsi_msk),
np.array([1.0, 2.0]),
)
with pytest.raises(ValueError, match=r".*ndarray.*integer.*"):
prompt_segmentor.clip_coordinates(
WSIReader.open(mini_wsi_msk),
np.array([1.0, 2.0]),
)
prompt_segmentor.get_reader(mini_wsi_svs, None, "wsi", auto_get_mask=True)
with pytest.raises(ValueError, match=r".*must be a valid file path.*"):
prompt_segmentor.get_reader(
mini_wsi_msk,
"not_exist",
"wsi",
auto_get_mask=True,
)

shutil.rmtree(save_dir, ignore_errors=True) # default output dir test
with pytest.raises(ValueError, match=r".*valid mode.*"):
prompt_segmentor.predict([], mode="abc")

crash_segmentor = PromptSegmentor()

# * test crash segmentor
def _predict_one_wsi(
*args: dict,
**kwargs: dict,
) -> tuple[WSIReader, str]:
"""Override the predict function to test crash segmentor."""
msg = f"Test crash segmentor:{args} {kwargs}"
raise RuntimeError(msg)

crash_segmentor._predict_one_wsi = _predict_one_wsi
shutil.rmtree(save_dir, ignore_errors=True)
with pytest.raises(
RuntimeError,
match=r"Test crash segmentor:\(.*\) \{.*\}",
):
crash_segmentor.predict(
[mini_wsi_svs],
mode="wsi",
multi_prompt=True,
device=select_device(on_gpu=ON_GPU),
patch_input_shape=[512, 512],
resolution=2.0,
units="mpp",
crash_on_exception=True,
save_dir=save_dir,
)

# test ignore crash
shutil.rmtree(save_dir, ignore_errors=True)
crash_segmentor.predict(
[mini_wsi_svs],
mode="wsi",
multi_prompt=True,
device=select_device(on_gpu=ON_GPU),
patch_input_shape=[512, 512],
resolution=2.0,
units="mpp",
crash_on_exception=False,
save_dir=save_dir,
)
44 changes: 44 additions & 0 deletions tests/test_app_bokeh.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,50 @@ def test_hovernet_on_box(doc: Document, data_path: pytest.TempPathFactory) -> No
assert len(main.UI["type_column"].children) == 1


def test_sam_segment(doc: Document, data_path: pytest.TempPathFactory) -> None:
"""Test running SAM on points and a box."""
slide_select = doc.get_model_by_name("slide_select0")
slide_select.value = [data_path["slide2"].name]
run_button = doc.get_model_by_name("to_model0")
assert len(main.UI["color_column"].children) == 0
slide_select.value = [data_path["slide1"].name]
# set up a box selection
main.UI["box_source"].data = {
"x": [1200],
"y": [-2000],
"width": [400],
"height": [400],
}

# select SAM model and run it on box
model_select = doc.get_model_by_name("model_drop0")
model_select.value = "SAM"

click = ButtonClick(run_button)
run_button._trigger_event(click)
assert len(main.UI["color_column"].children) > 0

# test save functionality
save_button = doc.get_model_by_name("save_button0")
click = ButtonClick(save_button)
save_button._trigger_event(click)
saved_path = (
data_path["base_path"] / "overlays" / (data_path["slide1"].stem + ".db")
)
assert saved_path.exists()

# load an overlay with different types
cprop_select = doc.get_model_by_name("cprop0")
cprop_select.value = ["prob"]
layer_drop = doc.get_model_by_name("layer_drop0")
click = MenuItemClick(layer_drop, str(data_path["dat_anns"]))
layer_drop._trigger_event(click)
assert main.UI["vstate"].types == ["annotation"]
# check the per-type ui controls have been updated
assert len(main.UI["color_column"].children) == 1
assert len(main.UI["type_column"].children) == 1


def test_alpha_sliders(doc: Document) -> None:
"""Test sliders for adjusting slide and overlay alpha."""
slide_alpha = doc.get_model_by_name("slide_alpha0")
Expand Down
Loading
Loading