Skip to content

Commit e230877

Browse files
authored
Merge pull request #160 from eigenvivek/eps
Expose InstanceNorm eps argument
2 parents 9500f4d + 44ac35d commit e230877

File tree

2 files changed

+16
-13
lines changed

2 files changed

+16
-13
lines changed

diffdrr/metrics.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
class NormalizedCrossCorrelation2d(torch.nn.Module):
1414
"""Compute Normalized Cross Correlation between two batches of images."""
1515

16-
def __init__(self, patch_size=None):
16+
def __init__(self, patch_size=None, eps=1e-5):
1717
super().__init__()
18-
self.norm = torch.nn.InstanceNorm2d(num_features=1)
18+
self.norm = torch.nn.InstanceNorm2d(num_features=1, eps=eps)
1919
self.patch_size = patch_size
2020

2121
def forward(self, x1, x2):
@@ -33,9 +33,9 @@ def forward(self, x1, x2):
3333
class MultiscaleNormalizedCrossCorrelation2d(torch.nn.Module):
3434
"""Compute Normalized Cross Correlation between two batches of images at multiple scales."""
3535

36-
def __init__(self, patch_sizes=[None], patch_weights=[1.0]):
36+
def __init__(self, patch_sizes=[None], patch_weights=[1.0], eps=1e-5):
3737
super().__init__()
38-
self.norm = torch.nn.InstanceNorm2d(num_features=1)
38+
self.norm = torch.nn.InstanceNorm2d(num_features=1, eps=eps)
3939

4040
assert len(patch_sizes) == len(patch_weights), "Each scale must have a weight"
4141
self.nccs = [
@@ -61,8 +61,8 @@ def to_patches(x, patch_size):
6161
class GradientNormalizedCrossCorrelation2d(NormalizedCrossCorrelation2d):
6262
"""Compute Normalized Cross Correlation between the image gradients of two batches of images."""
6363

64-
def __init__(self, patch_size=None, sigma=1.0):
65-
super().__init__(patch_size)
64+
def __init__(self, patch_size=None, sigma=1.0, **kwargs):
65+
super().__init__(patch_size, **kwargs)
6666
self.sobel = Sobel(sigma)
6767

6868
def forward(self, x1, x2):

notebooks/api/05_metrics.ipynb

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@
5858
"class NormalizedCrossCorrelation2d(torch.nn.Module):\n",
5959
" \"\"\"Compute Normalized Cross Correlation between two batches of images.\"\"\"\n",
6060
"\n",
61-
" def __init__(self, patch_size=None):\n",
61+
" def __init__(self, patch_size=None, eps=1e-5):\n",
6262
" super().__init__()\n",
63-
" self.norm = torch.nn.InstanceNorm2d(num_features=1)\n",
63+
" self.norm = torch.nn.InstanceNorm2d(num_features=1, eps=eps)\n",
6464
" self.patch_size = patch_size\n",
6565
"\n",
6666
" def forward(self, x1, x2):\n",
@@ -78,9 +78,9 @@
7878
"class MultiscaleNormalizedCrossCorrelation2d(torch.nn.Module):\n",
7979
" \"\"\"Compute Normalized Cross Correlation between two batches of images at multiple scales.\"\"\"\n",
8080
"\n",
81-
" def __init__(self, patch_sizes=[None], patch_weights=[1.0]):\n",
81+
" def __init__(self, patch_sizes=[None], patch_weights=[1.0], eps=1e-5):\n",
8282
" super().__init__()\n",
83-
" self.norm = torch.nn.InstanceNorm2d(num_features=1)\n",
83+
" self.norm = torch.nn.InstanceNorm2d(num_features=1, eps=eps)\n",
8484
" \n",
8585
" assert len(patch_sizes) == len(patch_weights), \"Each scale must have a weight\"\n",
8686
" self.nccs = [NormalizedCrossCorrelation2d(patch_size) for patch_size in patch_sizes]\n",
@@ -120,8 +120,8 @@
120120
"class GradientNormalizedCrossCorrelation2d(NormalizedCrossCorrelation2d):\n",
121121
" \"\"\"Compute Normalized Cross Correlation between the image gradients of two batches of images.\"\"\"\n",
122122
"\n",
123-
" def __init__(self, patch_size=None, sigma=1.0):\n",
124-
" super().__init__(patch_size)\n",
123+
" def __init__(self, patch_size=None, sigma=1.0, **kwargs):\n",
124+
" super().__init__(patch_size, **kwargs)\n",
125125
" self.sobel = Sobel(sigma)\n",
126126
"\n",
127127
" def forward(self, x1, x2):\n",
@@ -172,7 +172,7 @@
172172
{
173173
"data": {
174174
"text/plain": [
175-
"tensor([-0.0190, 0.0077, 0.0057, 0.0140, -0.0191, 0.0089, -0.0021, 0.0083])"
175+
"tensor([ 0.0220, -0.0021, 0.0235, 0.0114, -0.0079, -0.0089, 0.0051, -0.0099])"
176176
]
177177
},
178178
"execution_count": null,
@@ -187,6 +187,9 @@
187187
"ncc = NormalizedCrossCorrelation2d()\n",
188188
"ncc(x1, x2)\n",
189189
"\n",
190+
"ncc = NormalizedCrossCorrelation2d(eps=1e-1)\n",
191+
"ncc(x1, x2)\n",
192+
"\n",
190193
"ncc = NormalizedCrossCorrelation2d(patch_size=9)\n",
191194
"ncc(x1, x2)\n",
192195
"\n",

0 commit comments

Comments
 (0)