Skip to content
12 changes: 12 additions & 0 deletions lib/OptimizationManopt/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name = "OptimizationManopt"
uuid = "e57b7fff-7ee7-4550-b4f0-90e9476e9fb6"
authors = ["Mateusz Baran <mateuszbaran89@gmail.com>"]
version = "0.1.0"

[deps]
ManifoldDiff = "af67fdf4-a580-4b9f-bbec-742ef357defd"
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
ManifoldsBase = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb"
Manopt = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
307 changes: 307 additions & 0 deletions lib/OptimizationManopt/src/OptimizationManopt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
module OptimizationManopt

using Reexport
@reexport using Manopt
using Optimization, Manopt, ManifoldsBase, ManifoldDiff

"""
abstract type AbstractManoptOptimizer end

A Manopt solver without things specified by a call to `solve` (stopping criteria) and
internal state.
"""
abstract type AbstractManoptOptimizer end

function stopping_criterion_to_kwarg(stopping_criterion::Nothing)
return NamedTuple()
end
function stopping_criterion_to_kwarg(stopping_criterion::StoppingCriterion)
return (; stopping_criterion = stopping_criterion)
end

## gradient descent

struct GradientDescentOptimizer{
Teval <: AbstractEvaluationType,
TM <: AbstractManifold,
TLS <: Linesearch
} <: AbstractManoptOptimizer
M::TM
stepsize::TLS
end

function GradientDescentOptimizer(M::AbstractManifold;
eval::AbstractEvaluationType = Manopt.AllocatingEvaluation(),
stepsize::Stepsize = ArmijoLinesearch(M))
GradientDescentOptimizer{typeof(eval), typeof(M), typeof(stepsize)}(M, stepsize)
end

function call_manopt_optimizer(opt::GradientDescentOptimizer{Teval},
loss,
gradF,
x0,
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) where {
Teval <:
AbstractEvaluationType
}
sckwarg = stopping_criterion_to_kwarg(stopping_criterion)
opts = gradient_descent(opt.M,
loss,
gradF,
x0;
return_state = true,
evaluation = Teval(),
stepsize = opt.stepsize,
sckwarg...)
# we unwrap DebugOptions here
minimizer = Manopt.get_solver_result(opts)
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
:who_knows
end

## Nelder-Mead

struct NelderMeadOptimizer{
TM <: AbstractManifold,
} <: AbstractManoptOptimizer
M::TM
end

function call_manopt_optimizer(opt::NelderMeadOptimizer,
loss,
gradF,
x0,
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion})
sckwarg = stopping_criterion_to_kwarg(stopping_criterion)

opts = NelderMead(opt.M,
loss;
return_state = true,
sckwarg...)
minimizer = Manopt.get_solver_result(opts)
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
:who_knows
end

## conjugate gradient descent

struct ConjugateGradientDescentOptimizer{Teval <: AbstractEvaluationType,
TM <: AbstractManifold, TLS <: Stepsize} <:
AbstractManoptOptimizer
M::TM
stepsize::TLS
end

function ConjugateGradientDescentOptimizer(M::AbstractManifold;
eval::AbstractEvaluationType = InplaceEvaluation(),
stepsize::Stepsize = ArmijoLinesearch(M))
ConjugateGradientDescentOptimizer{typeof(eval), typeof(M), typeof(stepsize)}(M,
stepsize)
end

function call_manopt_optimizer(opt::ConjugateGradientDescentOptimizer{Teval},
loss,
gradF,
x0,
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) where {
Teval <:
AbstractEvaluationType
}
sckwarg = stopping_criterion_to_kwarg(stopping_criterion)
opts = conjugate_gradient_descent(opt.M,
loss,
gradF,
x0;
return_state = true,
evaluation = Teval(),
stepsize = opt.stepsize,
sckwarg...)
# we unwrap DebugOptions here
minimizer = Manopt.get_solver_result(opts)
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
:who_knows
end

## particle swarm

struct ParticleSwarmOptimizer{Teval <: AbstractEvaluationType,
TM <: AbstractManifold, Tretr <: AbstractRetractionMethod,
Tinvretr <: AbstractInverseRetractionMethod,
Tvt <: AbstractVectorTransportMethod} <:
AbstractManoptOptimizer
M::TM
retraction_method::Tretr
inverse_retraction_method::Tinvretr
vector_transport_method::Tvt
population_size::Int
end

function ParticleSwarmOptimizer(M::AbstractManifold;
eval::AbstractEvaluationType = InplaceEvaluation(),
population_size::Int = 100,
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
inverse_retraction_method::AbstractInverseRetractionMethod = default_inverse_retraction_method(M),
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M))
ParticleSwarmOptimizer{typeof(eval), typeof(M), typeof(retraction_method),
typeof(inverse_retraction_method),
typeof(vector_transport_method)}(M,
retraction_method,
inverse_retraction_method,
vector_transport_method,
population_size)
end

function call_manopt_optimizer(opt::ParticleSwarmOptimizer{Teval},
loss,
gradF,
x0,
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) where {
Teval <:
AbstractEvaluationType
}
sckwarg = stopping_criterion_to_kwarg(stopping_criterion)
initial_population = vcat([x0], [rand(opt.M) for _ in 1:(opt.population_size - 1)])
opts = particle_swarm(opt.M,
loss;
x0 = initial_population,
n = opt.population_size,
return_state = true,
retraction_method = opt.retraction_method,
inverse_retraction_method = opt.inverse_retraction_method,
vector_transport_method = opt.vector_transport_method,
sckwarg...)
# we unwrap DebugOptions here
minimizer = Manopt.get_solver_result(opts)
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
:who_knows
Copy link

@mateuszbaran mateuszbaran Apr 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can get the retcode by calling

asc = get_active_stopping_criteria(o.stop)

Then asc is a vector of stopping criteria that triggered. For example if it's a one-element vector with a StopAfterIteration object, we know that it reached the limit of iterations. If you want a quick "does it look like it converged?" information, you can use

any(Manopt.indicates_convergence, asc)

end

## quasi Newton

struct QuasiNewtonOptimizer{Teval <: AbstractEvaluationType,
TM <: AbstractManifold, Tretr <: AbstractRetractionMethod,
Tvt <: AbstractVectorTransportMethod, TLS <: Stepsize} <:
AbstractManoptOptimizer
M::TM
retraction_method::Tretr
vector_transport_method::Tvt
stepsize::TLS
end

function QuasiNewtonOptimizer(M::AbstractManifold;
eval::AbstractEvaluationType = InplaceEvaluation(),
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M),
stepsize = WolfePowellLinesearch(M;
retraction_method = retraction_method,
vector_transport_method = vector_transport_method,
linesearch_stopsize = 1e-12))
QuasiNewtonOptimizer{typeof(eval), typeof(M), typeof(retraction_method),
typeof(vector_transport_method), typeof(stepsize)}(M,
retraction_method,
vector_transport_method,
stepsize)
end

function call_manopt_optimizer(opt::QuasiNewtonOptimizer{Teval},
loss,
gradF,
x0,
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) where {
Teval <:
AbstractEvaluationType
}
sckwarg = stopping_criterion_to_kwarg(stopping_criterion)
opts = quasi_Newton(opt.M,
loss,
gradF,
x0;
return_state = true,
evaluation = Teval(),
retraction_method = opt.retraction_method,
vector_transport_method = opt.vector_transport_method,
stepsize = opt.stepsize,
sckwarg...)
# we unwrap DebugOptions here
minimizer = Manopt.get_solver_result(opts)
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
:who_knows
end

## Optimization.jl stuff

function build_loss(f::OptimizationFunction, prob)
function (::AbstractManifold, θ)
x = f.f(θ)
__x = first(x)
return prob.sense === Optimization.MaxSense ? -__x : __x
end
end

function build_gradF(f::OptimizationFunction{true}, prob, cur)
function g(M::AbstractManifold, G, θ)
f.grad(G, θ, cur...)
G .= riemannian_gradient(M, θ, G)
end
function g(M::AbstractManifold, θ)
G = zero(θ)
f.grad(G, θ, cur...)
return riemannian_gradient(M, θ, G)
end
end

# TODO:
# 1) convert tolerances and other stopping criteria
# 2) return convergence information
# 3) add callbacks to Manopt.jl

function SciMLBase.__solve(prob::OptimizationProblem,
opt::AbstractManoptOptimizer,
data = Optimization.DEFAULT_DATA;
callback = (args...) -> (false),
maxiters::Union{Number, Nothing} = nothing,
maxtime::Union{Number, Nothing} = nothing,
abstol::Union{Number, Nothing} = nothing,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can translate abstol into StopWhenChangeLess I think. reltol is really hard to define generically so we don't use that.

reltol::Union{Number, Nothing} = nothing,
progress = false,
kwargs...)
local x, cur, state

manifold = haskey(prob.kwargs, :manifold) ? prob.kwargs[:manifold] : nothing

if manifold === nothing || manifold !== opt.M
throw(ArgumentError("Either manifold not specified in the problem `OptimizationProblem(f, x, p; manifold = SymmetricPositiveDefinite(5))` or it doesn't match the manifold specified in the optimizer `$(opt.M)`"))
end

if data !== Optimization.DEFAULT_DATA
maxiters = length(data)
end

cur, state = iterate(data)

stopping_criterion = nothing
if maxiters !== nothing
stopping_criterion = StopAfterIteration(maxiters)
end

maxiters = Optimization._check_and_convert_maxiters(maxiters)
maxtime = Optimization._check_and_convert_maxtime(maxtime)

f = Optimization.instantiate_function(prob.f, prob.u0, prob.f.adtype, prob.p)

_loss = build_loss(f, prob)

gradF = build_gradF(f, prob, cur)

opt_res, opt_ret = call_manopt_optimizer(opt, _loss, gradF, prob.u0, stopping_criterion)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would change this interface for example to

Suggested change
opt_res, opt_ret = call_manopt_optimizer(opt, _loss, gradF, prob.u0, stopping_criterion)
opt_res, opt_ret = call_manopt_optimizer(opt, prob.M, _loss, gradF, prob.u0, stopping_criterion)

remove the manifold from all optimisers and be happy :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, good point. Will just remove it.


return SciMLBase.build_solution(SciMLBase.DefaultOptimizationCache(prob.f, prob.p),
opt,
opt_res.minimizer,
prob.sense === Optimization.MaxSense ?
-opt_res.minimum : opt_res.minimum;
original = opt_res.options,
retcode = opt_ret)
end

end # module OptimizationManopt
Loading