Skip to content

Commit 1f15307

Browse files
authored
Merge branch 'develop' into models-abc-multigpu
2 parents ee25842 + 3eea490 commit 1f15307

File tree

5 files changed

+242
-4
lines changed

5 files changed

+242
-4
lines changed

.github/workflows/python-package.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ jobs:
3030
sudo apt update
3131
sudo apt-get install -y libopenslide-dev openslide-tools libopenjp2-7 libopenjp2-tools
3232
python -m pip install --upgrade pip
33-
python -m pip install ruff==0.11.4 pytest pytest-cov pytest-runner
33+
python -m pip install ruff==0.11.8 pytest pytest-cov pytest-runner
3434
pip install -r requirements/requirements.txt
3535
- name: Cache tiatoolbox static assets
3636
uses: actions/cache@v3

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ repos:
6060
- id: rst-inline-touching-normal # Detect mistake of inline code touching normal text in rst.
6161
- repo: https://github.com/astral-sh/ruff-pre-commit
6262
# Ruff version.
63-
rev: v0.11.4
63+
rev: v0.11.8
6464
hooks:
6565
- id: ruff
6666
args: [--fix, --exit-non-zero-on-fix]

requirements/requirements_dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ pytest>=7.2.0
1010
pytest-cov>=4.0.0
1111
pytest-runner>=6.0
1212
pytest-xdist[psutil]
13-
ruff==0.11.4 # This will be updated by pre-commit bot to latest version
13+
ruff==0.11.8 # This will be updated by pre-commit bot to latest version
1414
toml>=0.10.2
1515
twine>=4.0.1
1616
wheel>=0.37.1

tests/test_utils.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,14 @@
1212
import numpy as np
1313
import pandas as pd
1414
import pytest
15+
import tifffile
1516
import torch
17+
import zarr
18+
from defusedxml import ElementTree as ET # noqa: N817
1619
from PIL import Image
1720
from requests import HTTPError
1821
from shapely.geometry import Polygon
22+
from tifffile import TiffFile
1923

2024
from tests.test_annotation_stores import cell_polygon
2125
from tiatoolbox import rcParam, utils
@@ -1858,3 +1862,123 @@ def test_torch_compile_compatibility(caplog: pytest.LogCaptureFixture) -> None:
18581862

18591863
is_torch_compile_compatible()
18601864
assert "torch.compile" in caplog.text
1865+
1866+
1867+
# Tests for OME tiff writer
1868+
1869+
1870+
def get_ome_metadata(tiff_path: Path) -> str | None:
1871+
"""Extracts the OME metadata string from a TIFF file."""
1872+
with TiffFile(tiff_path) as tif:
1873+
if tif.ome_metadata:
1874+
return tif.ome_metadata
1875+
return None
1876+
1877+
1878+
def assert_ome_metadata_value(
1879+
ome_xml: ET.Element, tag: str, expected_value: str
1880+
) -> None:
1881+
"""Asserts the value of a specific OME metadata tag (as an attribute)."""
1882+
namespace = "{http://www.openmicroscopy.org/Schemas/OME/2016-06}"
1883+
image_elements = ome_xml.findall(f".//{namespace}Image")
1884+
if image_elements:
1885+
pixels_elements = image_elements[0].findall(f"./{namespace}Pixels")
1886+
if pixels_elements:
1887+
actual_value = pixels_elements[0].get(tag)
1888+
assert actual_value == expected_value, (
1889+
f"Expected attribute '{tag}' to be '{expected_value}', "
1890+
f"but got '{actual_value}'."
1891+
)
1892+
return
1893+
1894+
# If we reach here, the tag or attribute was not found
1895+
pytest.fail(f"Attribute '{tag}' not found in OME metadata.")
1896+
1897+
1898+
def test_iwrite_probability_heatmap_as_ome_tiff_errors(tmp_path: Path) -> None:
1899+
"""Test expected errors in `write_probability_heatmap_as_ome_tiff`."""
1900+
probability = np.zeros(shape=(256, 256, 3))
1901+
1902+
# Input image must have 2 (CY) dimensions.
1903+
with pytest.raises(ValueError, match=r".*must have 2 \(YX\).*"):
1904+
misc.write_probability_heatmap_as_ome_tiff(
1905+
image_path=tmp_path / "failed_test.tif",
1906+
probability=probability,
1907+
)
1908+
1909+
probability = np.zeros(shape=(256, 256, 3))
1910+
probability = torch.from_numpy(probability)
1911+
1912+
# Input image must be a NumPy array or a Zarr array.
1913+
with pytest.raises(TypeError, match=r".*must be a NumPy array or a Zarr.*"):
1914+
misc.write_probability_heatmap_as_ome_tiff(
1915+
image_path=tmp_path / "failed_test.tif",
1916+
probability=probability,
1917+
)
1918+
1919+
1920+
def test_save_numpy_array_proability_ome_tiff(
1921+
tmp_path: Path, source_image: Path
1922+
) -> None:
1923+
"""Tests saving a basic NumPy array."""
1924+
image_path = tmp_path / "numpy_image.ome.tif"
1925+
probability = utils.imread(source_image)
1926+
probability_0 = probability[:, :, 0]
1927+
misc.write_probability_heatmap_as_ome_tiff(
1928+
image_path=image_path,
1929+
probability=probability_0,
1930+
tile_size=(64, 64),
1931+
mpp=(0.5, 0.5),
1932+
levels=2,
1933+
colormap=cv2.COLORMAP_JET,
1934+
)
1935+
assert image_path.is_file()
1936+
saved_img = tifffile.imread(image_path)
1937+
assert probability.shape == saved_img.shape
1938+
assert probability.dtype == saved_img.dtype
1939+
ome_xml = ET.fromstring(get_ome_metadata(image_path))
1940+
assert ome_xml is not None
1941+
1942+
assert_ome_metadata_value(ome_xml, "SizeY", str(probability.shape[0]))
1943+
assert_ome_metadata_value(ome_xml, "SizeX", str(probability.shape[1]))
1944+
assert_ome_metadata_value(ome_xml, "SizeC", str(3))
1945+
assert_ome_metadata_value(ome_xml, "DimensionOrder", "XYCZT")
1946+
assert_ome_metadata_value(ome_xml, "PhysicalSizeX", "0.5")
1947+
assert_ome_metadata_value(ome_xml, "PhysicalSizeY", "0.5")
1948+
assert_ome_metadata_value(ome_xml, "PhysicalSizeXUnit", "µm")
1949+
assert_ome_metadata_value(ome_xml, "PhysicalSizeYUnit", "µm")
1950+
1951+
1952+
def test_save_zarr_array_probability_ome_tiff(
1953+
tmp_path: Path, source_image: Path
1954+
) -> None:
1955+
"""Tests saving a Zarr array with uint8 dtype."""
1956+
image_path = tmp_path / "zarr_uint8_image.ome.tif"
1957+
1958+
img = utils.imread(source_image)
1959+
probability = img[:, 0:200, 0]
1960+
img_zarr = zarr.zeros(shape=probability.shape, dtype=np.uint8)
1961+
img_zarr[:] = probability
1962+
1963+
misc.write_probability_heatmap_as_ome_tiff(
1964+
image_path,
1965+
img_zarr,
1966+
tile_size=(32, 32),
1967+
levels=2,
1968+
colormap=cv2.COLORMAP_INFERNO,
1969+
)
1970+
assert image_path.is_file()
1971+
saved_img = tifffile.imread(image_path, squeeze=True)
1972+
assert img_zarr.shape == saved_img.shape[0:2]
1973+
assert img_zarr.dtype == saved_img.dtype
1974+
ome_xml = ET.fromstring(get_ome_metadata(image_path))
1975+
assert ome_xml is not None
1976+
1977+
assert_ome_metadata_value(ome_xml, "SizeY", str(img_zarr.shape[0]))
1978+
assert_ome_metadata_value(ome_xml, "SizeX", str(img_zarr.shape[1]))
1979+
assert_ome_metadata_value(ome_xml, "SizeC", str(3))
1980+
assert_ome_metadata_value(ome_xml, "DimensionOrder", "XYCZT")
1981+
assert_ome_metadata_value(ome_xml, "PhysicalSizeX", "0.25")
1982+
assert_ome_metadata_value(ome_xml, "PhysicalSizeY", "0.25")
1983+
assert_ome_metadata_value(ome_xml, "PhysicalSizeXUnit", "µm")
1984+
assert_ome_metadata_value(ome_xml, "PhysicalSizeYUnit", "µm")

tiatoolbox/utils/misc.py

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,22 @@
1616
import numpy as np
1717
import pandas as pd
1818
import requests
19+
import tifffile
1920
import yaml
2021
import zarr
2122
from filelock import FileLock
2223
from shapely.affinity import translate
2324
from shapely.geometry import Polygon
2425
from shapely.geometry import shape as feature2geometry
2526
from skimage import exposure
27+
from tqdm import trange
2628

2729
from tiatoolbox import logger
2830
from tiatoolbox.annotation.storage import Annotation, AnnotationStore, SQLiteStore
2931
from tiatoolbox.utils.exceptions import FileNotSupportedError
3032

3133
if TYPE_CHECKING: # pragma: no cover
34+
from collections.abc import Iterator
3235
from os import PathLike
3336

3437
from shapely import geometry
@@ -160,7 +163,7 @@ def imwrite(image_path: PathLike, img: np.ndarray) -> None:
160163

161164

162165
def imread(image_path: PathLike, as_uint8: bool | None = None) -> np.ndarray:
163-
"""Read an image as numpy array.
166+
"""Read an image as a NumPy array.
164167
165168
Args:
166169
image_path (PathLike):
@@ -1283,6 +1286,117 @@ def dict_to_store(
12831286
return store
12841287

12851288

1289+
def _tiles(
1290+
in_img: np.ndarray | zarr.core.Array,
1291+
tile_size: tuple[int, int],
1292+
colormap: int = cv2.COLORMAP_JET,
1293+
level: int = 0,
1294+
) -> Iterator[np.ndarray]:
1295+
for y in trange(0, in_img.shape[0], tile_size[0]):
1296+
for x in range(0, in_img.shape[1], tile_size[1]):
1297+
in_img_ = in_img[
1298+
y : y + tile_size[0] : 2**level, x : x + tile_size[1] : 2**level
1299+
]
1300+
yield cv2.applyColorMap(in_img_, colormap)
1301+
1302+
1303+
def write_probability_heatmap_as_ome_tiff(
1304+
image_path: Path,
1305+
probability: np.ndarray | zarr.core.Array,
1306+
tile_size: tuple[int, int] = (64, 64),
1307+
levels: int = 2,
1308+
mpp: tuple[float, float] = (0.25, 0.25),
1309+
colormap: int = cv2.COLORMAP_JET,
1310+
) -> None:
1311+
"""Saves output probability maps from segmentation models as heatmaps.
1312+
1313+
This function converts the probability maps from individual classes to heatmaps
1314+
and saves them as pyramidal ome tiffs.
1315+
1316+
Args:
1317+
image_path (Path):
1318+
File path (including extension) to save image to.
1319+
probability (np.ndarray or zarr.core.Array):
1320+
The input image data in YXC (Height, Width, Channels) format.
1321+
tile_size (tuple):
1322+
Tile/Chunk size (YX/HW) for writing the tiff file.
1323+
Only allows tile shapes allowed by tifffile. Default is (64, 64).
1324+
levels (int):
1325+
Number of levels for saving pyramidal ome tiffs. Default is 2.
1326+
mpp (tuple[float, float]):
1327+
Tuple of mpp values in y and x (YX/HW). Default is (0.25, 0.25).
1328+
colormap (int):
1329+
Colormap to save the heatmaps. Default is 2 (cv2.COLORMAP_JET).
1330+
1331+
Raises:
1332+
TypeError:
1333+
If the input `img` is not a NumPy or Zarr array or does not have 3
1334+
dimensions.
1335+
ValueError:
1336+
If input dimensions is not 3 (HWC) dimensions.
1337+
1338+
Examples:
1339+
>>> probability_map = imread("path/to/probability_map")
1340+
>>> write_probability_heatmap_as_ome_tiff(
1341+
... image_path=image_path,
1342+
... probability=probability_map,
1343+
... tile_size=(64, 64),
1344+
... class_name="tumor",
1345+
... levels=2,
1346+
... mpp=(0.5, 0.5),
1347+
... colormap=cv2.COLORMAP_JET,
1348+
... )
1349+
1350+
"""
1351+
if not isinstance(probability, (zarr.core.Array, np.ndarray)):
1352+
msg = "Input 'probability' must be a NumPy array or a Zarr array."
1353+
raise TypeError(msg)
1354+
1355+
if probability.ndim != 2: # noqa: PLR2004
1356+
msg = "Input 'probability' must have 2 (YX) dimensions."
1357+
raise ValueError(msg)
1358+
1359+
ome_metadata = {
1360+
"axes": "YXC",
1361+
"PhysicalSizeX": mpp[1],
1362+
"PhysicalSizeXUnit": "µm",
1363+
"PhysicalSizeY": mpp[0],
1364+
"PhysicalSizeYUnit": "µm",
1365+
}
1366+
1367+
h = probability.shape[0]
1368+
w = probability.shape[1]
1369+
1370+
with tifffile.TiffWriter(image_path, bigtiff=True, ome=True) as tif:
1371+
tif.write(
1372+
_tiles(in_img=probability, tile_size=tile_size, colormap=colormap),
1373+
dtype="uint8",
1374+
shape=(h, w, 3),
1375+
tile=tile_size,
1376+
compression="jpeg",
1377+
metadata=ome_metadata,
1378+
subifds=levels - 1,
1379+
)
1380+
1381+
for level_ in range(1, levels):
1382+
tif.write(
1383+
_tiles(
1384+
in_img=probability,
1385+
tile_size=tile_size,
1386+
colormap=colormap,
1387+
level=level_,
1388+
),
1389+
dtype="uint8",
1390+
shape=(h // 2**level_, w // 2**level_, 3),
1391+
tile=(tile_size[0] // 2**level_, tile_size[1] // 2**level_),
1392+
compression="jpeg",
1393+
subfiletype=0,
1394+
)
1395+
1396+
msg = f"Image saved as OME-TIFF to {image_path}."
1397+
logger.info(msg)
1398+
1399+
12861400
def dict_to_zarr(
12871401
raw_predictions: dict,
12881402
save_path: Path,

0 commit comments

Comments
 (0)