|
58 | 58 | "class NormalizedCrossCorrelation2d(torch.nn.Module):\n",
|
59 | 59 | " \"\"\"Compute Normalized Cross Correlation between two batches of images.\"\"\"\n",
|
60 | 60 | "\n",
|
61 |
| - " def __init__(self, patch_size=None):\n", |
| 61 | + " def __init__(self, patch_size=None, eps=1e-5):\n", |
62 | 62 | " super().__init__()\n",
|
63 |
| - " self.norm = torch.nn.InstanceNorm2d(num_features=1)\n", |
| 63 | + " self.norm = torch.nn.InstanceNorm2d(num_features=1, eps=eps)\n", |
64 | 64 | " self.patch_size = patch_size\n",
|
65 | 65 | "\n",
|
66 | 66 | " def forward(self, x1, x2):\n",
|
|
78 | 78 | "class MultiscaleNormalizedCrossCorrelation2d(torch.nn.Module):\n",
|
79 | 79 | " \"\"\"Compute Normalized Cross Correlation between two batches of images at multiple scales.\"\"\"\n",
|
80 | 80 | "\n",
|
81 |
| - " def __init__(self, patch_sizes=[None], patch_weights=[1.0]):\n", |
| 81 | + " def __init__(self, patch_sizes=[None], patch_weights=[1.0], eps=1e-5):\n", |
82 | 82 | " super().__init__()\n",
|
83 |
| - " self.norm = torch.nn.InstanceNorm2d(num_features=1)\n", |
| 83 | + " self.norm = torch.nn.InstanceNorm2d(num_features=1, eps=eps)\n", |
84 | 84 | " \n",
|
85 | 85 | " assert len(patch_sizes) == len(patch_weights), \"Each scale must have a weight\"\n",
|
86 | 86 | " self.nccs = [NormalizedCrossCorrelation2d(patch_size) for patch_size in patch_sizes]\n",
|
|
120 | 120 | "class GradientNormalizedCrossCorrelation2d(NormalizedCrossCorrelation2d):\n",
|
121 | 121 | " \"\"\"Compute Normalized Cross Correlation between the image gradients of two batches of images.\"\"\"\n",
|
122 | 122 | "\n",
|
123 |
| - " def __init__(self, patch_size=None, sigma=1.0):\n", |
124 |
| - " super().__init__(patch_size)\n", |
| 123 | + " def __init__(self, patch_size=None, sigma=1.0, **kwargs):\n", |
| 124 | + " super().__init__(patch_size, **kwargs)\n", |
125 | 125 | " self.sobel = Sobel(sigma)\n",
|
126 | 126 | "\n",
|
127 | 127 | " def forward(self, x1, x2):\n",
|
|
172 | 172 | {
|
173 | 173 | "data": {
|
174 | 174 | "text/plain": [
|
175 |
| - "tensor([-0.0190, 0.0077, 0.0057, 0.0140, -0.0191, 0.0089, -0.0021, 0.0083])" |
| 175 | + "tensor([ 0.0220, -0.0021, 0.0235, 0.0114, -0.0079, -0.0089, 0.0051, -0.0099])" |
176 | 176 | ]
|
177 | 177 | },
|
178 | 178 | "execution_count": null,
|
|
187 | 187 | "ncc = NormalizedCrossCorrelation2d()\n",
|
188 | 188 | "ncc(x1, x2)\n",
|
189 | 189 | "\n",
|
| 190 | + "ncc = NormalizedCrossCorrelation2d(eps=1e-1)\n", |
| 191 | + "ncc(x1, x2)\n", |
| 192 | + "\n", |
190 | 193 | "ncc = NormalizedCrossCorrelation2d(patch_size=9)\n",
|
191 | 194 | "ncc(x1, x2)\n",
|
192 | 195 | "\n",
|
|
0 commit comments