Skip to content

Commit 3745388

Browse files
committed
Add first batch of roundtrip tests
1 parent e84ace6 commit 3745388

File tree

6 files changed

+134
-30
lines changed

6 files changed

+134
-30
lines changed

tests/data/nifti/example4d.nii

1.13 MB
Binary file not shown.

tests/data/nifti/functional.nii

42.2 KB
Binary file not shown.

tests/data/nifti/standard.nii

492 Bytes
Binary file not shown.
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import nibabel as nib
2+
import numpy as np
3+
import pytest
4+
5+
import tiledb
6+
from tests import get_path
7+
from tiledb.bioimg.converters.nifti import NiftiConverter
8+
9+
10+
def compare_nifti_images(file1, file2):
11+
img1 = nib.load(file1)
12+
img2 = nib.load(file2)
13+
14+
# Compare the headers (metadata)
15+
if img1.header != img2.header:
16+
return False
17+
18+
# Compare the affine matrices (spatial information)
19+
if not np.array_equal(img1.affine, img2.affine):
20+
return False
21+
22+
# Compare the image data (voxel data)
23+
data1 = img1.get_fdata()
24+
data2 = img2.get_fdata()
25+
if not np.array_equal(data1, data2):
26+
return False
27+
return True
28+
29+
30+
@pytest.mark.parametrize(
31+
"filename", ["nifti/example4d.nii", "nifti/functional.nii", "nifti/standard.nii"]
32+
)
33+
@pytest.mark.parametrize("preserve_axes", [False, True])
34+
@pytest.mark.parametrize("chunked", [False])
35+
@pytest.mark.parametrize(
36+
"compressor, lossless",
37+
[
38+
(tiledb.ZstdFilter(level=0), True),
39+
# WEBP is not supported for Grayscale images
40+
],
41+
)
42+
def test_nifti_converter_roundtrip(
43+
tmp_path, preserve_axes, chunked, compressor, lossless, filename
44+
):
45+
# For lossy WEBP we cannot use random generated images as they have so much noise
46+
input_path = str(get_path(filename))
47+
tiledb_path = str(tmp_path / "to_tiledb")
48+
output_path = str(tmp_path / "from_tiledb.nii")
49+
50+
NiftiConverter.to_tiledb(
51+
input_path,
52+
tiledb_path,
53+
preserve_axes=preserve_axes,
54+
chunked=chunked,
55+
compressor=compressor,
56+
log=False,
57+
)
58+
# Store it back to PNG
59+
NiftiConverter.from_tiledb(tiledb_path, output_path)
60+
compare_nifti_images(input_path, output_path)

tiledb/bioimg/converters/nifti.py

Lines changed: 71 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import nibabel as nib
1818
import numpy as np
1919
from nibabel import Nifti1Image
20+
from nibabel.analyze import _dtdefs
2021
from numpy._typing import NDArray
2122

2223
from tiledb import VFS, Config, Ctx
@@ -56,7 +57,11 @@ def __init__(
5657
self._binary_header = base64.b64encode(
5758
self._nib_image.header.binaryblock
5859
).decode("utf-8")
59-
self._mode = "".join(self._nib_image.dataobj.dtype.names)
60+
self._mode = (
61+
"".join(self._nib_image.dataobj.dtype.names)
62+
if self._nib_image.dataobj.dtype.names is not None
63+
else ""
64+
)
6065

6166
def __enter__(self) -> NiftiReader:
6267
return self
@@ -100,10 +105,40 @@ def image_metadata(self) -> Dict[str, Any]:
100105

101106
@property
102107
def axes(self) -> Axes:
103-
if self._mode == "L":
104-
axes = Axes(["X", "Y", "Z"])
108+
header_dict = self.nifti1_hdr_2_dict()
109+
# The 0-index holds information about the number of dimensions
110+
# according the spec https://nifti.nimh.nih.gov/pub/dist/src/niftilib/nifti1.h
111+
dims_number = header_dict["dim"][0]
112+
if dims_number == 4:
113+
# According to standard the 4th dimension corresponds to 'T' time
114+
# but in special cases can be degnerate into channels
115+
if header_dict["dim"][dims_number] == 1:
116+
# The time dimension does not correspond to time
117+
if self._mode == "RGB" or self._mode == "RGBA":
118+
# [..., ..., ..., 1, 3] or [..., ..., ..., 1, 4]
119+
axes = Axes(["X", "Y", "Z", "T", "C"])
120+
else:
121+
# The image is single-channel with 1 value in Temporal dimension
122+
# instead of channel. So we map T to be channel.
123+
# [..., ..., ..., 1]
124+
axes = Axes(["X", "Y", "Z", "C"])
125+
else:
126+
# The time dimension does correspond to time
127+
axes = Axes(["X", "Y", "Z", "T"])
128+
elif dims_number < 4:
129+
# Only spatial dimensions
130+
if self._mode == "RGB" or self._mode == "RGBA":
131+
axes = Axes(["X", "Y", "Z", "C"])
132+
else:
133+
axes = Axes(["X", "Y", "Z"])
105134
else:
106-
axes = Axes(["X", "Y", "Z", "C"])
135+
# Has more dimensions that belong to spatial-temporal unknown attributes
136+
# TODO: investigate sample images of this format.
137+
if self._mode == "RGB" or self._mode == "RGBA":
138+
axes = Axes(["X", "Y", "Z", "C"])
139+
else:
140+
axes = Axes(["X", "Y", "Z"])
141+
107142
self._logger.debug(f"Reader axes: {axes}")
108143
return axes
109144

@@ -124,7 +159,6 @@ def channels(self) -> Sequence[str]:
124159
"G": "GREEN",
125160
"B": "BLUE",
126161
"A": "ALPHA",
127-
"L": "GRAYSCALE",
128162
}
129163
# Use list comprehension to convert the short form to full form
130164
rgb_full = [color_map[color] for color in self._mode]
@@ -139,12 +173,11 @@ def level_count(self) -> int:
139173
def level_dtype(self, level: int = 0) -> np.dtype:
140174
header_dict = self.nifti1_hdr_2_dict()
141175

142-
# Check the header first
143-
if (dtype := header_dict["data_type"].dtype) == np.dtype("S10"):
176+
dtype = self.get_dtype_from_code(header_dict["datatype"])
177+
if dtype == np.dtype([("R", "u1"), ("G", "u1"), ("B", "u1")]):
144178
dtype = np.uint8
145-
146179
# TODO: Compare with the dtype of fields
147-
# dict(self._nib_image.dataobj.dtype.fields)
180+
148181
self._logger.debug(f"Level {level} dtype: {dtype}")
149182
return dtype
150183

@@ -153,15 +186,17 @@ def level_shape(self, level: int = 0) -> Tuple[int, ...]:
153186
return ()
154187

155188
original_shape = self._nib_image.shape
156-
fields = self._nib_image.dataobj.dtype.fields
157-
if len(fields) == 3:
158-
# RGB convert the shape from to stack 3 channels
159-
l_shape = (*original_shape[:-1], 3)
160-
elif len(fields) == 4:
161-
# RGBA
162-
l_shape = (*original_shape[:-1], 4)
189+
if (fields := self._nib_image.dataobj.dtype.fields) is not None:
190+
if len(fields) == 3:
191+
# RGB convert the shape from to stack 3 channels
192+
l_shape = (*original_shape, 3)
193+
elif len(fields) == 4:
194+
# RGBA
195+
l_shape = (*original_shape, 4)
196+
else:
197+
# Grayscale
198+
l_shape = original_shape
163199
else:
164-
# Grayscale
165200
l_shape = original_shape
166201
self._logger.debug(f"Level {level} shape: {l_shape}")
167202
return l_shape
@@ -221,6 +256,13 @@ def nifti1_hdr_2_dict(self) -> Dict[str, Any]:
221256
for field in structured_header_arr.dtype.names
222257
}
223258

259+
# Function to find and return the third value based on the first value
260+
def get_dtype_from_code(self, dtype_code: int) -> np.dtype:
261+
for item in _dtdefs:
262+
if item[0] == dtype_code: # Check if the first value matches the input code
263+
return item[2] # Return the third value (dtype)
264+
return None # Return None if the code is not foun
265+
224266
@staticmethod
225267
def _serialize_header(header_dict: Mapping[str, Any]) -> Dict[str, Any]:
226268
serialized_header = {
@@ -265,9 +307,13 @@ def compute_level_metadata(
265307
def write_group_metadata(self, metadata: Mapping[str, Any]) -> None:
266308
self._group_metadata = json.loads(metadata["json_write_kwargs"])
267309

268-
def _structured_dtype(self) -> np.dtype:
310+
def _structured_dtype(self) -> Optional[np.dtype]:
269311
if self._original_mode == "RGB":
270312
return np.dtype([("R", "u1"), ("G", "u1"), ("B", "u1")])
313+
elif self._original_mode == "RGBA":
314+
return np.dtype([("R", "u1"), ("G", "u1"), ("B", "u1"), ("A", "u1")])
315+
else:
316+
return None
271317

272318
def write_level_image(
273319
self,
@@ -278,9 +324,14 @@ def write_level_image(
278324
binaryblock=base64.b64decode(self._group_metadata["binaryblock"])
279325
)
280326
contiguous_image = np.ascontiguousarray(image)
281-
structured_arr = contiguous_image.view(dtype=self._structured_dtype()).reshape(
282-
*image.shape[:-1]
327+
structured_arr = contiguous_image.view(
328+
dtype=self._structured_dtype() if self._structured_dtype() else image.dtype
283329
)
330+
if len(image.shape) > 3:
331+
# If temporal is 1 and extra dim for channels RGB/RGBA
332+
if image.shape[3] == 1 and (image.shape[4] == 3 or 4):
333+
structured_arr = structured_arr.reshape(*image.shape[:4])
334+
284335
nib_image = self._writer(
285336
structured_arr, header=header, affine=header.get_best_affine()
286337
)

tiledb/bioimg/openslide.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import json
1515

1616
import tiledb
17-
from tiledb import Config, Ctx, TileDBError
17+
from tiledb import Config, Ctx
1818
from tiledb.highlevel import _get_ctx
1919

2020
from . import ATTR_NAME
@@ -84,7 +84,7 @@ def levels(self) -> Sequence[int]:
8484

8585
@property
8686
def dimensions(self) -> Tuple[int, ...]:
87-
"""A (width, height, depth - (if exists)) tuple for level 0 of the slide."""
87+
"""A (width, height) tuple for level 0 of the slide."""
8888
return self._levels[0].dimensions
8989

9090
@property
@@ -196,14 +196,7 @@ def dimensions(self) -> Tuple[int, ...]:
196196
dims = list(a.domain)
197197
width = a.shape[dims.index(a.dim("X"))]
198198
height = a.shape[dims.index(a.dim("Y"))]
199-
try:
200-
depth = a.shape[dims.index(a.dim("Z"))]
201-
# The Z dim does not exist
202-
except TileDBError:
203-
depth = None
204-
d1, d2 = width // self._pixel_depth, height
205-
dimensions = (d1, d2) if depth is None else (d1, d2, depth)
206-
return dimensions
199+
return width // self._pixel_depth, height
207200

208201
@property
209202
def properties(self) -> Mapping[str, Any]:

0 commit comments

Comments
 (0)