Skip to content

Commit 4b38156

Browse files
First commit
1 parent 4c2b7ed commit 4b38156

File tree

5 files changed

+58
-57
lines changed

5 files changed

+58
-57
lines changed

proxtorch/operators/graphnet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ def prox(self, x: torch.Tensor, tau: float) -> torch.Tensor:
1515

1616
def _smooth(self, x: torch.Tensor) -> torch.Tensor:
1717
# The last channel is the for the l1 norm
18-
grad = self.gradient(x)[:-1]
19-
# norm of the gradient
20-
norm = torch.norm(grad) ** 2
18+
grad = self.gradient(x)[:-1]/(1-self.l1_ratio)
19+
# sum of squares of the gradients
20+
norm = torch.sum(grad ** 2)
2121
return 0.5 * norm * self.alpha * (1 - self.l1_ratio)
2222

2323
def _nonsmooth(self, x: torch.Tensor) -> torch.Tensor:

proxtorch/operators/tv_2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ def __init__(self, alpha: float, max_iter: int = 200, tol: float = 1e-4) -> None
1111
max_iter (int, optional): Maximum iterations for the iterative algorithm. Defaults to 50.
1212
tol (float, optional): Tolerance level for early stopping. Defaults to 1e-2.
1313
"""
14-
super().__init__(alpha, max_iter, tol, l1_ratio=0.0)
14+
super().__init__(alpha, l1_ratio=0.0, max_iter=max_iter, tol=tol)

proxtorch/operators/tv_3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ def __init__(self, alpha: float, max_iter: int = 200, tol: float = 1e-4) -> None
1111
max_iter (int, optional): Maximum iterations for the iterative algorithm. Defaults to 50.
1212
tol (float, optional): Tolerance level for early stopping. Defaults to 1e-2.
1313
"""
14-
super().__init__(alpha, max_iter, tol, l1_ratio=0.0)
14+
super().__init__(alpha, l1_ratio=0.0, max_iter=max_iter, tol=tol)

proxtorch/operators/tvl1_3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ class TVL1_3D(ProxOperator):
4141
def __init__(
4242
self,
4343
alpha: float,
44+
l1_ratio=0.0,
4445
max_iter: int = 200,
4546
tol: float = 1e-4,
46-
l1_ratio=0.0,
4747
) -> None:
4848
"""
4949
Initialize the 3D Total Variation proximal operator.

test/test_graphnet.py

Lines changed: 52 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -3,57 +3,58 @@
33
from proxtorch.operators import GraphNet3D, GraphNet2D
44

55

6-
#
7-
# def test_converges_to_sparse_smooth():
8-
# torch.manual_seed(0)
9-
#
10-
# # generate a spatially sparse signal
11-
# x_true = torch.zeros(10, 10)
12-
# x_true[3:7, 3:7] = torch.ones(4, 4)
13-
# x_true = x_true.flatten()
14-
#
15-
# # generate a random matrix
16-
# A = torch.rand(100, 100)
17-
#
18-
# # generate measurements
19-
# y = A @ x_true
20-
#
21-
# # define the proximal operator
22-
# alpha = 10
23-
# l1_ratio = 0.0 # 0.5
24-
# prox = GraphNet2D(alpha, l1_ratio)
25-
#
26-
# # define the objective function
27-
# def objective(x):
28-
# return 0.5 * torch.norm(A @ x.reshape(-1) - y) ** 2
29-
#
30-
# # define the step size
31-
# tau = 1 / torch.norm(A.t() @ A)
32-
#
33-
# # initialize the solution
34-
# x = torch.nn.Parameter(torch.rand(10, 10, requires_grad=True))
35-
#
36-
# # optimizer
37-
# optimizer = torch.optim.SGD([x], lr=tau)
38-
#
39-
# # optimization loop
40-
# for i in range(1000):
41-
# optimizer.zero_grad()
42-
# obj = objective(x) + prox(x)
43-
# obj.backward()
44-
# optimizer.step()
45-
# x.data = prox.prox(x.data, tau)
46-
#
47-
# # check that the result is smooth
48-
# plt.imshow(x.data.detach().numpy())
49-
# plt.show()
50-
#
51-
# # compare with x_true
52-
# difference = torch.norm(x.data.flatten() - x_true)
53-
# assert difference < 1e-3
54-
#
55-
# # check that the result is sparse
56-
# assert torch.sum(x.data == 0) > 50
6+
def test_converges_to_sparse_smooth():
7+
import matplotlib.pyplot as plt
8+
torch.manual_seed(0)
9+
10+
# generate a spatially sparse signal
11+
x_true = torch.zeros(10, 10)
12+
x_true[3:7, 3:7] = torch.ones(4, 4)
13+
x_true = x_true.flatten()
14+
15+
# generate a random matrix
16+
A = torch.rand(100, 100)
17+
18+
# generate measurements
19+
y = A @ x_true
20+
21+
# define the proximal operator
22+
alpha = 1000
23+
l1_ratio = 0.0 # 0.5
24+
prox = GraphNet2D(alpha, l1_ratio)
25+
26+
# define the objective function
27+
def objective(x):
28+
return 0.5 * torch.norm(A @ x.reshape(-1) - y) ** 2
29+
30+
# define the step size
31+
tau = 0.1 / torch.norm(A.t() @ A)
32+
33+
# initialize the solution
34+
x = torch.nn.Parameter(torch.rand(10, 10, requires_grad=True))
35+
36+
# optimizer
37+
optimizer = torch.optim.SGD([x], lr=tau, nesterov=True, momentum=0.9)
38+
39+
# optimization loop
40+
for i in range(20000):
41+
optimizer.zero_grad()
42+
p=prox(x)
43+
obj = objective(x) + prox(x)
44+
obj.backward()
45+
optimizer.step()
46+
x.data = prox.prox(x.data, tau)
47+
48+
# check that the result is smooth
49+
plt.imshow(x.data.detach().numpy())
50+
plt.show()
51+
52+
# compare with x_true
53+
difference = torch.norm(x.data.flatten() - x_true)
54+
assert difference < 1e-3
55+
56+
# check that the result is sparse
57+
assert torch.sum(x.data == 0) > 50
5758

5859

5960
def test_graph_net_3d_prox():

0 commit comments

Comments
 (0)