Skip to content

Commit daf8699

Browse files
committed
move all delayed array creation into DaskImageDataStack
1 parent 2f11d90 commit daf8699

File tree

3 files changed

+51
-39
lines changed

3 files changed

+51
-39
lines changed

mantidimaging/eyes_tests/live_viewer_window_test.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
import os
99
from mantidimaging.core.operations.loader import load_filter_packages
10-
from mantidimaging.gui.windows.live_viewer.model import Image_Data
10+
from mantidimaging.gui.windows.live_viewer.model import Image_Data, DaskImageDataStack
1111
from mantidimaging.test_helpers.unit_test_helper import FakeFSTestCase
1212
from pathlib import Path
1313
from mantidimaging.eyes_tests.base_eyes import BaseEyesTest
@@ -61,31 +61,34 @@ def test_live_view_opens_without_data(self, _mock_time, _mock_image_watcher):
6161
@mock.patch("time.time", return_value=4000.0)
6262
def test_live_view_opens_with_data(self, _mock_time, _mock_image_watcher, mock_load_image):
6363
file_list = self._make_simple_dir(self.live_directory)
64-
image_list = [Image_Data(path, create_delayed_array=False) for path in file_list]
64+
image_list = [Image_Data(path) for path in file_list]
65+
dask_image_stack = DaskImageDataStack(image_list, create_delayed_array=False)
6566
mock_load_image.return_value = self._generate_image()
6667
self.imaging.show_live_viewer(self.live_directory)
67-
self.imaging.live_viewer.presenter.model._handle_image_changed_in_list(image_list)
68+
self.imaging.live_viewer.presenter.model._handle_image_changed_in_list(image_list, dask_image_stack)
6869
self.check_target(widget=self.imaging.live_viewer)
6970

7071
@mock.patch('mantidimaging.gui.windows.live_viewer.presenter.LiveViewerWindowPresenter.load_image')
7172
@mock.patch('mantidimaging.gui.windows.live_viewer.model.ImageWatcher')
7273
@mock.patch("time.time", return_value=4000.0)
7374
def test_live_view_opens_with_bad_data(self, _mock_time, _mock_image_watcher, mock_load_image):
7475
file_list = self._make_simple_dir(self.live_directory)
75-
image_list = [Image_Data(path, create_delayed_array=False) for path in file_list]
76+
image_list = [Image_Data(path) for path in file_list]
77+
dask_image_stack = DaskImageDataStack(image_list, create_delayed_array=False)
7678
mock_load_image.side_effect = ValueError
7779
self.imaging.show_live_viewer(self.live_directory)
78-
self.imaging.live_viewer.presenter.model._handle_image_changed_in_list(image_list)
80+
self.imaging.live_viewer.presenter.model._handle_image_changed_in_list(image_list, dask_image_stack)
7981
self.check_target(widget=self.imaging.live_viewer)
8082

8183
@mock.patch('mantidimaging.gui.windows.live_viewer.presenter.LiveViewerWindowPresenter.load_image')
8284
@mock.patch('mantidimaging.gui.windows.live_viewer.model.ImageWatcher')
8385
@mock.patch("time.time", return_value=4000.0)
8486
def test_rotate_operation_rotates_image(self, _mock_time, _mock_image_watcher, mock_load_image):
8587
file_list = self._make_simple_dir(self.live_directory)
86-
image_list = [Image_Data(path, create_delayed_array=False) for path in file_list]
88+
image_list = [Image_Data(path) for path in file_list]
89+
dask_image_stack = DaskImageDataStack(image_list, create_delayed_array=False)
8790
mock_load_image.return_value = self._generate_image()
8891
self.imaging.show_live_viewer(self.live_directory)
89-
self.imaging.live_viewer.presenter.model._handle_image_changed_in_list(image_list)
92+
self.imaging.live_viewer.presenter.model._handle_image_changed_in_list(image_list, dask_image_stack)
9093
self.imaging.live_viewer.rotate_angles_group.actions()[1].trigger()
9194
self.check_target(widget=self.imaging.live_viewer)

mantidimaging/gui/windows/live_viewer/model.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,41 @@ class DaskImageDataStack:
2424
"""
2525
A Dask Image Data Stack Class to hold a delayed array of all the images in the Live Viewer Path
2626
"""
27-
delayed_stack: dask.array.Array | None = None
27+
delayed_stack: dask.array.Array
28+
image_list: list[Image_Data]
2829

29-
def __init__(self, image_list: list[Image_Data] | None):
30-
if image_list:
31-
if image_list[0].create_delayed_array:
30+
def __init__(self, image_list: list[Image_Data], create_delayed_array: bool = True):
31+
self.image_list = image_list
32+
33+
if image_list and create_delayed_array:
34+
arrays = self.get_delayed_arrays()
35+
if arrays:
3236
if image_list[0].image_path.suffix.lower() in [".tif", ".tiff"]:
33-
arrays = [image_data.delayed_array for image_data in image_list]
3437
self.delayed_stack = dask.array.stack(dask.array.array(arrays))
3538
elif image_list[0].image_path.suffix.lower() in [".fits"]:
3639
with fits.open(image_list[0].image_path.__str__()) as fit:
3740
sample = fit[0].data
38-
arrays = [image_data.delayed_array for image_data in image_list]
3941
lazy_arrays = [dask.array.from_delayed(x, shape=sample.shape, dtype=sample.dtype) for x in arrays]
4042
self.delayed_stack = dask.array.stack(lazy_arrays)
4143

4244
@property
4345
def shape(self):
4446
return self.delayed_stack.shape
4547

48+
def get_delayed_arrays(self) -> list[dask.array.Array] | None:
49+
if self.image_list[0].image_path.suffix.lower() in [".tif", ".tiff"]:
50+
return [dask_image.imread.imread(image_data.image_path)[0] for image_data in self.image_list]
51+
elif self.image_list[0].image_path.suffix.lower() == ".fits":
52+
return [dask.delayed(fits.open)(image_data.image_path)[0].data for image_data in self.image_list]
53+
else:
54+
return None
55+
56+
def get_delayed_image(self, index) -> dask.array.Array:
57+
return self.delayed_stack[index]
58+
59+
def get_image_data(self, index) -> Image_Data:
60+
return self.image_list[index]
61+
4662

4763
class Image_Data:
4864
"""
@@ -66,7 +82,7 @@ class Image_Data:
6682
delayed_array: dask.array.Array
6783
create_delayed_array: bool
6884

69-
def __init__(self, image_path: Path, create_delayed_array: bool = True):
85+
def __init__(self, image_path: Path):
7086
"""
7187
Constructor for Image_Data class.
7288
@@ -78,9 +94,6 @@ def __init__(self, image_path: Path, create_delayed_array: bool = True):
7894
self.image_path = image_path
7995
self.image_name = image_path.name
8096
self._stat = image_path.stat()
81-
self.create_delayed_array = create_delayed_array
82-
if self.create_delayed_array:
83-
self.set_delayed_array()
8497

8598
@property
8699
def stat(self) -> stat_result:
@@ -96,12 +109,6 @@ def image_modified_time_stamp(self) -> str:
96109
"""Return the image modified time as a string"""
97110
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.image_modified_time))
98111

99-
def set_delayed_array(self) -> None:
100-
if self.image_path.suffix.lower() in [".tif", ".tiff"]:
101-
self.delayed_array = dask_image.imread.imread(self.image_path)[0]
102-
elif self.image_path.suffix.lower() == ".fits":
103-
self.delayed_array = dask.delayed(fits.open)(self.image_path)[0].data
104-
105112

106113
class SubDirectory:
107114

@@ -145,7 +152,7 @@ def __init__(self, presenter: LiveViewerWindowPresenter):
145152
self._dataset_path: Path | None = None
146153
self.image_watcher: ImageWatcher | None = None
147154
self.images: list[Image_Data] = []
148-
self.image_stack: DaskImageDataStack | None
155+
self.image_stack: DaskImageDataStack
149156

150157
@property
151158
def path(self) -> Path | None:
@@ -159,9 +166,8 @@ def path(self, path: Path) -> None:
159166
self.image_watcher.recent_image_changed.connect(self.handle_image_modified)
160167
self.image_watcher._handle_notified_of_directry_change(str(path))
161168

162-
def _handle_image_changed_in_list(self,
163-
image_files: list[Image_Data],
164-
dask_image_stack: DaskImageDataStack | None = None) -> None:
169+
def _handle_image_changed_in_list(self, image_files: list[Image_Data],
170+
dask_image_stack: DaskImageDataStack) -> None:
165171
"""
166172
Handle an image changed event. Update the image in the view.
167173
This method is called when the image_watcher detects a change
@@ -243,7 +249,7 @@ def find_images(self, directory: Path) -> list[Image_Data]:
243249
for file_path in directory.iterdir():
244250
if self._is_image_file(file_path.name):
245251
try:
246-
image_obj = Image_Data(file_path, create_delayed_array=self.create_delayed_array)
252+
image_obj = Image_Data(file_path)
247253
image_files.append(image_obj)
248254
except FileNotFoundError:
249255
continue
@@ -315,7 +321,7 @@ def _handle_directory_change(self) -> None:
315321
break
316322

317323
images = self.sort_images_by_modified_time(images)
318-
dask_image_stack = DaskImageDataStack(images)
324+
dask_image_stack = DaskImageDataStack(images, create_delayed_array=self.create_delayed_array)
319325
self.update_recent_watcher(images[-1:])
320326
self.image_changed.emit(images, dask_image_stack)
321327

mantidimaging/gui/windows/live_viewer/presenter.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from collections.abc import Callable
88
from logging import getLogger
99
import numpy as np
10+
import dask.array
1011

1112
from imagecodecs._deflate import DeflateError
1213

@@ -40,6 +41,8 @@ def __init__(self, view: LiveViewerWindowView, main_window: MainWindowView):
4041
self.main_window = main_window
4142
self.model = LiveViewerWindowModel(self)
4243
self.selected_image: Image_Data | None = None
44+
self.selected_delayed_image: dask.array.Array
45+
4346
self.filters = {f.filter_name: f for f in load_filter_packages()}
4447

4548
def close(self) -> None:
@@ -73,20 +76,21 @@ def update_image_list(self, images_list: list[Image_Data]) -> None:
7376
self.view.set_image_index(len(images_list) - 1)
7477

7578
def select_image(self, index: int) -> None:
76-
if not self.model.images:
79+
self.selected_image = self.model.image_stack.get_image_data(index)
80+
self.selected_delayed_image = self.model.image_stack.get_delayed_image(index)
81+
if not self.selected_image:
7782
return
78-
self.selected_image = self.model.images[index]
7983
image_timestamp = self.selected_image.image_modified_time_stamp
8084
self.view.label_active_filename.setText(f"{self.selected_image.image_name} - {image_timestamp}")
8185

82-
self.display_image(self.selected_image)
86+
self.display_image(self.selected_image, self.selected_delayed_image)
8387

84-
def display_image(self, image_data_obj: Image_Data) -> None:
88+
def display_image(self, image_data_obj: Image_Data, delayed_image: dask.array.Array) -> None:
8589
"""
8690
Display image in the view after validating contents
8791
"""
8892
try:
89-
image_data = self.load_image(image_data_obj)
93+
image_data = self.load_image(delayed_image)
9094
except (OSError, KeyError, ValueError, DeflateError) as error:
9195
message = f"{type(error).__name__} reading image: {image_data_obj.image_path}: {error}"
9296
logger.error(message)
@@ -104,28 +108,27 @@ def display_image(self, image_data_obj: Image_Data) -> None:
104108
self.view.live_viewer.show_error(None)
105109

106110
@staticmethod
107-
def load_image(image_data_obj: Image_Data) -> np.ndarray:
111+
def load_image(delayed_image: dask.array.Array) -> np.ndarray:
108112
"""
109113
Load a .Tif, .Tiff or .Fits file only if it exists
110114
and returns as an ndarray
111115
"""
112-
if image_data_obj.image_path.suffix.lower() in [".tif", ".tiff", ".fits"]:
113-
image_data = image_data_obj.delayed_array.compute()
116+
image_data = delayed_image.compute()
114117
return image_data
115118

116119
def update_image_modified(self, image_path: Path) -> None:
117120
"""
118121
Update the displayed image when the file is modified
119122
"""
120123
if self.selected_image and image_path == self.selected_image.image_path:
121-
self.display_image(self.selected_image)
124+
self.display_image(self.selected_image, self.selected_delayed_image)
122125

123126
def update_image_operation(self) -> None:
124127
"""
125128
Reload the current image if an operation has been performed on the current image
126129
"""
127130
if self.selected_image is not None:
128-
self.display_image(self.selected_image)
131+
self.display_image(self.selected_image, self.selected_delayed_image)
129132

130133
def convert_image_to_imagestack(self, image_data) -> ImageStack:
131134
"""

0 commit comments

Comments
 (0)