Skip to content

Commit b0243d4

Browse files
authored
Merge pull request #187 from eigenvivek/trilinear-renderer
Trilinear renderer
2 parents f707cc0 + 5a0d035 commit b0243d4

File tree

6 files changed

+709
-68
lines changed

6 files changed

+709
-68
lines changed

diffdrr/_modidx.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,20 @@
104104
'diffdrr.pose.so3_relative_angle': ('api/pose.html#so3_relative_angle', 'diffdrr/pose.py'),
105105
'diffdrr.pose.so3_rotation_angle': ('api/pose.html#so3_rotation_angle', 'diffdrr/pose.py'),
106106
'diffdrr.pose.standardize_quaternion': ('api/pose.html#standardize_quaternion', 'diffdrr/pose.py')},
107-
'diffdrr.siddon': { 'diffdrr.siddon._get_alpha_minmax': ('api/siddon.html#_get_alpha_minmax', 'diffdrr/siddon.py'),
108-
'diffdrr.siddon._get_alphas': ('api/siddon.html#_get_alphas', 'diffdrr/siddon.py'),
109-
'diffdrr.siddon._get_index': ('api/siddon.html#_get_index', 'diffdrr/siddon.py'),
110-
'diffdrr.siddon._get_voxel': ('api/siddon.html#_get_voxel', 'diffdrr/siddon.py'),
111-
'diffdrr.siddon.siddon_raycast': ('api/siddon.html#siddon_raycast', 'diffdrr/siddon.py')},
107+
'diffdrr.renderers': { 'diffdrr.renderers.Siddon': ('api/renderers.html#siddon', 'diffdrr/renderers.py'),
108+
'diffdrr.renderers.Siddon.__init__': ('api/renderers.html#siddon.__init__', 'diffdrr/renderers.py'),
109+
'diffdrr.renderers.Siddon.dims': ('api/renderers.html#siddon.dims', 'diffdrr/renderers.py'),
110+
'diffdrr.renderers.Siddon.forward': ('api/renderers.html#siddon.forward', 'diffdrr/renderers.py'),
111+
'diffdrr.renderers.Siddon.maxidx': ('api/renderers.html#siddon.maxidx', 'diffdrr/renderers.py'),
112+
'diffdrr.renderers.Trilinear': ('api/renderers.html#trilinear', 'diffdrr/renderers.py'),
113+
'diffdrr.renderers.Trilinear.__init__': ( 'api/renderers.html#trilinear.__init__',
114+
'diffdrr/renderers.py'),
115+
'diffdrr.renderers.Trilinear.dims': ('api/renderers.html#trilinear.dims', 'diffdrr/renderers.py'),
116+
'diffdrr.renderers.Trilinear.forward': ('api/renderers.html#trilinear.forward', 'diffdrr/renderers.py'),
117+
'diffdrr.renderers._get_alpha_minmax': ('api/renderers.html#_get_alpha_minmax', 'diffdrr/renderers.py'),
118+
'diffdrr.renderers._get_alphas': ('api/renderers.html#_get_alphas', 'diffdrr/renderers.py'),
119+
'diffdrr.renderers._get_index': ('api/renderers.html#_get_index', 'diffdrr/renderers.py'),
120+
'diffdrr.renderers._get_voxel': ('api/renderers.html#_get_voxel', 'diffdrr/renderers.py')},
112121
'diffdrr.utils': { 'diffdrr.utils.get_focal_length': ('api/utils.html#get_focal_length', 'diffdrr/utils.py'),
113122
'diffdrr.utils.get_principal_point': ('api/utils.html#get_principal_point', 'diffdrr/utils.py'),
114123
'diffdrr.utils.parse_intrinsic_matrix': ('api/utils.html#parse_intrinsic_matrix', 'diffdrr/utils.py')},

diffdrr/drr.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from fastcore.basics import patch
1010

1111
from .detector import Detector
12-
from .siddon import siddon_raycast
12+
from .renderers import Siddon, Trilinear
1313

1414
# %% auto 0
1515
__all__ = ['DRR', 'Registration']
@@ -34,6 +34,8 @@ def __init__(
3434
reverse_x_axis: bool = False, # If pose includes reflection (in E(3) not SE(3)), reverse x-axis
3535
patch_size: int | None = None, # Render patches of the DRR in series
3636
bone_attenuation_multiplier: float = 1.0, # Contrast ratio of bone to soft tissue
37+
renderer: str = "siddon", # Rendering backend, either "siddon" or "trilinear"
38+
**renderer_kwargs, # Kwargs for the renderer
3739
):
3840
super().__init__()
3941

@@ -69,6 +71,14 @@ def __init__(
6971
self.bone = torch.where(350 < self.volume)
7072
self.bone_attenuation_multiplier = bone_attenuation_multiplier
7173

74+
# Initialize the renderer
75+
if renderer == "siddon":
76+
self.renderer = Siddon(**renderer_kwargs)
77+
elif renderer == "trilinear":
78+
self.renderer = Trilinear(**renderer_kwargs)
79+
else:
80+
raise ValueError(f"renderer must be 'siddon', not {renderer}")
81+
7282
def reshape_transform(self, img, batch_size):
7383
if self.reshape:
7484
if self.detector.n_subsample is None:
@@ -101,6 +111,7 @@ def forward(
101111
parameterization: str = None, # Specifies the representation of the rotation
102112
convention: str = None, # If parameterization is Euler angles, specify convention
103113
bone_attenuation_multiplier: float = None, # Contrast ratio of bone to soft tissue
114+
**kwargs, # Passed to the renderer
104115
):
105116
"""Generate DRR with rotational and translational parameters."""
106117
if not hasattr(self, "density"):
@@ -119,11 +130,11 @@ def forward(
119130
img = []
120131
for idx in range(self.n_patches):
121132
t = target[:, idx * n_points : (idx + 1) * n_points]
122-
partial = siddon_raycast(source, t, self.density, self.spacing)
133+
partial = self.renderer(self.density, self.spacing, source, t, **kwargs)
123134
img.append(partial)
124135
img = torch.cat(img, dim=1)
125136
else:
126-
img = siddon_raycast(source, target, self.density, self.spacing)
137+
img = self.renderer(self.density, self.spacing, source, target, **kwargs)
127138
return self.reshape_transform(img, batch_size=len(pose))
128139

129140
# %% ../notebooks/api/00_drr.ipynb 11

diffdrr/siddon.py renamed to diffdrr/renderers.py

Lines changed: 83 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,44 @@
1-
# AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/api/01_siddon.ipynb.
1+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/api/01_renderers.ipynb.
22

33
# %% auto 0
4-
__all__ = ['siddon_raycast']
4+
__all__ = ['Siddon', 'Trilinear']
55

6-
# %% ../notebooks/api/01_siddon.ipynb 3
6+
# %% ../notebooks/api/01_renderers.ipynb 3
77
import torch
88

9-
# %% ../notebooks/api/01_siddon.ipynb 6
10-
def siddon_raycast(
11-
source: torch.Tensor,
12-
target: torch.Tensor,
13-
volume: torch.Tensor,
14-
spacing: torch.Tensor,
15-
eps: float = 1e-8,
16-
):
17-
"""An auto-differentiable implementation of the raycasting algorithm known as Siddon's method."""
18-
maxidx = volume.numel() - 1
19-
dims = torch.tensor(volume.shape).to(source) + 1
20-
alphas = _get_alphas(source, target, spacing, dims, eps)
21-
alphamid = (alphas[..., 0:-1] + alphas[..., 1:]) / 2
22-
voxels = _get_voxel(alphamid, source, target, volume, spacing, dims, maxidx, eps)
23-
24-
# Step length for alphas out of range will be nan
25-
# These nans cancel out voxels convereted to 0 index
26-
step_length = torch.diff(alphas, dim=-1)
27-
weighted_voxels = voxels * step_length
28-
29-
drr = torch.nansum(weighted_voxels, dim=-1)
30-
raylength = (target - source + eps).norm(dim=-1)
31-
drr *= raylength
32-
return drr
33-
34-
# %% ../notebooks/api/01_siddon.ipynb 8
9+
# %% ../notebooks/api/01_renderers.ipynb 6
10+
class Siddon(torch.nn.Module):
11+
def __init__(self, eps=1e-8):
12+
super().__init__()
13+
self.eps = eps
14+
15+
def dims(self, volume):
16+
return torch.tensor(volume.shape).to(volume) + 1
17+
18+
def maxidx(self, volume):
19+
return volume.numel() - 1
20+
21+
def forward(self, volume, spacing, source, target):
22+
dims = self.dims(volume)
23+
maxidx = self.maxidx(volume)
24+
25+
alphas = _get_alphas(source, target, spacing, dims, self.eps)
26+
alphamid = (alphas[..., 0:-1] + alphas[..., 1:]) / 2
27+
voxels = _get_voxel(
28+
alphamid, source, target, volume, spacing, dims, maxidx, self.eps
29+
)
30+
31+
# Step length for alphas out of range will be nan
32+
# These nans cancel out voxels convereted to 0 index
33+
step_length = torch.diff(alphas, dim=-1)
34+
weighted_voxels = voxels * step_length
35+
36+
drr = torch.nansum(weighted_voxels, dim=-1)
37+
raylength = (target - source + self.eps).norm(dim=-1)
38+
drr *= raylength
39+
return drr
40+
41+
# %% ../notebooks/api/01_renderers.ipynb 8
3542
def _get_alphas(source, target, spacing, dims, eps):
3643
# Get the CT sizing and spacing parameters
3744
dx, dy, dz = spacing
@@ -100,3 +107,50 @@ def _get_index(alpha, source, target, spacing, dims, maxidx, eps):
100107
idxs[idxs < 0] = 0
101108
idxs[idxs > maxidx] = maxidx
102109
return idxs
110+
111+
# %% ../notebooks/api/01_renderers.ipynb 10
112+
from torch.nn.functional import grid_sample
113+
114+
115+
class Trilinear(torch.nn.Module):
116+
def __init__(
117+
self,
118+
near=0.0,
119+
far=1.0,
120+
eps=1e-8,
121+
mode="bilinear",
122+
):
123+
super().__init__()
124+
self.near = near
125+
self.far = far
126+
self.eps = eps
127+
self.mode = mode
128+
129+
def dims(self, volume):
130+
return torch.tensor(volume.shape).to(volume) + 1
131+
132+
def forward(
133+
self, volume, spacing, source, target, n_points=100, align_corners=True
134+
):
135+
# Reorder array to match torch conventions
136+
volume = volume.permute(2, 1, 0)
137+
spacing = spacing[[2, 1, 0]]
138+
139+
# Get the raylength and reshape sources
140+
raylength = (source - target + self.eps).norm(dim=-1)
141+
source = source[:, None, :, None, :]
142+
target = target[:, None, :, None, :]
143+
144+
# Sample points along the rays and rescale to [-1, 1]
145+
alphas = torch.linspace(self.near, self.far, n_points).to(volume)
146+
alphas = alphas[None, None, None, :, None]
147+
rays = source + alphas * (target - source)
148+
rays = 2 * rays / (spacing * self.dims(volume)) - 1
149+
150+
# Render the DRR
151+
batch_size = len(rays)
152+
vol = volume[None, None, :, :, :].expand(batch_size, -1, -1, -1, -1)
153+
drr = grid_sample(vol, rays, mode=self.mode, align_corners=align_corners)
154+
drr = drr[:, 0, 0].sum(dim=-1)
155+
drr *= raylength
156+
return drr

notebooks/api/00_drr.ipynb

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
"from fastcore.basics import patch\n",
5050
"\n",
5151
"from diffdrr.detector import Detector\n",
52-
"from diffdrr.siddon import siddon_raycast"
52+
"from diffdrr.renderers import Siddon, Trilinear"
5353
]
5454
},
5555
{
@@ -129,6 +129,8 @@
129129
" reverse_x_axis: bool = False, # If pose includes reflection (in E(3) not SE(3)), reverse x-axis\n",
130130
" patch_size: int | None = None, # Render patches of the DRR in series\n",
131131
" bone_attenuation_multiplier: float = 1.0, # Contrast ratio of bone to soft tissue\n",
132+
" renderer: str = \"siddon\", # Rendering backend, either \"siddon\" or \"trilinear\"\n",
133+
" **renderer_kwargs, # Kwargs for the renderer\n",
132134
" ):\n",
133135
" super().__init__()\n",
134136
"\n",
@@ -164,6 +166,14 @@
164166
" self.bone = torch.where(350 < self.volume)\n",
165167
" self.bone_attenuation_multiplier = bone_attenuation_multiplier\n",
166168
"\n",
169+
" # Initialize the renderer\n",
170+
" if renderer == \"siddon\":\n",
171+
" self.renderer = Siddon(**renderer_kwargs)\n",
172+
" elif renderer == \"trilinear\":\n",
173+
" self.renderer = Trilinear(**renderer_kwargs)\n",
174+
" else:\n",
175+
" raise ValueError(f\"renderer must be 'siddon', not {renderer}\")\n",
176+
"\n",
167177
" def reshape_transform(self, img, batch_size):\n",
168178
" if self.reshape:\n",
169179
" if self.detector.n_subsample is None:\n",
@@ -220,6 +230,7 @@
220230
" parameterization: str = None, # Specifies the representation of the rotation\n",
221231
" convention: str = None, # If parameterization is Euler angles, specify convention\n",
222232
" bone_attenuation_multiplier: float = None, # Contrast ratio of bone to soft tissue\n",
233+
" **kwargs, # Passed to the renderer\n",
223234
"):\n",
224235
" \"\"\"Generate DRR with rotational and translational parameters.\"\"\"\n",
225236
" if not hasattr(self, \"density\"):\n",
@@ -238,11 +249,11 @@
238249
" img = []\n",
239250
" for idx in range(self.n_patches):\n",
240251
" t = target[:, idx * n_points : (idx + 1) * n_points]\n",
241-
" partial = siddon_raycast(source, t, self.density, self.spacing)\n",
252+
" partial = self.renderer(self.density, self.spacing, source, t, **kwargs)\n",
242253
" img.append(partial)\n",
243254
" img = torch.cat(img, dim=1)\n",
244255
" else:\n",
245-
" img = siddon_raycast(source, target, self.density, self.spacing)\n",
256+
" img = self.renderer(self.density, self.spacing, source, target, **kwargs)\n",
246257
" return self.reshape_transform(img, batch_size=len(pose))"
247258
]
248259
},

0 commit comments

Comments
 (0)