diff --git a/cpp/sopt/CMakeLists.txt b/cpp/sopt/CMakeLists.txt index e782d6ea4..d2aa5980f 100644 --- a/cpp/sopt/CMakeLists.txt +++ b/cpp/sopt/CMakeLists.txt @@ -2,6 +2,7 @@ set(headers bisection_method.h chained_operators.h credible_region.h imaging_padmm.h logging.disabled.h + gradient_descent.h forward_backward.h imaging_forward_backward.h g_proximal.h l1_g_proximal.h joint_map.h imaging_primal_dual.h primal_dual.h diff --git a/cpp/sopt/gradient_descent.h b/cpp/sopt/gradient_descent.h new file mode 100644 index 000000000..1f382773d --- /dev/null +++ b/cpp/sopt/gradient_descent.h @@ -0,0 +1,128 @@ +#ifndef SOPT_GRADIENT_DESCENT_H +#define SOPT_GRADIENT_DESCENT_H + +#include +#include "sopt/linear_transform.h" +#include "sopt/types.h" + +namespace sopt::algorithm { + + //! Values indicating how the algorithm ran + template + struct AlgorithmResults { + //! Number of iterations + t_uint niters; + //! Wether convergence was achieved + bool good; + //! the residual from the last iteration + Vector residual; + Vector result; + }; + +//! \brief Pure gradient descent algorithm +//! \details Requires \f$\grad f, \grad g\f$ be analytically defined. +//! \f$x_{n+1} = x_n + \alpha R(\grad f(x_n, y)) + \lambda \grad(g(\mu x_n))\f$ +//! \param f_gradient: Gradient function for f, where f is usually a likelihood. Takes two arguments(x, y). +//! \param g_gradient: Gradient function for g, where g is usually a prior / regulator. Takes one argument (x). +//! \param lambda: multiplier for g gradient function +//! \param Lipschitz_f: Lipschitz constant of function f (used to calculated alpha) +//! \param Lipschitz_g: Lipschitz constant of function g (used to calculated alpha) +//! \param mu: Scaling parameter for vector inside g. Also used to calculate alpha +template +class GradientDescent +{ + public: + using F_Gradient = + typename std::function(const Vector &, const Vector &)>; + using G_Gradient = typename std::function(const Vector &)>; + using REAL = typename real_type::type; + + GradientDescent(F_Gradient const &f_gradient, + G_Gradient const &g_gradient, + Vector const &target, + REAL const threshold, + REAL const Lipschitz_f = 1, + REAL const Lipschitz_g = 1, + REAL const mu = 1, + REAL const lambda = 1) + : Phi(linear_transform_identity()), + f_gradient(f_gradient), + g_gradient(g_gradient), + target(target), + Lipschitz_f(Lipschitz_f), + Lipschitz_g(Lipschitz_g), + threshold_delta(threshold) + { + alpha = 0.98 / (Lipschitz_f + mu * lambda * Lipschitz_g); + } + + AlgorithmResults operator()(Vector &x) + { + Vector z = x; + bool converged = false; + uint iterations = 0; + while ((!converged) && (iterations < max_iterations)) + { + iteration_step(x, z); + + converged = is_converged(x); + + ++iterations; + } + + if(converged) + { + // TODO: Log some relevant stuff about the convergence. + } + + AlgorithmResults results; + results.good = converged; + results.niters = iterations; + results.residual = (Phi * x) - target; + results.result = z; + + return results; + } + + protected: + LinearTransform> Phi; + F_Gradient f_gradient; + G_Gradient g_gradient; + REAL alpha; + REAL lambda = 1; + REAL mu = 1; + REAL Lipschitz_f = 1; + REAL Lipschitz_g = 1; + Vector target; + REAL threshold_delta; + Vector delta_x; + REAL theta_now; + REAL theta_next; + Vector x_prev; + uint max_iterations = 200; + + void iteration_step(Vector &x, Vector &z) + { + // Should be able to make this better to avoid copies + x_prev = x; + + delta_x = f_gradient(z, target).real(); + delta_x += lambda * g_gradient(mu * z); + delta_x *= alpha; + + theta_next = 0.5 * (1 + sqrt(1 + 4*theta_now*theta_now)); + + x = z - delta_x; + z = x + (theta_now - 1)/ theta_next * (x - x_prev); + } + + bool is_converged(Vector &x) + { + return (delta_x.norm() / x.norm()) < threshold_delta; + } + +}; + +} // namespace sopt::algorithm + +#endif // SOPT_GRADIENT_DESCENT_H \ No newline at end of file diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 61d06db59..d7c84ce69 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -24,6 +24,7 @@ add_catch_test(chained_operators LIBRARIES sopt SEED ${RAND_SEED}) add_catch_test(conjugate_gradient LIBRARIES sopt SEED ${RAND_SEED}) add_catch_test(credible_region LIBRARIES sopt SEED ${RAND_SEED}) add_catch_test(forward_backward LIBRARIES sopt tools_for_tests SEED ${RAND_SEED}) +add_catch_test(gradient_descent LIBRARIES sopt tools_for_tests SEED ${RAND_SEED}) add_catch_test(gradient_operator LIBRARIES sopt tools_for_tests SEED ${RAND_SEED}) add_catch_test(inpainting LIBRARIES sopt tools_for_tests SEED ${RAND_SEED}) add_catch_test(linear_transform LIBRARIES sopt SEED ${RAND_SEED}) diff --git a/cpp/tests/gradient_descent.cc b/cpp/tests/gradient_descent.cc new file mode 100644 index 000000000..7b47a3004 --- /dev/null +++ b/cpp/tests/gradient_descent.cc @@ -0,0 +1,79 @@ +#include +#include "sopt/gradient_descent.h" +#include + +uint constexpr N = 10; + +TEST_CASE("Gradient Descent with flat prior", "[GradDescent]") +{ + using namespace sopt; + + const Vector target = Vector::Random(N); + float const sigma = 0.5; + float const gamma = 0.1; + uint const max_iterations = 100; + + auto const grad_likelihood = [](const Vector&x, const Vector&y){return (x-y);}; + auto const grad_prior = [](const Vector &x){return 0*x;}; + + Vector init_guess = Vector::Random(N); + + auto Phi = linear_transform_identity(); + + algorithm::GradientDescent gd(grad_likelihood, grad_prior, target, 1e-4); + + algorithm::AlgorithmResults results = gd(init_guess); + + CHECK(results.good); + CHECK(results.result.isApprox(target, 0.1)); +} + +TEST_CASE("Gradient Descent with smoothness prior", "[GradDescent]") +{ + using namespace sopt; + std::mt19937_64 rng; + std::uniform_real_distribution noise(0, 0.2); + + Vector data(N); + for(uint i = 0; i < N; i++) + { + data(i) = sin((M_PI/(N-1))*i) + noise(rng); + } + + Vector perfect(N); + for(uint i = 0; i < N; i++) + { + perfect(i) = sin((M_PI/(N-1))*i); + } + + float const sigma = 0.5; + float const gamma = 0.1; + uint const max_iterations = 100; + + auto const grad_likelihood = [](const Vector&x, const Vector&y){return (x-y);}; + auto const grad_prior = [](const Vector &x) + { + Vector grad(x.size()); + grad(0) = x(0); + grad(x.size()-1) = x(x.size()-1); + for(uint i = 1; i < x.size()-1; i++) + { + // Push values to be roughly in line with neighbours + // Hand wavey kind of smoothness prior + grad(i) = x(i) - 0.5*(x(i-1) + x(i+1)); + } + return grad; + }; + + Vector init_guess = Vector::Random(N); + + auto Phi = linear_transform_identity(); + + algorithm::GradientDescent gd(grad_likelihood, grad_prior, data, 1e-4); + + algorithm::AlgorithmResults results = gd(init_guess); + + CHECK(results.good); + CHECK(results.result.isApprox(perfect, 0.1)); + CHECK((results.result - perfect).squaredNorm() < (data-perfect).squaredNorm()); +} \ No newline at end of file