|
3 | 3 | from proxtorch.operators import GraphNet3D, GraphNet2D
|
4 | 4 |
|
5 | 5 |
|
6 |
| -# |
7 |
| -# def test_converges_to_sparse_smooth(): |
8 |
| -# torch.manual_seed(0) |
9 |
| -# |
10 |
| -# # generate a spatially sparse signal |
11 |
| -# x_true = torch.zeros(10, 10) |
12 |
| -# x_true[3:7, 3:7] = torch.ones(4, 4) |
13 |
| -# x_true = x_true.flatten() |
14 |
| -# |
15 |
| -# # generate a random matrix |
16 |
| -# A = torch.rand(100, 100) |
17 |
| -# |
18 |
| -# # generate measurements |
19 |
| -# y = A @ x_true |
20 |
| -# |
21 |
| -# # define the proximal operator |
22 |
| -# alpha = 10 |
23 |
| -# l1_ratio = 0.0 # 0.5 |
24 |
| -# prox = GraphNet2D(alpha, l1_ratio) |
25 |
| -# |
26 |
| -# # define the objective function |
27 |
| -# def objective(x): |
28 |
| -# return 0.5 * torch.norm(A @ x.reshape(-1) - y) ** 2 |
29 |
| -# |
30 |
| -# # define the step size |
31 |
| -# tau = 1 / torch.norm(A.t() @ A) |
32 |
| -# |
33 |
| -# # initialize the solution |
34 |
| -# x = torch.nn.Parameter(torch.rand(10, 10, requires_grad=True)) |
35 |
| -# |
36 |
| -# # optimizer |
37 |
| -# optimizer = torch.optim.SGD([x], lr=tau) |
38 |
| -# |
39 |
| -# # optimization loop |
40 |
| -# for i in range(1000): |
41 |
| -# optimizer.zero_grad() |
42 |
| -# obj = objective(x) + prox(x) |
43 |
| -# obj.backward() |
44 |
| -# optimizer.step() |
45 |
| -# x.data = prox.prox(x.data, tau) |
46 |
| -# |
47 |
| -# # check that the result is smooth |
48 |
| -# plt.imshow(x.data.detach().numpy()) |
49 |
| -# plt.show() |
50 |
| -# |
51 |
| -# # compare with x_true |
52 |
| -# difference = torch.norm(x.data.flatten() - x_true) |
53 |
| -# assert difference < 1e-3 |
54 |
| -# |
55 |
| -# # check that the result is sparse |
56 |
| -# assert torch.sum(x.data == 0) > 50 |
| 6 | +def test_converges_to_sparse_smooth(): |
| 7 | + import matplotlib.pyplot as plt |
| 8 | + torch.manual_seed(0) |
| 9 | + |
| 10 | + # generate a spatially sparse signal |
| 11 | + x_true = torch.zeros(10, 10) |
| 12 | + x_true[3:7, 3:7] = torch.ones(4, 4) |
| 13 | + x_true = x_true.flatten() |
| 14 | + |
| 15 | + # generate a random matrix |
| 16 | + A = torch.rand(100, 100) |
| 17 | + |
| 18 | + # generate measurements |
| 19 | + y = A @ x_true |
| 20 | + |
| 21 | + # define the proximal operator |
| 22 | + alpha = 1000 |
| 23 | + l1_ratio = 0.0 # 0.5 |
| 24 | + prox = GraphNet2D(alpha, l1_ratio) |
| 25 | + |
| 26 | + # define the objective function |
| 27 | + def objective(x): |
| 28 | + return 0.5 * torch.norm(A @ x.reshape(-1) - y) ** 2 |
| 29 | + |
| 30 | + # define the step size |
| 31 | + tau = 0.1 / torch.norm(A.t() @ A) |
| 32 | + |
| 33 | + # initialize the solution |
| 34 | + x = torch.nn.Parameter(torch.rand(10, 10, requires_grad=True)) |
| 35 | + |
| 36 | + # optimizer |
| 37 | + optimizer = torch.optim.SGD([x], lr=tau, nesterov=True, momentum=0.9) |
| 38 | + |
| 39 | + # optimization loop |
| 40 | + for i in range(20000): |
| 41 | + optimizer.zero_grad() |
| 42 | + p=prox(x) |
| 43 | + obj = objective(x) + prox(x) |
| 44 | + obj.backward() |
| 45 | + optimizer.step() |
| 46 | + x.data = prox.prox(x.data, tau) |
| 47 | + |
| 48 | + # check that the result is smooth |
| 49 | + plt.imshow(x.data.detach().numpy()) |
| 50 | + plt.show() |
| 51 | + |
| 52 | + # compare with x_true |
| 53 | + difference = torch.norm(x.data.flatten() - x_true) |
| 54 | + assert difference < 1e-3 |
| 55 | + |
| 56 | + # check that the result is sparse |
| 57 | + assert torch.sum(x.data == 0) > 50 |
57 | 58 |
|
58 | 59 |
|
59 | 60 | def test_graph_net_3d_prox():
|
|
0 commit comments