Skip to content

Commit fa41254

Browse files
First commit
1 parent 154b6cb commit fa41254

File tree

1 file changed

+50
-9
lines changed

1 file changed

+50
-9
lines changed

proxtorch/operators/tvl1_3d.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -179,24 +179,65 @@ def prox(self, x: torch.Tensor, lr: float) -> torch.Tensor:
179179
180180
Notes
181181
-----
182-
The principle of total variation denoising is explained in
182+
Total variation denoising aims to minimize the total variation of the image,
183+
which can be roughly described as the integral of the norm of the image gradient.
184+
As a result, it produces "cartoon-like" images, i.e., piecewise-constant images.
185+
For more details, refer to:
183186
http://en.wikipedia.org/wiki/Total_variation_denoising
184187
185-
The principle of total variation denoising is to minimize the
186-
total variation of the image, which can be roughly described as
187-
the integral of the norm of the image gradient. Total variation
188-
denoising tends to produce "cartoon-like" images, that is,
189-
piecewise-constant images.
190-
191188
This function implements the FISTA (Fast Iterative Shrinkage
192189
Thresholding Algorithm) algorithm of Beck et Teboulle, adapted to
193190
total variation denoising in "Fast gradient-based algorithms for
194191
constrained total variation image denoising and deblurring problems"
195192
(2009).
196193
197-
For details on implementing the bound constraints, read the aforementioned
198-
Beck and Teboulle paper.
194+
For more on bound constraints implementation, see the aforementioned Beck and Teboulle paper.
199195
"""
196+
fista = True
197+
weight = self.alpha * lr
198+
input_shape = x.shape
199+
input_img_norm = torch.norm(x) ** 2
200+
lipschitz_constant = 1.1 * (4 * 3)
201+
negated_output = -x
202+
grad_aux = torch.zeros_like(self.gradient(x))
203+
grad_im = torch.zeros_like(grad_aux)
204+
t = 1.0
205+
i = 0
206+
dgap = torch.tensor(float("inf")).to(x.device)
207+
while i < self.max_iter:
208+
# tv_prev = self.tv_from_grad(self.gradient(output))
209+
grad_tmp = self.gradient(negated_output)
210+
grad_tmp *= 1.0 / (lipschitz_constant * weight)
211+
grad_aux += grad_tmp
212+
grad_tmp = self._projector_on_tvl1_dual(grad_aux)
213+
214+
# Careful, in the next few lines, grad_tmp and grad_aux are a
215+
# view on the same array, as _projector_on_tvl1_dual returns a view
216+
# on the input array
217+
t_new = 0.5 * (1 + sqrt(1 + 4 * t**2))
218+
t_factor = (t - 1) / t_new
219+
if fista:
220+
# fista
221+
grad_aux = (1 + t_factor) * grad_tmp - t_factor * grad_im
222+
else:
223+
# ista
224+
grad_aux = grad_tmp
225+
grad_im = grad_tmp
226+
t = t_new
227+
gap = weight * self.divergence(grad_aux)
228+
negated_output = gap - x
229+
if i % 4 == 0:
230+
old_dgap = dgap
231+
dgap = self._dual_gap_prox_tvl1(
232+
input_img_norm, -negated_output, gap, weight, l1_ratio=self.l1_ratio
233+
)
234+
if dgap < 5.0e-5:
235+
break
236+
if old_dgap < dgap:
237+
fista = False
238+
i += 1
239+
output = x - weight * self.divergence(grad_im)
240+
return output.reshape(input_shape)
200241

201242
def _nonsmooth(self, x: torch.Tensor) -> torch.Tensor:
202243
"""

0 commit comments

Comments
 (0)