Skip to content

Commit 8f8c0cf

Browse files
author
Holger Kohr
committed
ENH: vectorize geometry methods
1 parent 1a5f3c3 commit 8f8c0cf

File tree

5 files changed

+179
-121
lines changed

5 files changed

+179
-121
lines changed

odl/tomo/backends/astra_setup.py

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -248,19 +248,18 @@ def astra_conebeam_3d_geom_to_vec(geometry):
248248
angles = geometry.angles
249249
vectors = np.zeros((angles.size, 12))
250250

251-
for ang_idx, angle in enumerate(angles):
252-
# Source position
253-
vectors[ang_idx, 0:3] = geometry.src_position(angle)
251+
# Source position
252+
vectors[:, 0:3] = geometry.src_position(angles)
254253

255-
# Center of detector in 3D space
256-
mid_pt = geometry.det_params.mid_pt
257-
vectors[ang_idx, 3:6] = geometry.det_point_position(angle, mid_pt)
254+
# Center of detector in 3D space
255+
mid_pt = geometry.det_params.mid_pt
256+
vectors[:, 3:6] = geometry.det_point_position(angles, mid_pt)
258257

259-
# Vectors from detector pixel (0, 0) to (1, 0) and (0, 0) to (0, 1)
260-
det_axes = geometry.det_axes(angle)
261-
px_sizes = geometry.det_partition.cell_sides
262-
vectors[ang_idx, 6:9] = det_axes[0] * px_sizes[0]
263-
vectors[ang_idx, 9:12] = det_axes[1] * px_sizes[1]
258+
# Vectors from detector pixel (0, 0) to (1, 0) and (0, 0) to (0, 1)
259+
det_axes = geometry.det_axes(angles)
260+
px_sizes = geometry.det_partition.cell_sides
261+
vectors[:, 6:9] = det_axes[0] * px_sizes[0]
262+
vectors[:, 9:12] = det_axes[1] * px_sizes[1]
264263

265264
# ASTRA has (z, y, x) axis convention, in contrast to (x, y, z) in ODL,
266265
# so we need to adapt to this by changing the order.
@@ -312,19 +311,19 @@ def astra_conebeam_2d_geom_to_vec(geometry):
312311
angles = geometry.angles
313312
vectors = np.zeros((angles.size, 6))
314313

315-
for ang_idx, angle in enumerate(angles):
316-
# Source position
317-
vectors[ang_idx, 0:2] = rot_minus_90.dot(geometry.src_position(angle))
314+
# Source position
315+
src_pos = geometry.src_position(angles)
316+
vectors[:, 0:2] = rot_minus_90.dot(src_pos.T).T # dot along 2nd axis
318317

319-
# Center of detector
320-
mid_pt = geometry.det_params.mid_pt
321-
vectors[ang_idx, 2:4] = rot_minus_90.dot(
322-
geometry.det_point_position(angle, mid_pt))
318+
# Center of detector
319+
mid_pt = geometry.det_params.mid_pt
320+
centers = geometry.det_point_position(angles, mid_pt)
321+
vectors[:, 2:4] = rot_minus_90.dot(centers.T).T
323322

324-
# Vector from detector pixel 0 to 1
325-
det_axis = rot_minus_90.dot(geometry.det_axis(angle))
326-
px_size = geometry.det_partition.cell_sides[0]
327-
vectors[ang_idx, 4:6] = det_axis * px_size
323+
# Vector from detector pixel 0 to 1
324+
det_axis = rot_minus_90.dot(geometry.det_axis(angles).T).T
325+
px_size = geometry.det_partition.cell_sides[0]
326+
vectors[:, 4:6] = det_axis * px_size
328327

329328
return vectors
330329

@@ -366,20 +365,19 @@ def astra_parallel_3d_geom_to_vec(geometry):
366365
angles = geometry.angles
367366
vectors = np.zeros((angles.shape[0], 12))
368367

369-
for ang_idx, angle in enumerate(angles):
370-
mid_pt = geometry.det_params.mid_pt
368+
mid_pt = geometry.det_params.mid_pt
371369

372-
# Ray direction = -(detector-to-source normal vector)
373-
vectors[ang_idx, 0:3] = -geometry.det_to_src(angle, mid_pt)
370+
# Ray direction = -(detector-to-source normal vector)
371+
vectors[:, 0:3] = -geometry.det_to_src(angles, mid_pt)
374372

375-
# Center of the detector in 3D space
376-
vectors[ang_idx, 3:6] = geometry.det_point_position(angle, mid_pt)
373+
# Center of the detector in 3D space
374+
vectors[:, 3:6] = geometry.det_point_position(angles, mid_pt)
377375

378-
# Vectors from detector pixel (0, 0) to (1, 0) and (0, 0) to (0, 1)
379-
det_axes = geometry.det_axes(angle)
380-
px_sizes = geometry.det_partition.cell_sides
381-
vectors[ang_idx, 6:9] = det_axes[0] * px_sizes[0]
382-
vectors[ang_idx, 9:12] = det_axes[1] * px_sizes[1]
376+
# Vectors from detector pixel (0, 0) to (1, 0) and (0, 0) to (0, 1)
377+
det_axes = geometry.det_axes(angles)
378+
px_sizes = geometry.det_partition.cell_sides
379+
vectors[:, 6:9] = det_axes[0] * px_sizes[0]
380+
vectors[:, 9:12] = det_axes[1] * px_sizes[1]
383381

384382
# ASTRA has (z, y, x) axis convention, in contrast to (x, y, z) in ODL,
385383
# so we need to adapt to this by changing the order.

odl/tomo/geometry/conebeam.py

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def __init__(self, apart, dpart, src_radius, det_radius,
183183
detector = Flat1dDetector(dpart, det_axis_init)
184184
translation = kwargs.pop('translation', None)
185185
super().__init__(ndim=2, motion_part=apart, detector=detector,
186-
translation=translation)
186+
translation=translation, **kwargs)
187187

188188
self.__src_radius = float(src_radius)
189189
if self.src_radius < 0:
@@ -202,11 +202,6 @@ def __init__(self, apart, dpart, src_radius, det_radius,
202202
raise ValueError('`apart` has dimension {}, expected 1'
203203
''.format(self.motion_partition.ndim))
204204

205-
# Make sure there are no leftover kwargs
206-
if kwargs:
207-
raise TypeError('got unexpected keyword arguments {}'
208-
''.format(kwargs))
209-
210205
@classmethod
211206
def frommatrix(cls, apart, dpart, src_radius, det_radius, init_matrix):
212207
"""Create an instance of `FanFlatGeometry` using a matrix.
@@ -357,16 +352,19 @@ def src_position(self, angle):
357352
>>> np.allclose(geom.src_position(np.pi / 2), [2, 0])
358353
True
359354
"""
360-
if angle not in self.motion_params:
361-
raise ValueError('`angle` {} is not in the valid range {}'
355+
if self.check_bounds and not self.motion_params.contains_all(angle):
356+
raise ValueError('`angle` {} not in the valid range {}'
362357
''.format(angle, self.motion_params))
363358

359+
angle = np.array(angle, dtype=float, copy=False, ndmin=1)
360+
364361
# Initial vector from the rotation center to the source. It can be
365362
# computed this way since source and detector are at maximum distance,
366363
# i.e. the connecting line passes the origin.
367364
center_to_src_init = -self.src_radius * self.src_to_det_init
368-
return (self.translation +
369-
self.rotation_matrix(angle).dot(center_to_src_init))
365+
pos_vec = (self.translation[None, :] +
366+
self.rotation_matrix(angle).dot(center_to_src_init))
367+
return pos_vec.squeeze()
370368

371369
def det_refpoint(self, angle):
372370
"""Return the detector reference point position at ``angle``.
@@ -407,16 +405,18 @@ def det_refpoint(self, angle):
407405
>>> np.allclose(geom.det_refpoint(np.pi / 2), [-5, 0])
408406
True
409407
"""
410-
if angle not in self.motion_params:
411-
raise ValueError('`angle` {} is not in the valid range {}'
408+
if self.check_bounds and not self.motion_params.contains_all(angle):
409+
raise ValueError('`angle` {} not in the valid range {}'
412410
''.format(angle, self.motion_params))
413411

412+
angle = np.array(angle, dtype=float, copy=False, ndmin=1)
414413
# Initial vector from the rotation center to the detector. It can be
415414
# computed this way since source and detector are at maximum distance,
416415
# i.e. the connecting line passes the origin.
417416
center_to_det_init = self.det_radius * self.src_to_det_init
418-
return (self.translation +
419-
self.rotation_matrix(angle).dot(center_to_det_init))
417+
refpt = (self.translation[None, :] +
418+
self.rotation_matrix(angle).dot(center_to_det_init))
419+
return refpt.squeeze()
420420

421421
def rotation_matrix(self, angle):
422422
"""Return the rotation matrix for ``angle``.
@@ -440,10 +440,10 @@ def rotation_matrix(self, angle):
440440
the local coordinate system of the detector reference point,
441441
expressed in the fixed system.
442442
"""
443-
angle = float(angle)
444-
if angle not in self.motion_params:
443+
if self.check_bounds and not self.motion_params.contains_all(angle):
445444
raise ValueError('`angle` {} not in the valid range {}'
446445
''.format(angle, self.motion_params))
446+
447447
return euler_matrix(angle)
448448

449449
def __repr__(self):
@@ -834,9 +834,9 @@ def angles(self):
834834
"""Discrete angles given in this geometry."""
835835
return self.motion_grid.coord_vectors[0]
836836

837-
def det_axes(self, angles):
837+
def det_axes(self, angle):
838838
"""Return the detector axes tuple at ``angle``."""
839-
return tuple(self.rotation_matrix(angles).dot(axis)
839+
return tuple(self.rotation_matrix(angle).dot(axis)
840840
for axis in self.det_axes_init)
841841

842842
def det_refpoint(self, angle):
@@ -881,22 +881,25 @@ def det_refpoint(self, angle):
881881
>>> np.allclose(geom.det_refpoint(np.pi / 2), [-10, 0, 0.5])
882882
True
883883
"""
884-
angle = float(angle)
885-
if angle not in self.motion_params:
886-
raise ValueError('`angle` {} is not in the valid range {}'
884+
if self.check_bounds and not self.motion_params.contains_all(angle):
885+
raise ValueError('`angle` {} not in the valid range {}'
887886
''.format(angle, self.motion_params))
888887

888+
angle = np.array(angle, dtype=float, copy=False, ndmin=1)
889+
889890
# Initial vector from center of rotation to detector.
890891
# It can be computed this way since source and detector are at
891892
# maximum distance, i.e. the connecting line passes the origin.
892893
center_to_det_init = self.det_radius * self.src_to_det_init
893894
circle_component = self.rotation_matrix(angle).dot(center_to_det_init)
894895

895896
# Increment along the rotation axis according to pitch and pitch_offset
896-
pitch_component = self.axis * (self.pitch_offset +
897-
self.pitch * angle / (2 * np.pi))
897+
shift_along_axis = (self.pitch_offset +
898+
self.pitch * angle / (2 * np.pi))
899+
pitch_component = self.axis[None, :] * shift_along_axis[:, None]
898900

899-
return self.translation + circle_component + pitch_component
901+
refpt = self.translation[None, :] + circle_component + pitch_component
902+
return refpt.squeeze()
900903

901904
def src_position(self, angle):
902905
"""Return the source position at ``angle``.
@@ -940,22 +943,25 @@ def src_position(self, angle):
940943
>>> np.allclose(geom.src_position(np.pi / 2), [5, 0, 0.5])
941944
True
942945
"""
943-
angle = float(angle)
944-
if angle not in self.motion_params:
945-
raise ValueError('`angle` {} is not in the valid range {}'
946+
if self.check_bounds and not self.motion_params.contains_all(angle):
947+
raise ValueError('`angle` {} not in the valid range {}'
946948
''.format(angle, self.motion_params))
947949

950+
angle = np.array(angle, dtype=float, copy=False, ndmin=1)
951+
948952
# Initial vector from 0 to the source (non-translated).
949953
# It can be computed this way since source and detector are at
950954
# maximum distance, i.e. the connecting line passes the origin.
951955
origin_to_src_init = -self.src_radius * self.src_to_det_init
952956
circle_component = self.rotation_matrix(angle).dot(origin_to_src_init)
953957

954-
# Increment by pitch
955-
pitch_component = self.axis * (self.pitch_offset +
956-
self.pitch * angle / (np.pi * 2))
958+
# Increment along the rotation axis according to pitch and pitch_offset
959+
shift_along_axis = (self.pitch_offset +
960+
self.pitch * angle / (2 * np.pi))
961+
pitch_component = self.axis[None, :] * shift_along_axis[:, None]
957962

958-
return self.translation + circle_component + pitch_component
963+
refpt = self.translation[None, :] + circle_component + pitch_component
964+
return refpt.squeeze()
959965

960966
def __repr__(self):
961967
"""Return ``repr(self)``."""

odl/tomo/geometry/geometry.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,26 @@ class Geometry(object):
4545
<https://odlgroup.github.io/odl/guide/geometry_guide.html>`_.
4646
"""
4747

48-
def __init__(self, ndim, motion_part, detector, translation=None):
48+
def __init__(self, ndim, motion_part, detector, translation=None,
49+
**kwargs):
4950
"""Initialize a new instance.
5051
5152
Parameters
5253
----------
5354
ndim : positive int
5455
Number of dimensions of this geometry, i.e. dimensionality
55-
of the physical space in which this geometry is embedded
56+
of the physical space in which this geometry is embedded.
5657
motion_part : `RectPartition`
57-
Partition for the set of "motion" parameters
58+
Partition for the set of "motion" parameters.
5859
detector : `Detector`
59-
The detector of this geometry
60+
The detector of this geometry.
61+
translation : `array-like`, optional
62+
Global translation of the geometry. This is added last in any
63+
method that computes an absolute vector, e.g., `det_refpoint`.
64+
check_bounds : bool, optional
65+
If ``True``, check if provided parameters for query functions
66+
like `det_refpoint` are in the valid range.
67+
Default: ``True``
6068
"""
6169
ndim, ndim_in = int(ndim), ndim
6270
if ndim != ndim_in or ndim <= 0:
@@ -72,6 +80,7 @@ def __init__(self, ndim, motion_part, detector, translation=None):
7280
self.__ndim = ndim
7381
self.__motion_partition = motion_part
7482
self.__detector = detector
83+
self.__check_bounds = bool(kwargs.pop('check_bounds', True))
7584

7685
if translation is None:
7786
self.__translation = np.zeros(self.ndim)
@@ -84,6 +93,11 @@ def __init__(self, ndim, motion_part, detector, translation=None):
8493

8594
self.__implementation_cache = {}
8695

96+
# Make sure there are no leftover kwargs
97+
if kwargs:
98+
raise TypeError('got unexpected keyword arguments {}'
99+
''.format(kwargs))
100+
87101
@property
88102
def ndim(self):
89103
"""Number of dimensions of the geometry."""
@@ -154,6 +168,11 @@ def translation(self):
154168
"""Shift of the origin of this geometry."""
155169
return self.__translation
156170

171+
@property
172+
def check_bounds(self):
173+
"""Whether to check if method parameters are in the valid range."""
174+
return self.__check_bounds
175+
157176
def det_refpoint(self, mpar):
158177
"""Detector reference point function.
159178
@@ -271,7 +290,7 @@ def src_position(self, mpar):
271290
"""
272291
raise NotImplementedError('abstract method')
273292

274-
def det_to_src(self, mpar, dpar, normalized=True):
293+
def det_to_src(self, mparams, dparams, normalized=True):
275294
"""Vector pointing from a detector location to the source.
276295
277296
A function of the motion and detector parameters.
@@ -294,20 +313,22 @@ def det_to_src(self, mpar, dpar, normalized=True):
294313
vec : `numpy.ndarray`, shape (`ndim`,)
295314
(Unit) vector pointing from the detector to the source.
296315
"""
297-
if mpar not in self.motion_params:
298-
raise ValueError('`mpar` {} not in the valid range {}'
299-
''.format(mpar, self.motion_params))
300-
if dpar not in self.det_params:
301-
raise ValueError('`dpar` {} not in the valid range {}'
302-
''.format(dpar, self.det_params))
316+
if self.check_bounds:
317+
if not self.motion_params.contains_all(mparams):
318+
raise ValueError('`mparams` {} not in the valid range {}'
319+
''.format(mparams, self.motion_params))
320+
if not self.det_params.contains_all(dparams):
321+
raise ValueError('`dparams` {} not in the valid range {}'
322+
''.format(dparams, self.det_params))
303323

304-
vec = self.src_position(mpar) - self.det_point_position(mpar, dpar)
324+
det_to_src_vec = (self.src_position(mparams) -
325+
self.det_point_position(mparams, dparams))
305326

306327
if normalized:
307328
# axis = -1 allows this to be vectorized
308-
vec /= np.linalg.norm(vec, axis=-1)
329+
det_to_src_vec /= np.linalg.norm(det_to_src_vec, axis=-1)
309330

310-
return vec
331+
return det_to_src_vec
311332

312333

313334
class AxisOrientedGeometry(object):
@@ -358,9 +379,8 @@ def rotation_matrix(self, angle):
358379
the local coordinate system of the detector reference point,
359380
expressed in the fixed system.
360381
"""
361-
angle = float(angle)
362-
if angle not in self.motion_params:
363-
raise ValueError('`angle` {} is not in the valid range {}'
382+
if self.check_bounds and not self.motion_params.contains_all(angle):
383+
raise ValueError('`angle` {} not in the valid range {}'
364384
''.format(angle, self.motion_params))
365385

366386
return axis_rotation_matrix(self.axis, angle)

0 commit comments

Comments
 (0)