Skip to content

Commit f8693ed

Browse files
First commit
1 parent 8727128 commit f8693ed

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

proxtorch/operators/tvl1_3d.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def gradient(self, x):
7676
gradients[d, ...] = F.pad(
7777
torch.diff(x, dim=d, n=1), pad=get_padding_tuple(d, x.dim())
7878
)
79-
gradients[:-1] *= 1.0 - self.l1_ratio
79+
gradients[:-1] *= (1.0 - self.l1_ratio)
8080
gradients[-1] = self.l1_ratio * x
8181
return gradients
8282

@@ -139,7 +139,7 @@ def _projector_on_tvl1_dual(self, grad):
139139

140140
return grad
141141

142-
def _dual_gap_prox_tvl1(self, input_img_norm, new, gap, weight, l1_ratio=1.0):
142+
def _dual_gap_prox_tvl1(self, input_img_norm, new, gap, weight):
143143
"""
144144
Compute the dual gap of total variation denoising.
145145
@@ -148,7 +148,6 @@ def _dual_gap_prox_tvl1(self, input_img_norm, new, gap, weight, l1_ratio=1.0):
148148
new (torch.Tensor): Updated tensor.
149149
gap (torch.Tensor): Gap tensor.
150150
weight (float): Regularization strength.
151-
l1_ratio (float, optional): The L1 ratio. Defaults to 1.0.
152151
153152
Returns:
154153
float: Dual gap value.
@@ -229,7 +228,7 @@ def prox(self, x: torch.Tensor, lr: float) -> torch.Tensor:
229228
if i % 4 == 0:
230229
old_dgap = dgap
231230
dgap = self._dual_gap_prox_tvl1(
232-
input_img_norm, -negated_output, gap, weight, l1_ratio=self.l1_ratio
231+
input_img_norm, -negated_output, gap, weight
233232
)
234233
if dgap < self.tol:
235234
break

0 commit comments

Comments
 (0)