@@ -76,7 +76,7 @@ def gradient(self, x):
76
76
gradients [d , ...] = F .pad (
77
77
torch .diff (x , dim = d , n = 1 ), pad = get_padding_tuple (d , x .dim ())
78
78
)
79
- gradients [:- 1 ] *= 1.0 - self .l1_ratio
79
+ gradients [:- 1 ] *= ( 1.0 - self .l1_ratio )
80
80
gradients [- 1 ] = self .l1_ratio * x
81
81
return gradients
82
82
@@ -139,7 +139,7 @@ def _projector_on_tvl1_dual(self, grad):
139
139
140
140
return grad
141
141
142
- def _dual_gap_prox_tvl1 (self , input_img_norm , new , gap , weight , l1_ratio = 1.0 ):
142
+ def _dual_gap_prox_tvl1 (self , input_img_norm , new , gap , weight ):
143
143
"""
144
144
Compute the dual gap of total variation denoising.
145
145
@@ -148,7 +148,6 @@ def _dual_gap_prox_tvl1(self, input_img_norm, new, gap, weight, l1_ratio=1.0):
148
148
new (torch.Tensor): Updated tensor.
149
149
gap (torch.Tensor): Gap tensor.
150
150
weight (float): Regularization strength.
151
- l1_ratio (float, optional): The L1 ratio. Defaults to 1.0.
152
151
153
152
Returns:
154
153
float: Dual gap value.
@@ -229,7 +228,7 @@ def prox(self, x: torch.Tensor, lr: float) -> torch.Tensor:
229
228
if i % 4 == 0 :
230
229
old_dgap = dgap
231
230
dgap = self ._dual_gap_prox_tvl1 (
232
- input_img_norm , - negated_output , gap , weight , l1_ratio = self . l1_ratio
231
+ input_img_norm , - negated_output , gap , weight
233
232
)
234
233
if dgap < self .tol :
235
234
break
0 commit comments