-
Notifications
You must be signed in to change notification settings - Fork 91
✨ Integrate SAM as an Interactive Segmentation Tool #918
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
mbasheer04
wants to merge
58
commits into
TissueImageAnalytics:develop
Choose a base branch
from
mbasheer04:sam-viz
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
58 commits
Select commit
Hold shift + click to select a range
ca37400
Merge pull request #1 from TissueImageAnalytics/develop
mbasheer04 dab2693
Merge branch 'TissueImageAnalytics:develop' into develop
mbasheer04 bc6afda
Squashed commit of the following:
mbasheer04 5f8032d
Merge branch 'develop' of https://github.com/mbasheer04/tiatoolbox in…
mbasheer04 a543460
Integrating SAM into bokeh
mbasheer04 c620ca4
Added save file for GeneralSegmentor output
mbasheer04 9217b69
Added on-click prompt segementation to TIAViz
mbasheer04 60f4d5f
Added multi-prompt segmentation & bounding-box
mbasheer04 09a79fd
Added scores to masks
mbasheer04 1a4a76c
Attempting to add resolution/window-based segmentation
mbasheer04 b322539
Successfully implemented window-based segmentation
mbasheer04 35237c4
Fixing issues
mbasheer04 6ad56fe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 2cb7d9c
Merge branch 'develop' into sam-viz
shaneahmed 7b403c9
Preparing for PR
mbasheer04 a8f2938
Merge branch 'sam-viz' of https://github.com/mbasheer04/tiatoolbox in…
mbasheer04 6e8b859
Restoring notebook changes
mbasheer04 51ecbe9
Pre-commit fixes
mbasheer04 fd3a2c2
Added centroid extraction
mbasheer04 315d2c8
Merge branch 'develop' into sam-viz
shaneahmed a5d5971
Merge branch 'develop' into sam-viz
adamshephard ec9b76c
Improving Engine
mbasheer04 5db9657
Removing eval files
mbasheer04 e193646
Merge branch 'sam-viz' of https://github.com/mbasheer04/tiatoolbox in…
mbasheer04 effdadc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] f69dfb7
Cleaning up code
mbasheer04 c02f153
Merge branch 'sam-viz' of https://github.com/mbasheer04/tiatoolbox in…
mbasheer04 c1cb059
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 71eedb2
Added docstrings
mbasheer04 0878a66
Added SAM to requirements.txt
mbasheer04 65d72aa
Added pretrained model and fixed requirements
mbasheer04 be52c04
Adding engine unit tests
mbasheer04 2517844
Improving unit tests
mbasheer04 210ad2d
Improving unit tests
mbasheer04 821a848
Add store from #926
mbasheer04 5bf382b
Added annotation-based save for engine
mbasheer04 b5afc18
Merge branch 'develop' into sam-viz
mbasheer04 fbfa884
Fixing DeepSource issues
mbasheer04 37939c3
Merge branch 'sam-viz' of https://github.com/mbasheer04/tiatoolbox in…
mbasheer04 673c318
Debugging engine
mbasheer04 63d8012
Switched to transformers
mbasheer04 5c3656a
Finishing unit tests
mbasheer04 fd0539e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 01af916
Fixing deepsource
mbasheer04 311295d
Merge branch 'develop' into sam-viz
mbasheer04 c992fc8
Merge branch 'sam-viz' of https://github.com/mbasheer04/tiatoolbox in…
mbasheer04 4765906
Fixing build errors
mbasheer04 73150f2
Merge branch 'develop' into sam-viz
mbasheer04 7f122b3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] cc4f883
Fixing more build errors
mbasheer04 9c0e1bd
Merge branch 'sam-viz' of https://github.com/mbasheer04/tiatoolbox in…
mbasheer04 53374aa
Fixing arch test
mbasheer04 0e13290
Removed whole image mask generation
mbasheer04 6190647
Added missing exception tests
mbasheer04 94e09f1
Fixed image encoding to work with non-square images
mbasheer04 c24016c
Fixing issues with tile mode
mbasheer04 c7d3dc6
Fixed TIAViz issues
mbasheer04 85b1148
Added mask again
mbasheer04 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
"""Unit test package for SAM.""" | ||
|
||
from pathlib import Path | ||
from typing import Callable | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from tiatoolbox.models.architecture.sam import SAM | ||
from tiatoolbox.utils import env_detection as toolbox_env | ||
from tiatoolbox.utils import imread | ||
from tiatoolbox.utils.misc import select_device | ||
|
||
ON_GPU = toolbox_env.has_gpu() | ||
|
||
# Test pretrained Model ============================= | ||
|
||
|
||
def test_functional_sam( | ||
remote_sample: Callable, | ||
) -> None: | ||
"""Test for SAM.""" | ||
# convert to pathlib Path to prevent wsireader complaint | ||
tile_path = Path(remote_sample("patch-extraction-vf")) | ||
img = imread(tile_path) | ||
|
||
# test creation | ||
|
||
model = SAM(device=select_device(on_gpu=ON_GPU)) | ||
|
||
# create image patch and prompts | ||
patch = img[63:191, 750:878, :] | ||
|
||
points = [[[64, 64]]] | ||
boxes = [[[64, 64, 128, 128]]] | ||
|
||
# test preproc | ||
tensor = torch.from_numpy(img) | ||
patch = np.expand_dims(model.preproc(tensor), axis=0) | ||
patch = model.preproc(patch) | ||
|
||
# test inference | ||
|
||
mask_output, score_output = model.infer_batch( | ||
model, patch, points, device=select_device(on_gpu=ON_GPU) | ||
) | ||
|
||
assert mask_output is not None, "Output should not be None" | ||
assert len(mask_output) > 0, "Output should have at least one element" | ||
assert len(score_output) > 0, "Output should have at least one element" | ||
|
||
mask_output, score_output = model.infer_batch( | ||
model, patch, box_coords=boxes, device=select_device(on_gpu=ON_GPU) | ||
) | ||
|
||
assert len(mask_output) > 0, "Output should have at least one element" | ||
assert len(score_output) > 0, "Output should have at least one element" | ||
|
||
mask_output, score_output = model.infer_batch( | ||
model, patch, device=select_device(on_gpu=ON_GPU) | ||
) | ||
|
||
assert mask_output is not None, "Output should not be None" | ||
assert len(mask_output) > 0, "Output should have at least one element" | ||
assert len(score_output) > 0, "Output should have at least one element" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,273 @@ | ||
"""Unit test package for Prompt Segmentor.""" | ||
|
||
from __future__ import annotations | ||
|
||
# ! The garbage collector | ||
import multiprocessing | ||
import shutil | ||
from pathlib import Path | ||
from typing import Callable | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from tiatoolbox.models import PromptSegmentor | ||
from tiatoolbox.models.architecture.sam import SAM | ||
from tiatoolbox.models.engine.semantic_segmentor import ( | ||
IOSegmentorConfig, | ||
) | ||
from tiatoolbox.utils import env_detection as toolbox_env | ||
from tiatoolbox.utils import imwrite | ||
from tiatoolbox.utils.misc import select_device | ||
from tiatoolbox.wsicore.wsireader import WSIReader | ||
|
||
ON_GPU = toolbox_env.has_gpu() | ||
BATCH_SIZE = 1 if not ON_GPU else 2 | ||
try: | ||
NUM_LOADER_WORKERS = multiprocessing.cpu_count() | ||
except NotImplementedError: | ||
NUM_LOADER_WORKERS = 2 | ||
|
||
|
||
def test_functional_segmentor( | ||
remote_sample: Callable, | ||
tmp_path: Path, | ||
) -> None: | ||
"""Functional test for segmentor.""" | ||
save_dir = tmp_path / "dump" | ||
# # convert to pathlib Path to prevent wsireader complaint | ||
resolution = 2.0 | ||
mini_wsi_svs = Path(remote_sample("patch-extraction-vf")) | ||
reader = WSIReader.open(mini_wsi_svs, resolution) | ||
thumb = reader.slide_thumbnail(resolution=resolution, units="mpp") | ||
thumb = thumb[63:191, 750:878, :] | ||
mini_wsi_jpg = f"{tmp_path}/mini_svs.jpg" | ||
imwrite(mini_wsi_jpg, thumb) | ||
|
||
# preemptive clean up | ||
shutil.rmtree(save_dir, ignore_errors=True) | ||
|
||
model = SAM() | ||
|
||
# test engine setup | ||
|
||
_ = PromptSegmentor(None, BATCH_SIZE, NUM_LOADER_WORKERS) | ||
|
||
prompt_segmentor = PromptSegmentor(model, BATCH_SIZE, NUM_LOADER_WORKERS) | ||
|
||
ioconfig = IOSegmentorConfig( | ||
input_resolutions=[ | ||
{"units": "mpp", "resolution": 4.0}, | ||
], | ||
output_resolutions=[{"units": "mpp", "resolution": 4.0}], | ||
patch_input_shape=[512, 512], | ||
patch_output_shape=[512, 512], | ||
stride_shape=[512, 512], | ||
) | ||
|
||
# test inference | ||
|
||
points = np.array([[[64, 64]], [[64, 64]]]) # Point on nuclei | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why two of the same points? Does it not make more sense to have two different points here? |
||
|
||
# Run on tile mode with multi-prompt | ||
# Test running with multiple images | ||
shutil.rmtree(save_dir, ignore_errors=True) | ||
output_list = prompt_segmentor.predict( | ||
[mini_wsi_jpg, mini_wsi_jpg], | ||
mode="tile", | ||
multi_prompt=True, | ||
device=select_device(on_gpu=ON_GPU), | ||
point_coords=points, | ||
ioconfig=ioconfig, | ||
crash_on_exception=False, | ||
save_dir=save_dir, | ||
) | ||
|
||
pred_1 = np.load(output_list[0][1] + "/0.raw.0.npy") | ||
pred_2 = np.load(output_list[1][1] + "/0.raw.0.npy") | ||
assert len(output_list) == 2 | ||
assert np.sum(pred_1 - pred_2) == 0 | ||
|
||
points = np.array([[[64, 64], [100, 40], [100, 70]]]) # Points on nuclei | ||
boxes = np.array([[[10, 10, 50, 50], [80, 80, 110, 110]]]) # Boxes on nuclei | ||
|
||
# Run on tile mode with single-prompt | ||
# Also tests boxes | ||
shutil.rmtree(save_dir, ignore_errors=True) | ||
output_list = prompt_segmentor.predict( | ||
[mini_wsi_jpg], | ||
mode="tile", | ||
multi_prompt=False, | ||
device=select_device(on_gpu=ON_GPU), | ||
point_coords=points, | ||
box_coords=boxes, | ||
ioconfig=ioconfig, | ||
crash_on_exception=False, | ||
save_dir=save_dir, | ||
) | ||
|
||
total_prompts = points.shape[1] + boxes.shape[1] | ||
preds = [ | ||
np.load(output_list[0][1] + f"/{i}.raw.0.npy") for i in range(total_prompts) | ||
] | ||
|
||
assert len(output_list) == 1 | ||
assert len(preds) == total_prompts | ||
|
||
# Generate mask | ||
mask = np.zeros((thumb.shape[0], thumb.shape[1]), dtype=np.uint8) | ||
mask[32:120, 32:120] = 1 | ||
mini_wsi_msk = f"{tmp_path}/mini_svs_mask.jpg" | ||
imwrite(mini_wsi_msk, mask) | ||
|
||
ioconfig = IOSegmentorConfig( | ||
input_resolutions=[ | ||
{"units": "baseline", "resolution": 1.0}, | ||
], | ||
output_resolutions=[{"units": "baseline", "resolution": 1.0}], | ||
patch_input_shape=[512, 512], | ||
patch_output_shape=[512, 512], | ||
stride_shape=[512, 512], | ||
save_resolution={"units": "baseline", "resolution": 1.0}, | ||
) | ||
|
||
# Only point within mask should generate a segmentation | ||
points = np.array([[[64, 64], [100, 40]]]) | ||
save_dir = tmp_path / "dump" | ||
|
||
# Run on wsi mode with multi-prompt | ||
# Also tests masks | ||
shutil.rmtree(save_dir, ignore_errors=True) | ||
output_list = prompt_segmentor.predict( | ||
[mini_wsi_jpg], | ||
masks=[mini_wsi_msk], | ||
mode="wsi", | ||
multi_prompt=True, | ||
device=select_device(on_gpu=ON_GPU), | ||
point_coords=points, | ||
ioconfig=ioconfig, | ||
crash_on_exception=False, | ||
save_dir=save_dir, | ||
) | ||
|
||
# Check if db exists | ||
assert Path(output_list[0][1] + ".0.db").exists() | ||
|
||
points = np.array([[[10, 30]]]) | ||
boxes = np.array([[[10, 10, 30, 30]]]) | ||
# Test no prompts within mask | ||
shutil.rmtree(save_dir, ignore_errors=True) | ||
output_list = prompt_segmentor.predict( | ||
[mini_wsi_jpg], | ||
masks=[mini_wsi_msk], | ||
mode="wsi", | ||
multi_prompt=True, | ||
device=select_device(on_gpu=ON_GPU), | ||
point_coords=points, | ||
box_coords=boxes, | ||
ioconfig=ioconfig, | ||
crash_on_exception=False, | ||
save_dir=save_dir, | ||
) | ||
# Check if db exists | ||
assert Path(output_list[0][1] + ".0.db").exists() | ||
|
||
# Run on wsi mode with single-prompt | ||
shutil.rmtree(save_dir, ignore_errors=True) | ||
output_list = prompt_segmentor.predict( | ||
[mini_wsi_jpg], | ||
mode="wsi", | ||
multi_prompt=False, | ||
device=select_device(on_gpu=ON_GPU), | ||
point_coords=points, | ||
ioconfig=ioconfig, | ||
crash_on_exception=False, | ||
save_dir=save_dir, | ||
) | ||
|
||
# Check if db exists | ||
assert Path(output_list[0][1] + ".0.db").exists() | ||
|
||
|
||
def test_crash_segmentor(remote_sample: Callable, tmp_path: Path) -> None: | ||
"""Functional crash tests for segmentor.""" | ||
# # convert to pathlib Path to prevent wsireader complaint | ||
mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs")) | ||
mini_wsi_msk = Path(remote_sample("wsi2_4k_4k_msk")) | ||
|
||
save_dir = tmp_path / "test_crash_segmentor" | ||
prompt_segmentor = PromptSegmentor(batch_size=BATCH_SIZE) | ||
|
||
# * test basic crash | ||
with pytest.raises(TypeError, match=r".*`mask_reader`.*"): | ||
prompt_segmentor.filter_coordinates(mini_wsi_msk, np.array(["a", "b", "c"])) | ||
with pytest.raises(TypeError, match=r".*`mask_reader`.*"): | ||
prompt_segmentor.get_mask_bounds(mini_wsi_msk) | ||
with pytest.raises(TypeError, match=r".*mask_reader.*"): | ||
prompt_segmentor.clip_coordinates(mini_wsi_msk, np.array(["a", "b", "c"])) | ||
|
||
with pytest.raises(ValueError, match=r".*ndarray.*integer.*"): | ||
prompt_segmentor.filter_coordinates( | ||
WSIReader.open(mini_wsi_msk), | ||
np.array([1.0, 2.0]), | ||
) | ||
with pytest.raises(ValueError, match=r".*ndarray.*integer.*"): | ||
prompt_segmentor.clip_coordinates( | ||
WSIReader.open(mini_wsi_msk), | ||
np.array([1.0, 2.0]), | ||
) | ||
prompt_segmentor.get_reader(mini_wsi_svs, None, "wsi", auto_get_mask=True) | ||
with pytest.raises(ValueError, match=r".*must be a valid file path.*"): | ||
prompt_segmentor.get_reader( | ||
mini_wsi_msk, | ||
"not_exist", | ||
"wsi", | ||
auto_get_mask=True, | ||
) | ||
|
||
shutil.rmtree(save_dir, ignore_errors=True) # default output dir test | ||
with pytest.raises(ValueError, match=r".*valid mode.*"): | ||
prompt_segmentor.predict([], mode="abc") | ||
|
||
crash_segmentor = PromptSegmentor() | ||
|
||
# * test crash segmentor | ||
def _predict_one_wsi( | ||
*args: dict, | ||
**kwargs: dict, | ||
) -> tuple[WSIReader, str]: | ||
"""Override the predict function to test crash segmentor.""" | ||
msg = f"Test crash segmentor:{args} {kwargs}" | ||
raise RuntimeError(msg) | ||
|
||
crash_segmentor._predict_one_wsi = _predict_one_wsi | ||
shutil.rmtree(save_dir, ignore_errors=True) | ||
with pytest.raises( | ||
RuntimeError, | ||
match=r"Test crash segmentor:\(.*\) \{.*\}", | ||
): | ||
crash_segmentor.predict( | ||
[mini_wsi_svs], | ||
mode="wsi", | ||
multi_prompt=True, | ||
device=select_device(on_gpu=ON_GPU), | ||
patch_input_shape=[512, 512], | ||
resolution=2.0, | ||
units="mpp", | ||
crash_on_exception=True, | ||
save_dir=save_dir, | ||
) | ||
|
||
# test ignore crash | ||
shutil.rmtree(save_dir, ignore_errors=True) | ||
crash_segmentor.predict( | ||
[mini_wsi_svs], | ||
mode="wsi", | ||
multi_prompt=True, | ||
device=select_device(on_gpu=ON_GPU), | ||
patch_input_shape=[512, 512], | ||
resolution=2.0, | ||
units="mpp", | ||
crash_on_exception=False, | ||
save_dir=save_dir, | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on points, it should have just one element from the one point? Maybe correct the wording of this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same for below bounding boxes