Skip to content

Commit 4f0d4ff

Browse files
authored
Merge pull request #139 from eigenvivek/diffrend
Differentiable rendering and large DRRs
2 parents a4fa970 + d1a39d6 commit 4f0d4ff

File tree

16 files changed

+421
-346
lines changed

16 files changed

+421
-346
lines changed

README.md

Lines changed: 32 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,39 @@
11
DiffDRR
22
================
33

4-
<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->
5-
64
> Auto-differentiable DRR synthesis and optimization in PyTorch
75
86
[![CI](https://github.com/eigenvivek/DiffDRR/actions/workflows/test.yaml/badge.svg)](https://github.com/eigenvivek/DiffDRR/actions/workflows/test.yaml)
9-
[![Paper
10-
shield](https://img.shields.io/badge/arXiv-2208.12737-red.svg)](https://arxiv.org/abs/2208.12737)
11-
[![License:
12-
MIT](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE)
13-
[![Downloads](https://static.pepy.tech/personalized-badge/diffdrr?period=month&units=international_system&left_color=grey&right_color=blue&left_text=downloads.month)](https://pepy.tech/project/diffdrr)
7+
[![Paper shield](https://img.shields.io/badge/arXiv-2208.12737-red.svg)](https://arxiv.org/abs/2208.12737)
8+
[![License: MIT](https://img.shields.io/badge/License-Apache2.0-blue.svg)](LICENSE)
9+
[![Downloads](https://static.pepy.tech/personalized-badge/diffdrr?period=total&units=none&left_color=grey&right_color=blue&left_text=downloads)](https://pepy.tech/project/diffdrr)
1410
[![Docs](https://img.shields.io/badge/docs-passing-brightgreen.svg)](https://vivekg.dev/DiffDRR/)
15-
[![Code style:
16-
black](https://img.shields.io/badge/Code%20style-black-black.svg)](https://github.com/psf/black)
11+
[![Code style: black](https://img.shields.io/badge/Code%20style-black-black.svg)](https://github.com/psf/black)
1712

18-
`DiffDRR` is a PyTorch-based digitally reconstructed radiograph (DRR)
19-
generator that provides
13+
`DiffDRR` is a PyTorch-based digitally reconstructed radiograph (DRR) generator that provides
2014

21-
1. Auto-differentiable DRR syntheisis
22-
2. GPU-accelerated rendering
23-
3. A pure Python implementation
15+
1. Auto-differentiable DRR syntheisis
16+
2. GPU-accelerated rendering
17+
3. A pure Python implementation
2418

25-
Most importantly, `DiffDRR` implements DRR synthesis as a PyTorch
26-
module, making it interoperable in deep learning pipelines.
19+
Most importantly, `DiffDRR` implements DRR synthesis as a PyTorch module, making it interoperable in deep learning pipelines.
2720

2821
- [Installation Guide](#installation-guide)
2922
- [Usage](#usage)
30-
- [Example: Rigid 2D-to-3D
31-
registration](#application-6-dof-slice-to-volume-registration)
23+
- [Example: Rigid 2D-to-3D registration](#application-6-dof-slice-to-volume-registration)
3224
- [How does `DiffDRR` work?](#how-does-diffdrr-work)
3325
- [Citing `DiffDRR`](#citing-diffdrr)
3426

3527
## Installation Guide
3628

3729
To install `DiffDRR` from PyPI:
38-
39-
``` zsh
30+
```zsh
4031
pip install diffdrr
4132
```
4233

4334
## Usage
4435

45-
The following minimal example specifies the geometry of the projectional
46-
radiograph imaging system and traces rays through a CT volume:
36+
The following minimal example specifies the geometry of the projectional radiograph imaging system and traces rays through a CT volume:
4737

4838
``` python
4939
import matplotlib.pyplot as plt
@@ -57,43 +47,34 @@ from diffdrr.visualization import plot_drr
5747
volume, spacing = load_example_ct()
5848

5949
# Initialize the DRR module for generating synthetic X-rays
50+
device = "cuda" if torch.cuda.is_available() else "cpu"
6051
drr = DRR(
61-
volume, # The CT volume as a numpy array
62-
spacing, # Voxel dimensions of the CT
63-
sdr=300.0, # Source-to-detector radius (half of the source-to-detector distance)
64-
height=200, # Height of the DRR (if width is not seperately provided, the generated image is square)
65-
delx=4.0, # Pixel spacing (in mm)
66-
batch_size=1, # How many batches of parameters will be passed = number of DRRs generated each forward pass
67-
).to("cuda" if torch.cuda.is_available() else "cpu")
68-
69-
# Rotations and translations determine the viewing angle
70-
# They must have the same batch_size as was passed to the DRR constructor
71-
# Rotations are (yaw pitch roll)
72-
# Translations are (bx by bz)
73-
rotations = torch.tensor([[torch.pi, 0.0, torch.pi / 2]])
74-
translations = torch.tensor(volume.shape) * torch.tensor(spacing) / 2
75-
translations = translations.unsqueeze(0)
76-
77-
# Generate the DRR
78-
drr.move_carm(rotations, translations)
79-
with torch.no_grad():
80-
img = drr() # Only keep the graph if optimizing DRRs
81-
ax = plot_drr(img)
52+
volume, # The CT volume as a numpy array
53+
spacing, # Voxel dimensions of the CT
54+
sdr=300.0, # Source-to-detector radius (half of the source-to-detector distance)
55+
height=200, # Height of the DRR (if width is not seperately provided, the generated image is square)
56+
delx=4.0, # Pixel spacing (in mm)
57+
).to(device)
58+
59+
# Set the camera pose with rotations (yaw, pitch, roll) and translations (x, y, z)
60+
rotations = torch.tensor([[torch.pi, 0.0, torch.pi / 2]], device=device)
61+
bx, by, bz = torch.tensor(volume.shape) * torch.tensor(spacing) / 2
62+
translations = torch.tensor([[bx, by, bz]], device=device)
63+
64+
# Make the DRR
65+
img = drr(rotations, translations)
66+
plot_drr(img, ticks=False)
8267
plt.show()
8368
```
8469

85-
![](index_files/figure-commonmark/cell-2-output-1.png)
70+
![](notebooks/index_files/figure-commonmark/cell-2-output-1.png)
8671

8772
On a single NVIDIA RTX 2080 Ti GPU, producing such an image takes
8873

89-
``` python
90-
%timeit drr()
91-
```
92-
93-
34.9 ms ± 110 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
74+
34.9 ms ± 32.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
9475

9576
The full example is available at
96-
[`tutorials/introduction.ipynb`](tutorials/introduction.ipynb).
77+
[`introduction.ipynb`](https://vivekg.dev/DiffDRR/tutorials/introduction.html).
9778

9879
## Application: 6-DoF Slice-to-Volume Registration
9980

@@ -111,7 +92,7 @@ optimization runs like this:
11192
![](https://cdn.githubraw.com/eigenvivek/DiffDRR/7a6a44aeab58d19cc7a4afabfc5aabab3a494974/experiments/registration/results/momentum_dampen/gifs/converged/649.gif)
11293

11394
The full example is available at
114-
[`tutorials/optimizers.ipynb`](tutorials/optimizers.ipynb).
95+
[`optimizers.ipynb`](https://vivekg.dev/DiffDRR/tutorials/optimizers.html).
11596

11697
## How does `DiffDRR` work?
11798

diffdrr/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.3.2"
1+
__version__ = "0.3.3"

diffdrr/_modidx.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
'diffdrr.drr': { 'diffdrr.drr.DRR': ('api/drr.html#drr', 'diffdrr/drr.py'),
1919
'diffdrr.drr.DRR.__init__': ('api/drr.html#drr.__init__', 'diffdrr/drr.py'),
2020
'diffdrr.drr.DRR.forward': ('api/drr.html#drr.forward', 'diffdrr/drr.py'),
21-
'diffdrr.drr.DRR.move_carm': ('api/drr.html#drr.move_carm', 'diffdrr/drr.py'),
22-
'diffdrr.drr.DRR.reshape_transform': ('api/drr.html#drr.reshape_transform', 'diffdrr/drr.py')},
21+
'diffdrr.drr.DRR.reshape_transform': ('api/drr.html#drr.reshape_transform', 'diffdrr/drr.py'),
22+
'diffdrr.drr.Registration': ('api/drr.html#registration', 'diffdrr/drr.py'),
23+
'diffdrr.drr.Registration.__init__': ('api/drr.html#registration.__init__', 'diffdrr/drr.py'),
24+
'diffdrr.drr.Registration.forward': ('api/drr.html#registration.forward', 'diffdrr/drr.py')},
2325
'diffdrr.metrics': { 'diffdrr.metrics.XCorr2': ('api/metrics.html#xcorr2', 'diffdrr/metrics.py'),
2426
'diffdrr.metrics.XCorr2.__init__': ('api/metrics.html#xcorr2.__init__', 'diffdrr/metrics.py'),
2527
'diffdrr.metrics.XCorr2.forward': ('api/metrics.html#xcorr2.forward', 'diffdrr/metrics.py')},

diffdrr/drr.py

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,18 @@
66
import numpy as np
77
import torch
88
import torch.nn as nn
9-
109
from fastcore.basics import patch
1110

12-
from .siddon import siddon_raycast
1311
from .detector import Detector
12+
from .siddon import siddon_raycast
1413
from .utils import reshape_subsampled_drr
1514

1615
# %% auto 0
17-
__all__ = ['DRR']
16+
__all__ = ['DRR', 'Registration']
1817

1918
# %% ../notebooks/api/00_drr.ipynb 5
2019
class DRR(nn.Module):
21-
"""Torch module that computes differentiable digitally reconstructed radiographs."""
20+
"""PyTorch module that computes differentiable digitally reconstructed radiographs."""
2221

2322
def __init__(
2423
self,
@@ -31,16 +30,13 @@ def __init__(
3130
| None = None, # Width of the rendered DRR (if not provided, set to `height`)
3231
dely: float | None = None, # Y-axis pixel size (if not provided, set to `delx`)
3332
p_subsample: float | None = None, # Proportion of pixels to randomly subsample
34-
reshape: bool = True, # Return DRR with shape (b, h, w)
33+
reshape: bool = True, # Return DRR with shape (b, 1, h, w)
3534
convention: str = "diffdrr", # Either `diffdrr` or `deepdrr`, order of basis matrix multiplication
36-
batch_size: int = 1, # Number of DRRs to generate per forward pass
35+
patch_size: int
36+
| None = None, # Render patches of the DRR in series (useful for large DRRs)
3737
):
3838
super().__init__()
3939

40-
params = torch.empty(batch_size, 6)
41-
self.rotations = nn.Parameter(params[..., :3])
42-
self.translations = nn.Parameter(params[..., 3:])
43-
4440
# Initialize the X-ray detector
4541
width = height if width is None else width
4642
dely = delx if dely is None else dely
@@ -60,9 +56,9 @@ def __init__(
6056
self.register_buffer("spacing", torch.tensor(spacing))
6157
self.register_buffer("volume", torch.tensor(volume).flip([0]))
6258
self.reshape = reshape
63-
64-
# Dummy tensor for device and dtype
65-
self.register_buffer("dummy", torch.tensor([0.0]))
59+
self.patch_size = patch_size
60+
if self.patch_size is not None:
61+
self.n_patches = (height * width) // (self.patch_size**2)
6662

6763
def reshape_transform(self, img, batch_size):
6864
if self.reshape:
@@ -74,18 +70,41 @@ def reshape_transform(self, img, batch_size):
7470

7571
# %% ../notebooks/api/00_drr.ipynb 7
7672
@patch
77-
def move_carm(self: DRR, rotations: torch.Tensor, translations: torch.Tensor):
78-
state_dict = self.state_dict()
79-
state_dict["rotations"].copy_(rotations)
80-
state_dict["translations"].copy_(translations)
81-
82-
# %% ../notebooks/api/00_drr.ipynb 8
83-
@patch
84-
def forward(self: DRR):
73+
def forward(self: DRR, rotations: torch.Tensor, translations: torch.Tensor):
8574
"""Generate DRR with rotations and translations parameters."""
75+
assert len(rotations) == len(translations)
76+
batch_size = len(rotations)
8677
source, target = self.detector.make_xrays(
87-
rotations=self.rotations,
88-
translations=self.translations,
78+
rotations=rotations,
79+
translations=translations,
8980
)
90-
img = siddon_raycast(source, target, self.volume, self.spacing)
91-
return self.reshape_transform(img, batch_size=len(self.rotations))
81+
82+
if self.patch_size is not None:
83+
n_points = target.shape[1] // self.n_patches
84+
img = []
85+
for idx in range(self.n_patches):
86+
t = target[:, idx * n_points : (idx + 1) * n_points]
87+
partial = siddon_raycast(source, t, self.volume, self.spacing)
88+
img.append(partial)
89+
img = torch.cat(img, dim=1)
90+
else:
91+
img = siddon_raycast(source, target, self.volume, self.spacing)
92+
return self.reshape_transform(img, batch_size=batch_size)
93+
94+
# %% ../notebooks/api/00_drr.ipynb 9
95+
class Registration(nn.Module):
96+
"""Perform automatic 2D-to-3D registration using differentiable rendering."""
97+
98+
def __init__(
99+
self,
100+
drr: DRR,
101+
rotations: torch.Tensor,
102+
translations: torch.Tensor,
103+
):
104+
super().__init__()
105+
self.drr = drr
106+
self.rotations = nn.Parameter(rotations)
107+
self.translations = nn.Parameter(translations)
108+
109+
def forward(self):
110+
return self.drr(self.rotations, self.translations)

diffdrr/visualization.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def animate(
6060
drr: DRR,
6161
ground_truth: torch.Tensor | None = None,
6262
verbose: bool = True,
63+
device="cpu",
6364
**kwargs, # To pass to imageio.v3.imwrite
6465
):
6566
"""Animate the optimization of a DRR."""
@@ -96,10 +97,15 @@ def make_fig(ground_truth):
9697
for idx, row in itr:
9798
fig, ax_opt = make_fig() if ground_truth is None else make_fig(ground_truth)
9899
params = row[["theta", "phi", "gamma", "bx", "by", "bz"]].values
99-
rotations = torch.tensor(row[["theta", "phi", "gamma"]].values)
100-
translations = torch.tensor(row[["bx", "by", "bz"]].values)
101-
drr.move_carm(rotations, translations)
102-
itr = drr().detach()
100+
rotations = (
101+
torch.tensor(row[["theta", "phi", "gamma"]].values)
102+
.unsqueeze(0)
103+
.to(device)
104+
)
105+
translations = (
106+
torch.tensor(row[["bx", "by", "bz"]].values).unsqueeze(0).to(device)
107+
)
108+
itr = drr(rotations, translations)
103109
_ = plot_drr(itr, axs=ax_opt)
104110
ax_opt.set(xlabel="Moving DRR")
105111
fig.savefig(f"{tmpdir}/{idx}.png")

notebooks/_quarto.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,13 @@ website:
1111
twitter-card: true
1212
open-graph: true
1313
repo-actions: [issue]
14+
favicon: favicon.png
1415
navbar:
1516
background: primary
1617
search: true
18+
right:
19+
- icon: github
20+
href: "https://github.com/eigenvivek/DiffDRR"
1721
sidebar:
1822
style: floating
1923

0 commit comments

Comments
 (0)