Skip to content

Commit 6258617

Browse files
committed
First functional roundtrip for nifti images
1 parent 2e6ffc6 commit 6258617

File tree

1 file changed

+36
-15
lines changed

1 file changed

+36
-15
lines changed

tiledb/bioimg/converters/nifti.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import nibabel as nib
1818
import numpy as np
19-
from black.trans import defaultdict
2019
from nibabel import Nifti1Image
2120
from numpy._typing import NDArray
2221

@@ -51,7 +50,12 @@ def __init__(
5150
self._vfs = VFS(config=self._source_cfg, ctx=self._source_ctx)
5251
self._vfs_fh = self._vfs.open(input_path, mode="rb")
5352
self._nib_image = Nifti1Image.from_stream(self._vfs_fh)
54-
self._metadata: Dict[str, Any] = self._serialize_header(self.nifti1_hdr_2_dict())
53+
self._metadata: Dict[str, Any] = self._serialize_header(
54+
self.nifti1_hdr_2_dict()
55+
)
56+
self._binary_header = base64.b64encode(
57+
self._nib_image.header.binaryblock
58+
).decode("utf-8")
5559
self._mode = "".join(self._nib_image.dataobj.dtype.names)
5660

5761
def __enter__(self) -> NiftiReader:
@@ -74,7 +78,7 @@ def logger(self) -> Optional[logging.Logger]:
7478

7579
@property
7680
def group_metadata(self) -> Dict[str, Any]:
77-
writer_kwargs = dict(metadata=self._metadata)
81+
writer_kwargs = dict(metadata=self._metadata, binaryblock=self._binary_header)
7882
self._logger.debug(f"Group metadata: {writer_kwargs}")
7983
return {"json_write_kwargs": json.dumps(writer_kwargs)}
8084

@@ -176,6 +180,7 @@ def level_image(
176180
) -> np.ndarray:
177181

178182
unscaled_img = self._nib_image.dataobj.get_unscaled()
183+
self._metadata["original_mode"] = self._mode
179184
raw_data_contiguous = np.ascontiguousarray(unscaled_img)
180185
numerical_data = np.frombuffer(raw_data_contiguous, dtype=self.level_dtype())
181186
numerical_data = numerical_data.reshape(self.level_shape())
@@ -202,6 +207,7 @@ def optimal_reader(
202207
# raise ValueError("chunk_size must be set for chunked reading.")
203208
#
204209
# array = self._nib_image.get_fdata()
210+
# array = self._nib_image.get_fdata()
205211
# total_slices = array.shape[-1]
206212
# for i in range(0, total_slices, self.chunk_size):
207213
# chunk = array[..., i : i + self.chunk_size]
@@ -216,15 +222,15 @@ def nifti1_hdr_2_dict(self) -> Dict[str, Any]:
216222
}
217223

218224
@staticmethod
219-
def _serialize_header(header_dict: [Dict, Any]) -> Dict[str, Any]:
220-
serialized_header = defaultdict(dict)
221-
for k,v in header_dict.items():
222-
if isinstance(v, np.ndarray):
223-
serialized_header[k] = v.tolist()
224-
if isinstance(serialized_header[k], bytes):
225-
serialized_header[k] = base64.b64encode(serialized_header[k]).decode('utf-8')
226-
else:
227-
serialized_header[k] = v
225+
def _serialize_header(header_dict: Mapping[str, Any]) -> Dict[str, Any]:
226+
serialized_header = {
227+
k: (
228+
base64.b64encode(v.tolist()).decode("utf-8")
229+
if isinstance(v, np.ndarray) and isinstance(v.tolist(), bytes)
230+
else v.tolist() if isinstance(v, np.ndarray) else v
231+
)
232+
for k, v in header_dict.items()
233+
}
228234
return serialized_header
229235

230236

@@ -233,6 +239,8 @@ def __init__(self, output_path: str, logger: logging.Logger):
233239
self._logger = logger
234240
self._output_path = output_path
235241
self._group_metadata: Dict[str, Any] = {}
242+
self._nifti1header = partial(nib.Nifti1Header)
243+
self._original_mode = None
236244
self._writer = partial(nib.Nifti1Image)
237245

238246
def __enter__(self) -> NiftiWriter:
@@ -249,20 +257,33 @@ def compute_level_metadata(
249257
) -> Mapping[str, Any]:
250258

251259
writer_metadata: Dict[str, Any] = {}
252-
original_mode = group_metadata.get("original_mode", "RGB")
253-
writer_metadata["mode"] = original_mode
260+
self._original_mode = group_metadata.get("original_mode", "RGB")
261+
writer_metadata["mode"] = self._original_mode
254262
self._logger.debug(f"Writer metadata: {writer_metadata}")
255263
return writer_metadata
256264

257265
def write_group_metadata(self, metadata: Mapping[str, Any]) -> None:
258266
self._group_metadata = json.loads(metadata["json_write_kwargs"])
259267

268+
def _structured_dtype(self) -> np.dtype:
269+
if self._original_mode == "RGB":
270+
return np.dtype([("R", "u1"), ("G", "u1"), ("B", "u1")])
271+
260272
def write_level_image(
261273
self,
262274
image: np.ndarray,
263275
metadata: Mapping[str, Any],
264276
) -> None:
265-
nib_image = self._writer(image, metadata["affine"])
277+
header = self._nifti1header(
278+
binaryblock=base64.b64decode(self._group_metadata["binaryblock"])
279+
)
280+
contiguous_image = np.ascontiguousarray(image)
281+
structured_arr = contiguous_image.view(dtype=self._structured_dtype()).reshape(
282+
*image.shape[:-1]
283+
)
284+
nib_image = self._writer(
285+
structured_arr, header=header, affine=header.get_best_affine()
286+
)
266287
nib.save(nib_image, self._output_path)
267288

268289
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:

0 commit comments

Comments
 (0)