@@ -179,24 +179,65 @@ def prox(self, x: torch.Tensor, lr: float) -> torch.Tensor:
179
179
180
180
Notes
181
181
-----
182
- The principle of total variation denoising is explained in
182
+ Total variation denoising aims to minimize the total variation of the image,
183
+ which can be roughly described as the integral of the norm of the image gradient.
184
+ As a result, it produces "cartoon-like" images, i.e., piecewise-constant images.
185
+ For more details, refer to:
183
186
http://en.wikipedia.org/wiki/Total_variation_denoising
184
187
185
- The principle of total variation denoising is to minimize the
186
- total variation of the image, which can be roughly described as
187
- the integral of the norm of the image gradient. Total variation
188
- denoising tends to produce "cartoon-like" images, that is,
189
- piecewise-constant images.
190
-
191
188
This function implements the FISTA (Fast Iterative Shrinkage
192
189
Thresholding Algorithm) algorithm of Beck et Teboulle, adapted to
193
190
total variation denoising in "Fast gradient-based algorithms for
194
191
constrained total variation image denoising and deblurring problems"
195
192
(2009).
196
193
197
- For details on implementing the bound constraints, read the aforementioned
198
- Beck and Teboulle paper.
194
+ For more on bound constraints implementation, see the aforementioned Beck and Teboulle paper.
199
195
"""
196
+ fista = True
197
+ weight = self .alpha * lr
198
+ input_shape = x .shape
199
+ input_img_norm = torch .norm (x ) ** 2
200
+ lipschitz_constant = 1.1 * (4 * 3 )
201
+ negated_output = - x
202
+ grad_aux = torch .zeros_like (self .gradient (x ))
203
+ grad_im = torch .zeros_like (grad_aux )
204
+ t = 1.0
205
+ i = 0
206
+ dgap = torch .tensor (float ("inf" )).to (x .device )
207
+ while i < self .max_iter :
208
+ # tv_prev = self.tv_from_grad(self.gradient(output))
209
+ grad_tmp = self .gradient (negated_output )
210
+ grad_tmp *= 1.0 / (lipschitz_constant * weight )
211
+ grad_aux += grad_tmp
212
+ grad_tmp = self ._projector_on_tvl1_dual (grad_aux )
213
+
214
+ # Careful, in the next few lines, grad_tmp and grad_aux are a
215
+ # view on the same array, as _projector_on_tvl1_dual returns a view
216
+ # on the input array
217
+ t_new = 0.5 * (1 + sqrt (1 + 4 * t ** 2 ))
218
+ t_factor = (t - 1 ) / t_new
219
+ if fista :
220
+ # fista
221
+ grad_aux = (1 + t_factor ) * grad_tmp - t_factor * grad_im
222
+ else :
223
+ # ista
224
+ grad_aux = grad_tmp
225
+ grad_im = grad_tmp
226
+ t = t_new
227
+ gap = weight * self .divergence (grad_aux )
228
+ negated_output = gap - x
229
+ if i % 4 == 0 :
230
+ old_dgap = dgap
231
+ dgap = self ._dual_gap_prox_tvl1 (
232
+ input_img_norm , - negated_output , gap , weight , l1_ratio = self .l1_ratio
233
+ )
234
+ if dgap < 5.0e-5 :
235
+ break
236
+ if old_dgap < dgap :
237
+ fista = False
238
+ i += 1
239
+ output = x - weight * self .divergence (grad_im )
240
+ return output .reshape (input_shape )
200
241
201
242
def _nonsmooth (self , x : torch .Tensor ) -> torch .Tensor :
202
243
"""
0 commit comments