Skip to content

Commit 5e6fadf

Browse files
Enhancing Test Coverage and Reliability for _find_shift Function (#2286)
2 parents ed6cdc8 + 0bc727c commit 5e6fadf

File tree

1 file changed

+31
-17
lines changed

1 file changed

+31
-17
lines changed

mantidimaging/core/rotation/test/polyfit_correlation_test.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
import unittest
55
from unittest import mock
66
import numpy as np
7+
import numpy.testing as npt
78

89
from mantidimaging.test_helpers.start_qapplication import start_multiprocessing_pool
9-
from mantidimaging.test_helpers.unit_test_helper import generate_images, assert_not_equals
10+
from mantidimaging.test_helpers.unit_test_helper import generate_images
1011
from ..polyfit_correlation import do_calculate_correlation_err, get_search_range, find_center, _find_shift
1112
from ...data import ImageStack
1213
from ...utility.progress_reporting import Progress
@@ -36,33 +37,46 @@ def test_do_search(self):
3637
def test_find_center(self):
3738
images = generate_images((10, 10, 10))
3839
images.data[0] = np.identity(10)
39-
images.proj180deg = ImageStack(np.fliplr(images.data))
40+
images.proj180deg = ImageStack(np.fliplr(images.data[0:1]))
4041
mock_progress = mock.create_autospec(Progress)
4142
res_cor, res_tilt = find_center(images, mock_progress)
4243
assert mock_progress.update.call_count == 11
4344
assert res_cor.value == 5.0, f"Found {res_cor.value}"
4445
assert res_tilt.value == 0.0, f"Found {res_tilt.value}"
4546

46-
def test_find_shift(self):
47-
rng = np.random.default_rng()
47+
def test_find_center_offset(self):
4848
images = generate_images((10, 10, 10))
49-
search_range = get_search_range(images.width)
50-
min_correlation_error = rng.random((len(search_range), images.height))
51-
shift = np.zeros(images.height)
49+
images.data[0] = np.identity(10)
50+
images.proj180deg = ImageStack(np.fliplr(images.data[0:1]))
51+
self.crop_images(images, (2, 10, 0, 10))
52+
self.crop_images(images.proj180deg, (2, 10, 0, 10))
53+
mock_progress = mock.create_autospec(Progress)
54+
res_cor, res_tilt = find_center(images, mock_progress)
55+
assert res_cor.value == 4.0, f"Found {res_cor.value}"
56+
assert abs(res_tilt.value) < 1e-6, f"Found {res_tilt.value}"
57+
58+
def test_find_shift(self):
59+
images = mock.Mock(height=3)
60+
min_correlation_error = np.array([[1, 2, 2, 2, 2, 2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3, 3, 3, 2, 3],
61+
[4, 4, 4, 4, 3, 4, 4, 4, 4, 4]]).T
62+
search_range = get_search_range(10)
63+
shift = np.zeros(3)
5264
_find_shift(images, search_range, min_correlation_error, shift)
53-
# check that the shift has been changed
54-
assert_not_equals(shift, np.zeros((images.height, )))
65+
npt.assert_array_equal(np.array([-5, 3, -1]), shift)
5566

5667
def test_find_shift_multiple_argmin(self):
57-
rng = np.random.default_rng()
58-
images = generate_images((10, 10, 10))
59-
search_range = get_search_range(images.width)
60-
min_correlation_error = rng.random((len(search_range), images.height))
61-
min_correlation_error.T[0][3] = min_correlation_error.T[0][4] = 0
62-
shift = np.zeros((images.height, ))
68+
images = mock.Mock(height=3)
69+
min_correlation_error = np.array([[1, 2, 2, 2, 2, 2, 2, 2, 2, 1], [3, 3, 3, 3, 3, 3, 3, 3, 2, 2],
70+
[4, 4, 4, 4, 3, 3, 4, 4, 4, 4]]).T
71+
search_range = get_search_range(10)
72+
shift = np.zeros(3)
6373
_find_shift(images, search_range, min_correlation_error, shift)
64-
# check that the shift has been changed
65-
assert_not_equals(shift, np.zeros((images.height, )))
74+
npt.assert_array_equal(np.array([-5, 3, -1]), shift)
75+
76+
@staticmethod
77+
def crop_images(images, crop_coords):
78+
x_start, x_end, y_start, y_end = crop_coords
79+
images.data = images.data[:, y_start:y_end, x_start:x_end]
6680

6781

6882
if __name__ == '__main__':

0 commit comments

Comments
 (0)