Skip to content

Commit c58e295

Browse files
authored
Merge pull request #305 from astro-informatics/tk/decomposition-refactor-forward-backward
Refactor ImagingForwardBackward class
2 parents 9bf228a + 83ed89c commit c58e295

File tree

10 files changed

+337
-190
lines changed

10 files changed

+337
-190
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,4 @@ build/
7979
.settings
8080
python/tests/__pycache__
8181
*.h.gch
82+
*~

cpp/examples/forward_backward/inpainting.cc

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <ctime>
88

99
#include <sopt/imaging_forward_backward.h>
10+
#include <sopt/l1_g_proximal.h>
1011
#include <sopt/logging.h>
1112
#include <sopt/maths.h>
1213
#include <sopt/relative_variation.h>
@@ -90,21 +91,28 @@ int main(int argc, char const **argv) {
9091
sopt::t_real const gamma = 18;
9192
sopt::t_real const beta = sigma * sigma * 0.5;
9293
SOPT_HIGH_LOG("Creating Foward Backward Functor");
93-
auto const fb = sopt::algorithm::ImagingForwardBackward<Scalar>(y)
94-
.itermax(500)
95-
.beta(beta) // stepsize
96-
.sigma(sigma) // sigma
97-
.gamma(gamma) // regularisation paramater
98-
.relative_variation(1e-3)
99-
.residual_tolerance(0)
100-
.tight_frame(true)
101-
.l1_proximal_tolerance(1e-4)
102-
.l1_proximal_nu(1)
103-
.l1_proximal_itermax(50)
104-
.l1_proximal_positivity_constraint(true)
105-
.l1_proximal_real_constraint(true)
106-
.Psi(psi)
107-
.Phi(sampling);
94+
auto fb = sopt::algorithm::ImagingForwardBackward<Scalar>(y)
95+
.itermax(500)
96+
.beta(beta) // stepsize
97+
.sigma(sigma) // sigma
98+
.gamma(gamma) // regularisation paramater
99+
.relative_variation(1e-3)
100+
.residual_tolerance(0)
101+
.tight_frame(true)
102+
.Phi(sampling);
103+
104+
// Create a shared pointer to an instance of the L1GProximal class
105+
// and set its properties
106+
auto gp = std::make_shared<sopt::algorithm::L1GProximal<Scalar>>(false);
107+
gp->l1_proximal_tolerance(1e-4)
108+
.l1_proximal_nu(1)
109+
.l1_proximal_itermax(50)
110+
.l1_proximal_positivity_constraint(true)
111+
.l1_proximal_real_constraint(true)
112+
.Psi(psi);
113+
114+
// Once the properties are set, inject it into the ImagingForwardBackward object
115+
fb.g_proximal(gp);
108116

109117
SOPT_HIGH_LOG("Starting Forward Backward");
110118
// Alternatively, forward-backward can be called with a tuple (x, residual) as argument

cpp/examples/forward_backward/inpainting_credible_interval.cc

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <sopt/credible_region.h>
1010
#include <sopt/imaging_forward_backward.h>
11+
#include <sopt/l1_g_proximal.h>
1112
#include <sopt/logging.h>
1213
#include <sopt/maths.h>
1314
#include <sopt/relative_variation.h>
@@ -92,22 +93,29 @@ int main(int argc, char const **argv) {
9293
sopt::t_real const gamma = 18;
9394
sopt::t_real const beta = sigma * sigma;
9495
SOPT_HIGH_LOG("Creating Foward Backward Functor");
95-
auto const fb = sopt::algorithm::ImagingForwardBackward<Scalar>(y)
96-
.itermax(500)
97-
.beta(beta)
98-
.sigma(sigma)
99-
.gamma(gamma)
100-
.relative_variation(5e-4)
101-
.residual_tolerance(0)
102-
.tight_frame(true)
103-
.l1_proximal_tolerance(1e-4)
104-
.l1_proximal_nu(1)
105-
.l1_proximal_itermax(50)
106-
.l1_proximal_positivity_constraint(true)
107-
.l1_proximal_real_constraint(true)
108-
.Psi(psi)
109-
.Phi(sampling);
96+
auto fb = sopt::algorithm::ImagingForwardBackward<Scalar>(y)
97+
.itermax(500)
98+
.beta(beta)
99+
.sigma(sigma)
100+
.gamma(gamma)
101+
.relative_variation(5e-4)
102+
.residual_tolerance(0)
103+
.tight_frame(true)
104+
.Phi(sampling);
110105

106+
// Create a shared pointer to an instance of the L1GProximal class
107+
// and set its properties
108+
auto gp = std::make_shared<sopt::algorithm::L1GProximal<Scalar>>(false);
109+
gp->l1_proximal_tolerance(1e-4)
110+
.l1_proximal_nu(1)
111+
.l1_proximal_itermax(50)
112+
.l1_proximal_positivity_constraint(true)
113+
.l1_proximal_real_constraint(true)
114+
.Psi(psi);
115+
116+
// Once the properties are set, inject it into the ImagingForwardBackward object
117+
fb.g_proximal(gp);
118+
111119
SOPT_HIGH_LOG("Starting Forward Backward");
112120
// Alternatively, forward-backward can be called with a tuple (x, residual) as argument
113121
// Here, we default to (Φ^Ty/ν, ΦΦ^Ty/ν - y)

cpp/examples/forward_backward/inpainting_joint_map.cc

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <ctime>
88

99
#include <sopt/imaging_forward_backward.h>
10+
#include <sopt/l1_g_proximal.h>
1011
#include <sopt/joint_map.h>
1112
#include <sopt/logging.h>
1213
#include <sopt/maths.h>
@@ -91,21 +92,28 @@ int main(int argc, char const **argv) {
9192
sopt::t_real const gamma = 0;
9293
sopt::t_real const beta = sigma * sigma * 0.5;
9394
SOPT_HIGH_LOG("Creating Foward Backward Functor");
94-
auto const fb = std::make_shared<sopt::algorithm::ImagingForwardBackward<Scalar>>(y);
95+
auto fb = std::make_shared<sopt::algorithm::ImagingForwardBackward<Scalar>>(y);
9596
fb->itermax(500)
96-
.beta(beta) // stepsize
97-
.sigma(sigma) // sigma
98-
.gamma(gamma) // regularisation paramater
99-
.relative_variation(1e-3)
100-
.residual_tolerance(0)
101-
.tight_frame(true)
102-
.l1_proximal_tolerance(1e-5)
103-
.l1_proximal_nu(1)
104-
.l1_proximal_itermax(50)
105-
.l1_proximal_positivity_constraint(true)
106-
.l1_proximal_real_constraint(true)
107-
.Psi(psi)
108-
.Phi(sampling);
97+
.beta(beta) // stepsize
98+
.sigma(sigma) // sigma
99+
.gamma(gamma) // regularisation paramater
100+
.relative_variation(1e-3)
101+
.residual_tolerance(0)
102+
.tight_frame(true)
103+
.Phi(sampling);
104+
105+
// Create a shared pointer to an instance of the L1GProximal class
106+
// and set its properties
107+
auto gp = std::make_shared<sopt::algorithm::L1GProximal<Scalar>>(false);
108+
gp->l1_proximal_tolerance(1e-4)
109+
.l1_proximal_nu(1)
110+
.l1_proximal_itermax(50)
111+
.l1_proximal_positivity_constraint(true)
112+
.l1_proximal_real_constraint(true)
113+
.Psi(psi);
114+
115+
// Once the properties are set, inject it into the ImagingForwardBackward object
116+
fb->g_proximal(gp);
109117

110118
SOPT_HIGH_LOG("Starting Forward Backward");
111119
// Alternatively, forward-backward can be called with a tuple (x, residual) as argument

cpp/sopt/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ set(headers
33
bisection_method.h chained_operators.h credible_region.h
44
imaging_padmm.h logging.disabled.h
55
forward_backward.h imaging_forward_backward.h
6+
g_proximal.h l1_g_proximal.h
67
joint_map.h
78
imaging_primal_dual.h primal_dual.h
89
maths.h proximal.h relative_variation.h sdmm.h

cpp/sopt/g_proximal.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#ifndef SOPT_G_PROXIMAL_H
2+
#define SOPT_G_PROXIMAL_H
3+
4+
#include <numeric>
5+
#include <tuple>
6+
#include <utility>
7+
#include "sopt/forward_backward.h"
8+
9+
// Abstract base class providing the interface to the g_proximal function
10+
template <class SCALAR> class GProximal {
11+
12+
typedef sopt::algorithm::ForwardBackward<SCALAR> FB;
13+
typedef typename FB::Real Real;
14+
typedef typename FB::t_Vector t_Vector;
15+
typedef typename FB::t_Proximal t_Proximal;
16+
typedef typename FB::t_LinearTransform t_LinearTransform;
17+
18+
public:
19+
20+
// A function that prints a log message
21+
virtual void log_message() const = 0;
22+
// A function that returns a function for the g_proximal.
23+
// Function must be of type t_Proximal, that is
24+
// void proximal_function(Vector, real, Vector)
25+
virtual t_Proximal proximal_function() const = 0;
26+
// Returns the norm of x
27+
virtual Real proximal_norm(t_Vector const &x) const = 0;
28+
// Transforms input image to a different basis.
29+
// Return linear_transform_identity() if transform not necessary.
30+
virtual const t_LinearTransform &Psi() const = 0;
31+
}; // class GProximal
32+
#endif

0 commit comments

Comments
 (0)