Skip to content

Commit 10c88cc

Browse files
committed
feat: add InfiniteOptControlProblem
1 parent 46ae9e5 commit 10c88cc

File tree

4 files changed

+39
-32
lines changed

4 files changed

+39
-32
lines changed

ext/MTKJuMPControlExt.jl

+25-27
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ struct JuMPControlProblem{uType, tType, isinplace, P, F, K} <: AbstractOptimalCo
2020
end
2121
end
2222

23-
struct InfiniteOptControlProblem{uType, tType, isinplace, P, F, K} <: SciMLBase.AbstractODEProblem{uType, tType, isinplace}
23+
struct InfiniteOptControlProblem{uType, tType, isinplace, P, F, K} <: AbstractOptimalControlProblem{uType, tType, isinplace}
2424
f::F
2525
u0::uType
2626
tspan::tType
@@ -49,16 +49,10 @@ The constraints are:
4949
- The solver constraints that encode the time-stepping used by the solver
5050
"""
5151
function MTK.JuMPControlProblem(sys::ODESystem, u0map, tspan, pmap; dt = error("dt must be provided for JuMPControlProblem."), guesses = Dict(), kwargs...)
52-
constraintsys = MTK.get_constraintsystem(sys)
53-
if !isnothing(constraintsys)
54-
(length(constraints(constraintsys)) + length(u0map) > length(states)) &&
55-
@warn "The control problem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The solvers will default to doing a nonlinear least-squares optimization."
56-
end
57-
5852
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
5953
f, u0, p = MTK.process_SciMLProblem(ODEFunction, sys, _u0map, pmap;
6054
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)
61-
model = init_model(sys, tspan[1]:dt:tspan[2], u0map)
55+
model = init_model(sys, tspan[1]:dt:tspan[2], u0map, u0)
6256

6357
JuMPControlProblem(f, u0, tspan, p, model, kwargs...)
6458
end
@@ -74,25 +68,24 @@ Related to `JuMPControlProblem`, but directly adds the differential equations
7468
of the system as derivative constraints, rather than using a solver tableau.
7569
"""
7670
function MTK.InfiniteOptControlProblem(sys::ODESystem, u0map, tspan, pmap; dt = error("dt must be provided for InfiniteOptControlProblem."), guesses = Dict(), kwargs...)
77-
constraintsys = MTK.get_constraintsystem(sys)
78-
if !isnothing(constraintsys)
79-
(length(constraints(constraintsys)) + length(u0map) > length(unknowns(sys))) &&
80-
@warn "The control problem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The solvers will default to doing a nonlinear least-squares optimization."
81-
end
82-
8371
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
8472
f, u0, p = MTK.process_SciMLProblem(ODEFunction, sys, _u0map, pmap;
8573
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)
8674

87-
model = init_model(sys, tspan[1]:dt:tspan[2], u0map)
75+
model = init_model(sys, tspan[1]:dt:tspan[2], u0map, u0)
8876
add_infopt_solve_constraints!(model, sys, pmap)
8977
InfiniteOptControlProblem(f, u0, tspan, p, model, kwargs...)
9078
end
9179

92-
function init_model(sys, tsteps, u0map)
80+
function init_model(sys, tsteps, u0map, u0)
81+
constraintsys = MTK.get_constraintsystem(sys)
82+
if !isnothing(constraintsys)
83+
(length(constraints(constraintsys)) + length(u0map) > length(unknowns(sys))) &&
84+
@warn "The control problem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The solvers will default to doing a nonlinear least-squares optimization."
85+
end
86+
9387
ctrls = controls(sys)
9488
states = unknowns(sys)
95-
9689
model = InfiniteModel()
9790
@infinite_parameter(model, t in [tsteps[1], tsteps[end]], num_supports = length(tsteps))
9891
@variable(model, U[i = 1:length(states)], Infinite(t))
@@ -103,13 +96,14 @@ function init_model(sys, tsteps, u0map)
10396

10497
stidxmap = Dict([v => i for (i, v) in enumerate(states)])
10598
u0_idxs = has_alg_eqs(sys) ? collect(1:length(states)) : [stidxmap[k] for (k, v) in u0map]
106-
add_initial_constraints!(model, u0, u0_idxs, tspan)
99+
add_initial_constraints!(model, u0, u0_idxs, tsteps[1])
100+
return model
107101
end
108102

109103
function add_jump_cost_function!(model, sys)
110104
jcosts = MTK.get_costs(sys)
111105
consolidate = MTK.get_consolidate(sys)
112-
if isnothing(jcosts)
106+
if isnothing(jcosts) || isempty(jcosts)
113107
@objective(model, Min, 0)
114108
return
115109
end
@@ -173,28 +167,32 @@ function add_user_constraints!(model, sys)
173167
end
174168
end
175169

176-
function add_initial_constraints!(model, u0, u0_idxs, tspan)
177-
ts = tspan[1]
170+
function add_initial_constraints!(model, u0, u0_idxs, ts)
178171
U = model[:U]
179172
@constraint(model, initial[i in u0_idxs], U[i](ts) == u0[i])
180173
end
181174

182175
is_explicit(tableau) = tableau isa DiffEqDevTools.ExplicitRKTableau
183176

184177
function add_infopt_solve_constraints!(model, sys, pmap)
185-
iv = get_iv(sys)
178+
iv = MTK.get_iv(sys)
186179
t = model[:t]
187180
U = model[:U]
188181
V = model[:V]
189182

190183
stmap = Dict([v => U[i] for (i, v) in enumerate(unknowns(sys))])
191184
ctrlmap = Dict([v => V[i] for (i, v) in enumerate(controls(sys))])
192-
submap = merge(stmap, ctrlmap, pmap)
185+
submap = merge(stmap, ctrlmap, Dict(pmap))
186+
@show submap
193187

194-
@register_symbolic _D(x) = (x, t)
195188
# Differential equations
196189
diff_eqs = diff_equations(sys)
197-
diff_eqs = map(e -> Symbolics.substitute(e, submap, Differential(iv) => _D), diff_eqs)
190+
D = Differential(iv)
191+
diffsubmap = Dict([D(U[i]) => (U[i], t) for i in 1:length(U)])
192+
for u in unknowns(sys)
193+
diff_eqs = map(e -> Symbolics.substitute(e, submap), diff_eqs)
194+
diff_eqs = map(e -> Symbolics.substitute(e, diffsubmap), diff_eqs)
195+
end
198196
@constraint(model, D[i = 1:length(diff_eqs)], diff_eqs[i].lhs == diff_eqs[i].rhs)
199197

200198
# Algebraic equations
@@ -273,13 +271,13 @@ function DiffEqBase.solve(prob::JuMPControlProblem, jump_solver, ode_solver::Sym
273271
delete(model, con)
274272
end
275273
end
274+
unregister(model, :K)
276275
for var in all_variables(model)
277276
if occursin("K", JuMP.name(var))
278-
unregister(model, Symbol(JuMP.name(var)))
279277
delete(model, var)
280278
end
281279
end
282-
add_solve_constraints!(prob, tableau)
280+
add_jump_solve_constraints!(prob, tableau)
283281
_solve(prob, jump_solver, ode_solver)
284282
end
285283

src/ModelingToolkit.jl

+2
Original file line numberDiff line numberDiff line change
@@ -348,5 +348,7 @@ function FMIComponent end
348348

349349
function JuMPControlProblem end
350350
export JuMPControlProblem
351+
function InfiniteOptControlProblem end
352+
export InfiniteOptControlProblem
351353

352354
end # module

src/systems/diffeqs/odesystem.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
336336

337337
if length(costs) > 1 && isnothing(consolidate)
338338
error("Must specify a consolidation function for the costs vector.")
339-
elseif isnothing(consolidate)
339+
elseif length(costs) == 1 && isnothing(consolidate)
340340
consolidate = u -> u[1]
341341
end
342342

test/extensions/jump_control.jl

+11-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using DiffEqDevTools, DiffEqBase
44
using SimpleDiffEq
55
using OrdinaryDiffEqSDIRK
66
using Ipopt
7+
using BenchmarkTools
78
const M = ModelingToolkit
89

910
@testset "ODE Solution, no cost" begin
@@ -29,11 +30,11 @@ const M = ModelingToolkit
2930
@test jsol.sol.u osol.u
3031

3132
# Implicit method.
32-
jsol2 = solve(jprob, Ipopt.Optimizer, :ImplicitEuler)
33-
osol2 = solve(oprob, ImplicitEuler(), dt = 0.01, adaptive = false)
33+
jsol2 = @btime solve($jprob, Ipopt.Optimizer, :ImplicitEuler) # 63.031 ms, 26.49 MiB
34+
osol2 = @btime solve($oprob, ImplicitEuler(), dt = 0.01, adaptive = false) # 129.375 μs, 61.91 KiB
3435
@test (jsol2.sol.u, osol2.u, rtol = 0.001)
3536
iprob = InfiniteOptControlProblem(sys, u0map, tspan, parammap, dt = 0.01)
36-
isol = solve(iprob, Ipopt.Optimizer, derivative_method = FiniteDifference(Backward()))
37+
isol = @btime solve($iprob, Ipopt.Optimizer, derivative_method = FiniteDifference(Backward())) # 11.540 ms, 4.00 MiB
3738

3839
# With a constraint
3940
u0map = Pair[]
@@ -43,10 +44,16 @@ const M = ModelingToolkit
4344

4445
jprob = JuMPControlProblem(lksys, u0map, tspan, parammap; guesses = guess, dt = 0.01)
4546
@test num_constraints(jprob.model) == 2
46-
jsol = solve(jprob, Ipopt.Optimizer, :Tsitouras5)
47+
jsol = @btime solve($jprob, Ipopt.Optimizer, :Tsitouras5) # 12.190 s, 9.68 GiB
4748
sol = jsol.sol
4849
@test sol(0.6)[1] 3.5
4950
@test sol(0.3)[1] 7.0
51+
52+
iprob = InfiniteOptControlProblem(lksys, u0map, tspan, parammap; guesses = guess, dt = 0.01)
53+
isol = @btime solve($iprob, Ipopt.Optimizer, derivative_method = OrthogonalCollocation(3)) # 48.564 ms, 9.58 MiB
54+
sol = isol.sol
55+
@test sol(0.6)[1] 3.5
56+
@test sol(0.3)[1] 7.0
5057
end
5158

5259
#@testset "Optimal control: bees" begin

0 commit comments

Comments
 (0)