Skip to content

Commit 01e557c

Browse files
committed
Improve argument unzipping
1 parent 1e930f5 commit 01e557c

File tree

2 files changed

+28
-26
lines changed

2 files changed

+28
-26
lines changed

diffdrr/renderers.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,12 @@ def forward(self, volume, spacing, source, target):
4343
# %% ../notebooks/api/01_renderers.ipynb 8
4444
def _get_alphas(source, target, spacing, dims, eps):
4545
# Get the CT sizing and spacing parameters
46-
dx, dy, dz = spacing
47-
nx, ny, nz = dims
46+
alphax = torch.arange(dims[0]).to(source) * spacing[0]
47+
alphay = torch.arange(dims[1]).to(source) * spacing[1]
48+
alphaz = torch.arange(dims[2]).to(source) * spacing[2]
4849

4950
# Get the alpha at each plane intersection
5051
sx, sy, sz = source[..., 0], source[..., 1], source[..., 2]
51-
alphax = torch.arange(nx).to(source) * dx
52-
alphay = torch.arange(ny).to(source) * dy
53-
alphaz = torch.arange(nz).to(source) * dz
5452
alphax = alphax.expand(len(source), 1, -1) - sx.unsqueeze(-1)
5553
alphay = alphay.expand(len(source), 1, -1) - sy.unsqueeze(-1)
5654
alphaz = alphaz.expand(len(source), 1, -1) - sz.unsqueeze(-1)
@@ -67,9 +65,11 @@ def _get_alphas(source, target, spacing, dims, eps):
6765
alphas[~good_idxs] = torch.nan
6866

6967
# Sort the alphas by ray, putting nans at the end of the list
70-
# Drop indices where alphas for all rays are nan
7168
alphas = torch.sort(alphas, dim=-1).values
69+
70+
# Drop indices where alphas for all rays are nan
7271
alphas = alphas[..., ~alphas.isnan().all(dim=0).all(dim=0)]
72+
7373
return alphas
7474

7575

@@ -110,7 +110,7 @@ def _get_index(alpha, source, target, spacing, dims, maxidx, eps):
110110
idxs[idxs > maxidx] = maxidx
111111
return idxs
112112

113-
# %% ../notebooks/api/01_renderers.ipynb 10
113+
# %% ../notebooks/api/01_renderers.ipynb 12
114114
from torch.nn.functional import grid_sample
115115

116116

notebooks/api/01_renderers.ipynb

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,12 @@
151151
"#| export\n",
152152
"def _get_alphas(source, target, spacing, dims, eps):\n",
153153
" # Get the CT sizing and spacing parameters\n",
154-
" dx, dy, dz = spacing\n",
155-
" nx, ny, nz = dims\n",
154+
" alphax = torch.arange(dims[0]).to(source) * spacing[0]\n",
155+
" alphay = torch.arange(dims[1]).to(source) * spacing[1]\n",
156+
" alphaz = torch.arange(dims[2]).to(source) * spacing[2]\n",
156157
"\n",
157158
" # Get the alpha at each plane intersection\n",
158159
" sx, sy, sz = source[..., 0], source[..., 1], source[..., 2]\n",
159-
" alphax = torch.arange(nx).to(source) * dx\n",
160-
" alphay = torch.arange(ny).to(source) * dy\n",
161-
" alphaz = torch.arange(nz).to(source) * dz\n",
162160
" alphax = alphax.expand(len(source), 1, -1) - sx.unsqueeze(-1)\n",
163161
" alphay = alphay.expand(len(source), 1, -1) - sy.unsqueeze(-1)\n",
164162
" alphaz = alphaz.expand(len(source), 1, -1) - sz.unsqueeze(-1)\n",
@@ -175,9 +173,11 @@
175173
" alphas[~good_idxs] = torch.nan\n",
176174
"\n",
177175
" # Sort the alphas by ray, putting nans at the end of the list\n",
178-
" # Drop indices where alphas for all rays are nan\n",
179176
" alphas = torch.sort(alphas, dim=-1).values\n",
177+
" \n",
178+
" # Drop indices where alphas for all rays are nan\n",
180179
" alphas = alphas[..., ~alphas.isnan().all(dim=0).all(dim=0)]\n",
180+
"\n",
181181
" return alphas\n",
182182
"\n",
183183
"\n",
@@ -219,6 +219,20 @@
219219
" return idxs"
220220
]
221221
},
222+
{
223+
"cell_type": "markdown",
224+
"metadata": {},
225+
"source": [
226+
"## Compilier-friendly Siddon"
227+
]
228+
},
229+
{
230+
"cell_type": "code",
231+
"execution_count": null,
232+
"metadata": {},
233+
"outputs": [],
234+
"source": []
235+
},
222236
{
223237
"cell_type": "markdown",
224238
"metadata": {},
@@ -238,19 +252,7 @@
238252
"cell_type": "code",
239253
"execution_count": null,
240254
"metadata": {},
241-
"outputs": [
242-
{
243-
"ename": "NameError",
244-
"evalue": "name 'torch' is not defined",
245-
"output_type": "error",
246-
"traceback": [
247-
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
248-
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
249-
"Cell \u001b[0;32mIn[1], line 5\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m#| export\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mnn\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfunctional\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m grid_sample\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mTrilinear\u001b[39;00m(\u001b[43mtorch\u001b[49m\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mModule):\n\u001b[1;32m 6\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;124;03m Instead of computing the exact line integral over the voxel grid (i.e., Siddon's method), we can sample colors at points along the each ray using trilinear interpolation.\u001b[39;00m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;124;03m \u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;124;03m where $\\mathbf V[\\cdot]$ is the trilinear interpolation function and $M$ is the number of points sampled per ray.\u001b[39;00m\n\u001b[1;32m 14\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 17\u001b[0m near\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.0\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 20\u001b[0m eps\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-8\u001b[39m,\n\u001b[1;32m 21\u001b[0m ):\n",
250-
"\u001b[0;31mNameError\u001b[0m: name 'torch' is not defined"
251-
]
252-
}
253-
],
255+
"outputs": [],
254256
"source": [
255257
"#| export\n",
256258
"from torch.nn.functional import grid_sample\n",

0 commit comments

Comments
 (0)