Skip to content

✨ Integrate SAM as an Interactive Segmentation Tool #918

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 58 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
ca37400
Merge pull request #1 from TissueImageAnalytics/develop
mbasheer04 Jan 17, 2025
dab2693
Merge branch 'TissueImageAnalytics:develop' into develop
mbasheer04 Jan 24, 2025
bc6afda
Squashed commit of the following:
mbasheer04 Jan 24, 2025
5f8032d
Merge branch 'develop' of https://github.com/mbasheer04/tiatoolbox in…
mbasheer04 Jan 24, 2025
a543460
Integrating SAM into bokeh
mbasheer04 Jan 30, 2025
c620ca4
Added save file for GeneralSegmentor output
mbasheer04 Jan 30, 2025
9217b69
Added on-click prompt segementation to TIAViz
mbasheer04 Feb 19, 2025
60f4d5f
Added multi-prompt segmentation & bounding-box
mbasheer04 Feb 20, 2025
09a79fd
Added scores to masks
mbasheer04 Feb 22, 2025
1a4a76c
Attempting to add resolution/window-based segmentation
mbasheer04 Feb 26, 2025
b322539
Successfully implemented window-based segmentation
mbasheer04 Feb 27, 2025
35237c4
Fixing issues
mbasheer04 Mar 4, 2025
6ad56fe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 20, 2025
2cb7d9c
Merge branch 'develop' into sam-viz
shaneahmed Mar 21, 2025
7b403c9
Preparing for PR
mbasheer04 Mar 20, 2025
a8f2938
Merge branch 'sam-viz' of https://github.com/mbasheer04/tiatoolbox in…
mbasheer04 Mar 27, 2025
6e8b859
Restoring notebook changes
mbasheer04 Mar 27, 2025
51ecbe9
Pre-commit fixes
mbasheer04 Mar 28, 2025
fd3a2c2
Added centroid extraction
mbasheer04 Apr 3, 2025
315d2c8
Merge branch 'develop' into sam-viz
shaneahmed Apr 4, 2025
a5d5971
Merge branch 'develop' into sam-viz
adamshephard Apr 10, 2025
ec9b76c
Improving Engine
mbasheer04 Apr 11, 2025
5db9657
Removing eval files
mbasheer04 Apr 11, 2025
e193646
Merge branch 'sam-viz' of https://github.com/mbasheer04/tiatoolbox in…
mbasheer04 Apr 11, 2025
effdadc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2025
f69dfb7
Cleaning up code
mbasheer04 Apr 11, 2025
c02f153
Merge branch 'sam-viz' of https://github.com/mbasheer04/tiatoolbox in…
mbasheer04 Apr 11, 2025
c1cb059
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2025
71eedb2
Added docstrings
mbasheer04 Apr 16, 2025
0878a66
Added SAM to requirements.txt
mbasheer04 Apr 17, 2025
65d72aa
Added pretrained model and fixed requirements
mbasheer04 Apr 24, 2025
be52c04
Adding engine unit tests
mbasheer04 Apr 25, 2025
2517844
Improving unit tests
mbasheer04 Apr 25, 2025
210ad2d
Improving unit tests
mbasheer04 Apr 29, 2025
821a848
Add store from #926
mbasheer04 Apr 29, 2025
5bf382b
Added annotation-based save for engine
mbasheer04 May 1, 2025
b5afc18
Merge branch 'develop' into sam-viz
mbasheer04 May 2, 2025
fbfa884
Fixing DeepSource issues
mbasheer04 May 2, 2025
37939c3
Merge branch 'sam-viz' of https://github.com/mbasheer04/tiatoolbox in…
mbasheer04 May 2, 2025
673c318
Debugging engine
mbasheer04 May 7, 2025
63d8012
Switched to transformers
mbasheer04 May 8, 2025
5c3656a
Finishing unit tests
mbasheer04 May 9, 2025
fd0539e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 9, 2025
01af916
Fixing deepsource
mbasheer04 May 9, 2025
311295d
Merge branch 'develop' into sam-viz
mbasheer04 May 9, 2025
c992fc8
Merge branch 'sam-viz' of https://github.com/mbasheer04/tiatoolbox in…
mbasheer04 May 9, 2025
4765906
Fixing build errors
mbasheer04 May 9, 2025
73150f2
Merge branch 'develop' into sam-viz
mbasheer04 May 9, 2025
7f122b3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 9, 2025
cc4f883
Fixing more build errors
mbasheer04 May 9, 2025
9c0e1bd
Merge branch 'sam-viz' of https://github.com/mbasheer04/tiatoolbox in…
mbasheer04 May 9, 2025
53374aa
Fixing arch test
mbasheer04 May 9, 2025
0e13290
Removed whole image mask generation
mbasheer04 May 10, 2025
6190647
Added missing exception tests
mbasheer04 May 10, 2025
94e09f1
Fixed image encoding to work with non-square images
mbasheer04 May 11, 2025
c24016c
Fixing issues with tile mode
mbasheer04 May 12, 2025
c7d3dc6
Fixed TIAViz issues
mbasheer04 May 12, 2025
85b1148
Added mask again
mbasheer04 May 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 197 additions & 23 deletions examples/07-advanced-modeling.ipynb

Large diffs are not rendered by default.

1,985 changes: 1,985 additions & 0 deletions examples/sam-architecture.ipynb

Large diffs are not rendered by default.

Binary file added examples/slides/glands.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/slides/sample_wsi.svs
Binary file not shown.
114 changes: 114 additions & 0 deletions examples/tiaviz-test.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"|2025-02-25|13:18:57.668| [WARNING] /dcs/22/u2208490/.conda/envs/tiatoolbox-dev/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"\n",
"|2025-02-25|13:19:02.454| [WARNING] /dcs/22/u2208490/.conda/envs/tiatoolbox-dev/lib/python3.11/site-packages/albumentations/__init__.py:28: UserWarning: A new version of Albumentations is available: '2.0.4' (you have '2.0.1'). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.\n",
" check_for_updates()\n",
"\n"
]
}
],
"source": [
"from tiatoolbox.models.architecture.sam import SAM\n",
"from tiatoolbox.models.engine.general_segmentor import GeneralSegmentor\n",
"\n",
"# abc = GeneralSegmentor(model=SAM())\n",
"# prompts = SAMPrompts([[100,100]])\n",
"# output = abc.predict(\"slides/sample_wsi.svs\", prompts, \"cpu\", \"abcdefg\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"|2025-02-25|13:19:07.574| [INFO] Loaded checkpoint sucessfully\n"
]
}
],
"source": [
"from pathlib import Path\n",
"\n",
"model = GeneralSegmentor(SAM())\n",
"\n",
"glands = \"slides/glands.png\"\n",
"slides = \"slides/sample_wsi.svs\"\n",
"glands_prompts = [(370, 270), (300, 400)]\n",
"slides_prompts = [[5792, 6018]]\n",
"slides_location = (5745, 5972)\n",
"slides_size = (200, 114)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(114, 200, 3)\n",
"(114, 200, 3)\n",
"(12000, 12000, 3)\n",
"|2025-02-25|13:20:23.353| [INFO] For numpy array image, we assume (HxWxC) format\n",
"|2025-02-25|13:20:23.364| [INFO] Computing image embeddings for the provided image...\n",
"|2025-02-25|13:20:24.098| [INFO] Image embeddings computed.\n",
"[]\n",
"Prediction stored at /dcs/22/u2208490/cs310/tiatoolbox/examples/overlays\n"
]
}
],
"source": [
"prompts = model.create_prompts(slides_prompts)\n",
"output = model.predict(\n",
" slides, prompts, \"cpu\", \"overlays\", slides_location, slides_size, 0.5\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"save_path = model.to_annotation(output[0][1], output[0][2], Path(\"overlays/sample_wsi\"))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "tiatoolbox-dev",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
49 changes: 49 additions & 0 deletions tests/models/test_arch_sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Unit test package for SAM."""

from pathlib import Path
from typing import Callable

import pytest

from tiatoolbox.models import SAM
from tiatoolbox.models.architecture.sam import SAMPrompts
from tiatoolbox.utils import imread

ON_GPU = False

# Test pretrained Model =============================


def test_functional_sam(
remote_sample: Callable,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test for SAM."""
# convert to pathlib Path to prevent wsireader complaint
tile_path = Path(remote_sample("patch-extraction-vf"))
img = imread(tile_path)

# test creation
_ = SAM()

# test inference
# create prompts

prompts1 = SAMPrompts(point_coords=[[64, 64]])
prompts2 = SAMPrompts(point_coords=[[64, 64]], point_labels=[1])
prompts3 = SAMPrompts(box_coords=[[64, 64, 128, 128]])
prompts4 = SAMPrompts(
point_coords=[[64, 64]], point_labels=[1], box_coords=[[64, 64, 128, 128]]
)

model = SAM()

# load pretrained weights
# pretrained = torch.load(weights_path, map_location="cpu")
# model.load_state_dict(pretrained)

_ = model.infer_batch(model, img, on_gpu=ON_GPU) # no prompts
_ = model.infer_batch(model, img, prompts=prompts1, on_gpu=ON_GPU)
_ = model.infer_batch(model, img, prompts=prompts2, on_gpu=ON_GPU)
_ = model.infer_batch(model, img, prompts=prompts3, on_gpu=ON_GPU)
_ = model.infer_batch(model, img, prompts=prompts4, on_gpu=ON_GPU)
Empty file.
4 changes: 4 additions & 0 deletions tiatoolbox/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from .architecture.mapde import MapDe
from .architecture.micronet import MicroNet
from .architecture.nuclick import NuClick
from .architecture.sam import SAM
from .architecture.sccnn import SCCNN
from .engine.general_segmentor import GeneralSegmentor
from .engine.multi_task_segmentor import MultiTaskSegmentor
from .engine.nucleus_instance_segmentor import NucleusInstanceSegmentor
from .engine.patch_predictor import (
Expand All @@ -25,7 +27,9 @@
)

__all__ = [
"SAM",
"SCCNN",
"GeneralSegmentor",
"HoVerNet",
"HoVerNetPlus",
"IDaRS",
Expand Down
144 changes: 144 additions & 0 deletions tiatoolbox/models/architecture/sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""Define SAM architecture."""

from __future__ import annotations

import numpy as np
import torch
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.build_sam import build_sam2, build_sam2_hf
from sam2.sam2_image_predictor import SAM2ImagePredictor

from tiatoolbox.models.models_abc import ModelABC


class SAMPrompts:
"""Structure of prompts for SAM."""

def __init__(self, point_coords=None, point_labels=None, box_coords=None):
self.point_coords = None if point_coords == [] else point_coords
self.box_coords = None if box_coords == [] else box_coords
if point_coords and point_labels is None:
self.point_labels = [1] * len(point_coords)
else:
self.point_labels = point_labels


class SAM(ModelABC):
def __init__(
self: SAM,
model_hf_path: str = "facebook/sam2-hiera-tiny",
checkpoint_path: str = None,
model_cfg_path: str = None,
) -> None:
"""Initialize :class:`SAM`."""
super().__init__()
self.net_name = "SAM"

if checkpoint_path is None or model_cfg_path is None:
self.model = build_sam2_hf(model_hf_path, device="cpu")
else:
self.model = build_sam2(model_cfg_path, checkpoint_path)

self.predictor = SAM2ImagePredictor(self.model)
self.generator = SAM2AutomaticMaskGenerator(self.model)

def forward(self: SAM, image: np.ndarray, prompts: SAMPrompts = None) -> np.ndarray:
"""Torch method, this contains logic for using layers defined in init."""
mask = self.generate_mask(self, image, prompts)
return mask

@staticmethod
def infer_batch(
model: torch.nn.Module,
batch_data: list,
prompts: SAMPrompts = None,
*,
device,
) -> np.ndarray:
"""Run inference on an input batch.

Contains logic for forward operation as well as I/O aggregation.

Args:
model (nn.Module):
PyTorch defined model.
batch_data (np.ndarray):
A batch of data generated by
`torch.utils.data.DataLoader`.
on_gpu (bool):
Whether to run inference on a GPU.

"""
model.eval()
model = model.to(device)

if isinstance(
batch_data, torch.Tensor
): # Move the tensor to the CPU if it's a PyTorch tensor
batch_data = batch_data.to(device).type(torch.float32)
batch_data = batch_data.cpu().numpy()

with torch.inference_mode():
batch_data = model.preproc(batch_data)
masks, scores = model(batch_data, prompts)
masks = model.postproc(masks)
return masks, scores

@staticmethod
def encode_image(self, image: np.ndarray) -> np.ndarray:
"""Encodes the image for feature extraction."""
self.predictor.set_image(image)

@staticmethod
def generate_mask(self, features: np.ndarray, prompts: SAMPrompts) -> np.ndarray:
"""Generates a segmentation mask using SAM 2, optionally guided by a prompt."""
if prompts:
self.encode_image(self, features)
masks, scores, _ = self.predictor.predict(
point_coords=prompts.point_coords,
point_labels=prompts.point_labels,
box=prompts.box_coords,
multimask_output=False,
)
sorted_ind = np.argsort(scores)[::-1]
masks = np.array(masks[sorted_ind], dtype=np.uint8)
scores = np.around(scores[sorted_ind], 2)
else:
masks = self.generator.generate(features)
scores = np.array([mask["predicted_iou"] for mask in masks])
return masks, scores

@staticmethod
def load_weights(self, checkpoint_path: str) -> None:
"""Loads model weights from specified checkpoint."""
self.model.load_state_dict(
torch.load(checkpoint_path, map_location=self.device)
)

@staticmethod
def preproc(image: np.ndarray) -> np.ndarray:
"""Pre-processes images - Converts them into a format accepted by SAM (HWC) from NCHW."""
if isinstance(
image, torch.Tensor
): # Move the tensor to the CPU if it's a PyTorch tensor
image = image.cpu().numpy()

# Handle different shapes
if image.ndim == 4 and image.shape == (1, 512, 512, 3): # Case 1: (N, H, W, C)
image = np.squeeze(image, axis=0) # Remove batch dimension
elif image.ndim == 4 and image.shape == (
1,
3,
512,
512,
): # Case 2: (N, C, H, W)
image = np.squeeze(image, axis=0) # Remove batch dimension
image = np.transpose(image, (1, 2, 0)) # (C, H, W) -> (H, W, C)

image = image[:, :, :3] # Remove alpha channel
return image

@staticmethod
def postproc(image: np.ndarray) -> np.ndarray:
"""Define the post-processing of this class of model."""
return image
Loading
Loading