Skip to content

✨ Integrate SAM as an Interactive Segmentation Tool #918

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 32 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
ca37400
Merge pull request #1 from TissueImageAnalytics/develop
mbasheer04 Jan 17, 2025
dab2693
Merge branch 'TissueImageAnalytics:develop' into develop
mbasheer04 Jan 24, 2025
bc6afda
Squashed commit of the following:
mbasheer04 Jan 24, 2025
5f8032d
Merge branch 'develop' of https://github.com/mbasheer04/tiatoolbox in…
mbasheer04 Jan 24, 2025
a543460
Integrating SAM into bokeh
mbasheer04 Jan 30, 2025
c620ca4
Added save file for GeneralSegmentor output
mbasheer04 Jan 30, 2025
9217b69
Added on-click prompt segementation to TIAViz
mbasheer04 Feb 19, 2025
60f4d5f
Added multi-prompt segmentation & bounding-box
mbasheer04 Feb 20, 2025
09a79fd
Added scores to masks
mbasheer04 Feb 22, 2025
1a4a76c
Attempting to add resolution/window-based segmentation
mbasheer04 Feb 26, 2025
b322539
Successfully implemented window-based segmentation
mbasheer04 Feb 27, 2025
35237c4
Fixing issues
mbasheer04 Mar 4, 2025
6ad56fe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 20, 2025
2cb7d9c
Merge branch 'develop' into sam-viz
shaneahmed Mar 21, 2025
7b403c9
Preparing for PR
mbasheer04 Mar 20, 2025
a8f2938
Merge branch 'sam-viz' of https://github.com/mbasheer04/tiatoolbox in…
mbasheer04 Mar 27, 2025
6e8b859
Restoring notebook changes
mbasheer04 Mar 27, 2025
51ecbe9
Pre-commit fixes
mbasheer04 Mar 28, 2025
fd3a2c2
Added centroid extraction
mbasheer04 Apr 3, 2025
315d2c8
Merge branch 'develop' into sam-viz
shaneahmed Apr 4, 2025
a5d5971
Merge branch 'develop' into sam-viz
adamshephard Apr 10, 2025
ec9b76c
Improving Engine
mbasheer04 Apr 11, 2025
5db9657
Removing eval files
mbasheer04 Apr 11, 2025
e193646
Merge branch 'sam-viz' of https://github.com/mbasheer04/tiatoolbox in…
mbasheer04 Apr 11, 2025
effdadc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2025
f69dfb7
Cleaning up code
mbasheer04 Apr 11, 2025
c02f153
Merge branch 'sam-viz' of https://github.com/mbasheer04/tiatoolbox in…
mbasheer04 Apr 11, 2025
c1cb059
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2025
71eedb2
Added docstrings
mbasheer04 Apr 16, 2025
0878a66
Added SAM to requirements.txt
mbasheer04 Apr 17, 2025
65d72aa
Added pretrained model and fixed requirements
mbasheer04 Apr 24, 2025
be52c04
Adding engine unit tests
mbasheer04 Apr 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ requests>=2.28.1
scikit-image>=0.20
scikit-learn>=1.2.0
scipy>=1.8
segment-anything-py>=1.0.0
shapely>=2.0.0
SimpleITK>=2.2.1
sphinx>=5.3.0
Expand Down
64 changes: 64 additions & 0 deletions tests/models/test_arch_sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Unit test package for SAM."""

from pathlib import Path
from typing import Callable

import numpy as np

from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.models.architecture.sam import SAM
from tiatoolbox.utils import env_detection as toolbox_env
from tiatoolbox.utils import imread
from tiatoolbox.utils.misc import select_device

ON_GPU = toolbox_env.has_gpu()

# Test pretrained Model =============================


def test_functional_sam(
remote_sample: Callable,
) -> None:
"""Test for SAM."""
# convert to pathlib Path to prevent wsireader complaint
tile_path = Path(remote_sample("patch-extraction-vf"))
img = imread(tile_path)

weights_path = fetch_pretrained_weights("segment_anything-base")

# test creation

model = SAM()
model.load_state_dict(weights_path)

# test inference

# create image patch and prompts
patch = np.expand_dims(img[63:191, 750:878, :], axis=0)
patch = model.preproc(patch) # pre-process the image

points = np.array([[[64, 64]]], dtype=np.int32)
boxes = np.array([[64, 64, 128, 128]], dtype=np.int32)

mask_output, score_output = model.infer_batch(
model, patch, points, device=select_device(on_gpu=ON_GPU)
)

assert mask_output is not None, "Output should not be None"
assert len(mask_output) > 0, "Output should have at least one element"
assert len(score_output) > 0, "Output should have at least one element"

mask_output, score_output = model.infer_batch(
model, patch, box_coords=boxes, device=select_device(on_gpu=ON_GPU)
)

assert len(mask_output) > 0, "Output should have at least one element"
assert len(score_output) > 0, "Output should have at least one element"

mask_output, score_output = model.infer_batch(
model, patch, device=select_device(on_gpu=ON_GPU)
)

assert mask_output is not None, "Output should not be None"
assert len(mask_output) > 0, "Output should have at least one element"
assert len(score_output) > 0, "Output should have at least one element"
133 changes: 133 additions & 0 deletions tests/models/test_prompt_segmentor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""Unit test package for Prompt Segmentor."""

from __future__ import annotations

# ! The garbage collector
import multiprocessing
import shutil
from pathlib import Path
from typing import Callable

import numpy as np

from tiatoolbox.models import PromptSegmentor
from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.models.architecture.sam import SAM
from tiatoolbox.models.engine.semantic_segmentor import (
IOSegmentorConfig,
)
from tiatoolbox.utils import env_detection as toolbox_env
from tiatoolbox.utils import imwrite
from tiatoolbox.utils.misc import select_device
from tiatoolbox.wsicore.wsireader import WSIReader

ON_GPU = toolbox_env.has_gpu()
# The value is based on 2 TitanXP each with 12GB
BATCH_SIZE = 1 if not ON_GPU else 16
try:
NUM_LOADER_WORKERS = multiprocessing.cpu_count()
except NotImplementedError:
NUM_LOADER_WORKERS = 2


def test_functional_segmentor(
remote_sample: Callable,
tmp_path: Path,
) -> None:
"""Functional test for segmentor."""
save_dir = tmp_path / "dump"
# # convert to pathlib Path to prevent wsireader complaint
resolution = 2.0
mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs"))
reader = WSIReader.open(mini_wsi_svs)
thumb = reader.slide_thumbnail(resolution=resolution, units="mpp")
mini_wsi_jpg = f"{tmp_path}/mini_svs.jpg"
imwrite(mini_wsi_jpg, thumb)
mini_wsi_msk = f"{tmp_path}/mini_mask.jpg"
imwrite(mini_wsi_msk, (thumb > 0).astype(np.uint8))

# preemptive clean up
shutil.rmtree(save_dir, ignore_errors=True)

model = SAM()
model.load_state_dict(fetch_pretrained_weights("segment_anything-base"))

prompt_segmentor = PromptSegmentor(model, BATCH_SIZE, NUM_LOADER_WORKERS)

ioconfig = IOSegmentorConfig(
input_resolutions=[
{"units": "mpp", "resolution": 2.0},
{"units": "mpp", "resolution": 1.0},
],
output_resolutions=[{"units": "mpp", "resolution": 2.0}],
patch_input_shape=[512, 512],
patch_output_shape=[512, 512],
)

points = np.array([[[64, 64], [100, 100], [64, 100], [100, 64]]]) # Random points

# Run on tile mode with multi-prompt
shutil.rmtree(save_dir, ignore_errors=True)
output_list = prompt_segmentor.predict(
[mini_wsi_jpg],
mode="tile",
multi_prompt=True,
device=select_device(on_gpu=ON_GPU),
point_coords=points,
ioconfig=ioconfig,
crash_on_exception=False,
save_dir=save_dir,
)

assert len(output_list) == 1

# Run on tile mode with single-prompt
shutil.rmtree(save_dir, ignore_errors=True)
output_list = prompt_segmentor.predict(
[mini_wsi_jpg],
mode="tile",
multi_prompt=False,
device=select_device(on_gpu=ON_GPU),
points_coords=points,
ioconfig=ioconfig,
crash_on_exception=False,
save_dir=save_dir,
)

pred_1 = np.load(output_list[0][1] + ".raw.0.npy")
pred_2 = np.load(output_list[1][1] + ".raw.0.npy")
assert len(output_list) == 4
assert np.sum(pred_1 - pred_2) == 0
# due to overlapping merge and division, will not be
# exactly 1, but should be approximately so
assert np.sum((pred_1 - 1) > 1.0e-6) == 0
shutil.rmtree(save_dir, ignore_errors=True)

# * test running with mask and svs
# * also test merging prediction at designated resolution
ioconfig = IOSegmentorConfig(
input_resolutions=[{"units": "mpp", "resolution": resolution}],
output_resolutions=[{"units": "mpp", "resolution": resolution}],
save_resolution={"units": "mpp", "resolution": resolution},
patch_input_shape=[512, 512],
patch_output_shape=[256, 256],
stride_shape=[512, 512],
)
shutil.rmtree(save_dir, ignore_errors=True)
output_list = prompt_segmentor.predict(
[mini_wsi_svs],
masks=[mini_wsi_msk],
mode="wsi",
device=select_device(on_gpu=ON_GPU),
ioconfig=ioconfig,
crash_on_exception=True,
save_dir=f"{save_dir}/raw/",
)
reader = WSIReader.open(mini_wsi_svs)
expected_shape = reader.slide_dimensions(**ioconfig.save_resolution)
expected_shape = np.array(expected_shape)[::-1] # to YX
pred_1 = np.load(output_list[0][1] + ".raw.0.npy")
saved_shape = np.array(pred_1.shape[:2])
assert np.sum(expected_shape - saved_shape) == 0
assert np.sum((pred_1 - 1) > 1.0e-6) == 0
shutil.rmtree(save_dir, ignore_errors=True)
50 changes: 50 additions & 0 deletions tests/test_app_bokeh.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,56 @@ def test_hovernet_on_box(doc: Document, data_path: pytest.TempPathFactory) -> No
assert len(main.UI["type_column"].children) == 1


def test_sam_segment(doc: Document, data_path: pytest.TempPathFactory) -> None:
"""Test running hovernet on a box."""
slide_select = doc.get_model_by_name("slide_select0")
slide_select.value = [data_path["slide2"].name]
run_button = doc.get_model_by_name("to_model0")
assert len(main.UI["color_column"].children) == 0
slide_select.value = [data_path["slide1"].name]
# set up a box selection
main.UI["box_source"].data = {
"x": [1200],
"y": [-2000],
"width": [400],
"height": [400],
}

# select hovernet model and run it on box
model_select = doc.get_model_by_name("model_drop0")
model_select.value = "hovernet"

click = ButtonClick(run_button)
run_button._trigger_event(click)
im = get_tile("overlay", 4, 8, 4, show=False)
_, num = label(np.any(im[:, :, :3], axis=2))
# check there are multiple cells being detected
assert len(main.UI["color_column"].children) > 3
assert num > 10

# test save functionality
save_button = doc.get_model_by_name("save_button0")
click = ButtonClick(save_button)
save_button._trigger_event(click)
saved_path = (
data_path["base_path"]
/ "overlays"
/ (data_path["slide1"].stem + "_saved_anns.db")
)
assert saved_path.exists()

# load an overlay with different types
cprop_select = doc.get_model_by_name("cprop0")
cprop_select.value = ["prob"]
layer_drop = doc.get_model_by_name("layer_drop0")
click = MenuItemClick(layer_drop, str(data_path["dat_anns"]))
layer_drop._trigger_event(click)
assert main.UI["vstate"].types == ["annotation"]
# check the per-type ui controls have been updated
assert len(main.UI["color_column"].children) == 1
assert len(main.UI["type_column"].children) == 1


def test_alpha_sliders(doc: Document) -> None:
"""Test sliders for adjusting slide and overlay alpha."""
slide_alpha = doc.get_model_by_name("slide_alpha0")
Expand Down
18 changes: 18 additions & 0 deletions tiatoolbox/data/pretrained_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -934,3 +934,21 @@ nuclick_light-pannuke:
patch_input_shape: [128, 128]
patch_output_shape: [128, 128]
save_resolution: {'units': 'baseline', 'resolution': 1.0}

segment_anything-base:
url: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
architecture:
class: sam.SAM
kwargs:
model_type: "vit_b"
checkpoint_path: "sam_vit_b_01ec64.pth"
ioconfig:
class: semantic_segmentor.IOSegmentorConfig
kwargs:
input_resolutions:
- {'units': 'baseline', 'resolution': 1.0}
output_resolutions:
- {'units': 'baseline', 'resolution': 1.0}
patch_input_shape: [1024, 1024]
patch_output_shape: [1024, 1024]
save_resolution: {'units': 'baseline', 'resolution': 1.0}
4 changes: 4 additions & 0 deletions tiatoolbox/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .architecture.mapde import MapDe
from .architecture.micronet import MicroNet
from .architecture.nuclick import NuClick
from .architecture.sam import SAM
from .architecture.sccnn import SCCNN
from .engine.multi_task_segmentor import MultiTaskSegmentor
from .engine.nucleus_instance_segmentor import NucleusInstanceSegmentor
Expand All @@ -17,6 +18,7 @@
PatchPredictor,
WSIPatchDataset,
)
from .engine.prompt_segmentor import PromptSegmentor
from .engine.semantic_segmentor import (
DeepFeatureExtractor,
IOSegmentorConfig,
Expand All @@ -25,6 +27,7 @@
)

__all__ = [
"SAM",
"SCCNN",
"HoVerNet",
"HoVerNetPlus",
Expand All @@ -35,5 +38,6 @@
"NuClick",
"NucleusInstanceSegmentor",
"PatchPredictor",
"PromptSegmentor",
"SemanticSegmentor",
]
Loading
Loading