|
4 | 4 | import unittest
|
5 | 5 | from unittest import mock
|
6 | 6 | import numpy as np
|
| 7 | +import numpy.testing as npt |
7 | 8 |
|
8 | 9 | from mantidimaging.test_helpers.start_qapplication import start_multiprocessing_pool
|
9 |
| -from mantidimaging.test_helpers.unit_test_helper import generate_images, assert_not_equals |
| 10 | +from mantidimaging.test_helpers.unit_test_helper import generate_images |
10 | 11 | from ..polyfit_correlation import do_calculate_correlation_err, get_search_range, find_center, _find_shift
|
11 | 12 | from ...data import ImageStack
|
12 | 13 | from ...utility.progress_reporting import Progress
|
@@ -36,33 +37,46 @@ def test_do_search(self):
|
36 | 37 | def test_find_center(self):
|
37 | 38 | images = generate_images((10, 10, 10))
|
38 | 39 | images.data[0] = np.identity(10)
|
39 |
| - images.proj180deg = ImageStack(np.fliplr(images.data)) |
| 40 | + images.proj180deg = ImageStack(np.fliplr(images.data[0:1])) |
40 | 41 | mock_progress = mock.create_autospec(Progress)
|
41 | 42 | res_cor, res_tilt = find_center(images, mock_progress)
|
42 | 43 | assert mock_progress.update.call_count == 11
|
43 | 44 | assert res_cor.value == 5.0, f"Found {res_cor.value}"
|
44 | 45 | assert res_tilt.value == 0.0, f"Found {res_tilt.value}"
|
45 | 46 |
|
46 |
| - def test_find_shift(self): |
47 |
| - rng = np.random.default_rng() |
| 47 | + def test_find_center_offset(self): |
48 | 48 | images = generate_images((10, 10, 10))
|
49 |
| - search_range = get_search_range(images.width) |
50 |
| - min_correlation_error = rng.random((len(search_range), images.height)) |
51 |
| - shift = np.zeros(images.height) |
| 49 | + images.data[0] = np.identity(10) |
| 50 | + images.proj180deg = ImageStack(np.fliplr(images.data[0:1])) |
| 51 | + self.crop_images(images, (2, 10, 0, 10)) |
| 52 | + self.crop_images(images.proj180deg, (2, 10, 0, 10)) |
| 53 | + mock_progress = mock.create_autospec(Progress) |
| 54 | + res_cor, res_tilt = find_center(images, mock_progress) |
| 55 | + assert res_cor.value == 4.0, f"Found {res_cor.value}" |
| 56 | + assert abs(res_tilt.value) < 1e-6, f"Found {res_tilt.value}" |
| 57 | + |
| 58 | + def test_find_shift(self): |
| 59 | + images = mock.Mock(height=3) |
| 60 | + min_correlation_error = np.array([[1, 2, 2, 2, 2, 2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3, 3, 3, 2, 3], |
| 61 | + [4, 4, 4, 4, 3, 4, 4, 4, 4, 4]]).T |
| 62 | + search_range = get_search_range(10) |
| 63 | + shift = np.zeros(3) |
52 | 64 | _find_shift(images, search_range, min_correlation_error, shift)
|
53 |
| - # check that the shift has been changed |
54 |
| - assert_not_equals(shift, np.zeros((images.height, ))) |
| 65 | + npt.assert_array_equal(np.array([-5, 3, -1]), shift) |
55 | 66 |
|
56 | 67 | def test_find_shift_multiple_argmin(self):
|
57 |
| - rng = np.random.default_rng() |
58 |
| - images = generate_images((10, 10, 10)) |
59 |
| - search_range = get_search_range(images.width) |
60 |
| - min_correlation_error = rng.random((len(search_range), images.height)) |
61 |
| - min_correlation_error.T[0][3] = min_correlation_error.T[0][4] = 0 |
62 |
| - shift = np.zeros((images.height, )) |
| 68 | + images = mock.Mock(height=3) |
| 69 | + min_correlation_error = np.array([[1, 2, 2, 2, 2, 2, 2, 2, 2, 1], [3, 3, 3, 3, 3, 3, 3, 3, 2, 2], |
| 70 | + [4, 4, 4, 4, 3, 3, 4, 4, 4, 4]]).T |
| 71 | + search_range = get_search_range(10) |
| 72 | + shift = np.zeros(3) |
63 | 73 | _find_shift(images, search_range, min_correlation_error, shift)
|
64 |
| - # check that the shift has been changed |
65 |
| - assert_not_equals(shift, np.zeros((images.height, ))) |
| 74 | + npt.assert_array_equal(np.array([-5, 3, -1]), shift) |
| 75 | + |
| 76 | + @staticmethod |
| 77 | + def crop_images(images, crop_coords): |
| 78 | + x_start, x_end, y_start, y_end = crop_coords |
| 79 | + images.data = images.data[:, y_start:y_end, x_start:x_end] |
66 | 80 |
|
67 | 81 |
|
68 | 82 | if __name__ == '__main__':
|
|
0 commit comments