Skip to content

Commit 9500f4d

Browse files
committed
Add multiscale variant of NCC
1 parent cdf7260 commit 9500f4d

File tree

3 files changed

+50
-3
lines changed

3 files changed

+50
-3
lines changed

diffdrr/_modidx.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@
3232
'diffdrr/metrics.py'),
3333
'diffdrr.metrics.GradientNormalizedCrossCorrelation2d.forward': ( 'api/metrics.html#gradientnormalizedcrosscorrelation2d.forward',
3434
'diffdrr/metrics.py'),
35+
'diffdrr.metrics.MultiscaleNormalizedCrossCorrelation2d': ( 'api/metrics.html#multiscalenormalizedcrosscorrelation2d',
36+
'diffdrr/metrics.py'),
37+
'diffdrr.metrics.MultiscaleNormalizedCrossCorrelation2d.__init__': ( 'api/metrics.html#multiscalenormalizedcrosscorrelation2d.__init__',
38+
'diffdrr/metrics.py'),
39+
'diffdrr.metrics.MultiscaleNormalizedCrossCorrelation2d.forward': ( 'api/metrics.html#multiscalenormalizedcrosscorrelation2d.forward',
40+
'diffdrr/metrics.py'),
3541
'diffdrr.metrics.NormalizedCrossCorrelation2d': ( 'api/metrics.html#normalizedcrosscorrelation2d',
3642
'diffdrr/metrics.py'),
3743
'diffdrr.metrics.NormalizedCrossCorrelation2d.__init__': ( 'api/metrics.html#normalizedcrosscorrelation2d.__init__',

diffdrr/metrics.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.nn as nn
88

99
# %% auto 0
10-
__all__ = ['NormalizedCrossCorrelation2d', 'GradientNormalizedCrossCorrelation2d']
10+
__all__ = ['NormalizedCrossCorrelation2d', 'MultiscaleNormalizedCrossCorrelation2d', 'GradientNormalizedCrossCorrelation2d']
1111

1212
# %% ../notebooks/api/05_metrics.ipynb 4
1313
class NormalizedCrossCorrelation2d(torch.nn.Module):
@@ -29,6 +29,26 @@ def forward(self, x1, x2):
2929
score /= c * h * w
3030
return score
3131

32+
33+
class MultiscaleNormalizedCrossCorrelation2d(torch.nn.Module):
34+
"""Compute Normalized Cross Correlation between two batches of images at multiple scales."""
35+
36+
def __init__(self, patch_sizes=[None], patch_weights=[1.0]):
37+
super().__init__()
38+
self.norm = torch.nn.InstanceNorm2d(num_features=1)
39+
40+
assert len(patch_sizes) == len(patch_weights), "Each scale must have a weight"
41+
self.nccs = [
42+
NormalizedCrossCorrelation2d(patch_size) for patch_size in patch_sizes
43+
]
44+
self.patch_weights = patch_weights
45+
46+
def forward(self, x1, x2):
47+
scores = []
48+
for weight, ncc in zip(self.patch_weights, self.nccs):
49+
scores.append(weight * ncc(x1, x2))
50+
return torch.stack(scores, dim=0).sum(dim=0)
51+
3252
# %% ../notebooks/api/05_metrics.ipynb 5
3353
from einops import rearrange
3454

notebooks/api/05_metrics.ipynb

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,25 @@
7272
" x1, x2 = self.norm(x1), self.norm(x2)\n",
7373
" score = torch.einsum(\"b...,b...->b\", x1, x2)\n",
7474
" score /= c * h * w\n",
75-
" return score"
75+
" return score\n",
76+
"\n",
77+
"\n",
78+
"class MultiscaleNormalizedCrossCorrelation2d(torch.nn.Module):\n",
79+
" \"\"\"Compute Normalized Cross Correlation between two batches of images at multiple scales.\"\"\"\n",
80+
"\n",
81+
" def __init__(self, patch_sizes=[None], patch_weights=[1.0]):\n",
82+
" super().__init__()\n",
83+
" self.norm = torch.nn.InstanceNorm2d(num_features=1)\n",
84+
" \n",
85+
" assert len(patch_sizes) == len(patch_weights), \"Each scale must have a weight\"\n",
86+
" self.nccs = [NormalizedCrossCorrelation2d(patch_size) for patch_size in patch_sizes]\n",
87+
" self.patch_weights = patch_weights\n",
88+
"\n",
89+
" def forward(self, x1, x2):\n",
90+
" scores = []\n",
91+
" for weight, ncc in zip(self.patch_weights, self.nccs):\n",
92+
" scores.append(weight * ncc(x1, x2))\n",
93+
" return torch.stack(scores, dim=0).sum(dim=0)"
7694
]
7795
},
7896
{
@@ -154,7 +172,7 @@
154172
{
155173
"data": {
156174
"text/plain": [
157-
"tensor([ 0.0039, -0.0092, 0.0008, 0.0022, -0.0049, -0.0204, -0.0088, 0.0056])"
175+
"tensor([-0.0190, 0.0077, 0.0057, 0.0140, -0.0191, 0.0089, -0.0021, 0.0083])"
158176
]
159177
},
160178
"execution_count": null,
@@ -172,6 +190,9 @@
172190
"ncc = NormalizedCrossCorrelation2d(patch_size=9)\n",
173191
"ncc(x1, x2)\n",
174192
"\n",
193+
"msncc = MultiscaleNormalizedCrossCorrelation2d(patch_sizes=[9, None], patch_weights=[0.5, 0.5])\n",
194+
"msncc(x1, x2)\n",
195+
"\n",
175196
"gncc = GradientNormalizedCrossCorrelation2d()\n",
176197
"gncc(x1, x2)\n",
177198
"\n",

0 commit comments

Comments
 (0)