Skip to content

Commit ab34ad8

Browse files
committed
format
1 parent 9b374b2 commit ab34ad8

File tree

3 files changed

+62
-50
lines changed

3 files changed

+62
-50
lines changed

ext/MTKJuMPControlExt.jl

+19-15
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ function init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t = false)
102102

103103
if is_free_t
104104
(ts_sym, te_sym) = tspan
105-
@variable(model, tf, start = pmap[te_sym])
105+
@variable(model, tf, start=pmap[te_sym])
106106
hasbounds(te_sym) && begin
107107
lo, hi = getbounds(te_sym)
108108
set_lower_bound(tf, lo)
@@ -112,11 +112,11 @@ function init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t = false)
112112
pmap[te_sym] = 1
113113
tspan = (0, 1)
114114
end
115-
116-
@infinite_parameter(model, t in [tspan[1], tspan[2]], num_supports = steps)
117-
@variable(model, U[i = 1:length(states)], Infinite(t), start = u0[i])
115+
116+
@infinite_parameter(model, t in [tspan[1], tspan[2]], num_supports=steps)
117+
@variable(model, U[i = 1:length(states)], Infinite(t), start=u0[i])
118118
c0 = [pmap[c] for c in ctrls]
119-
@variable(model, V[i = 1:length(ctrls)], Infinite(t), start = c0[i])
119+
@variable(model, V[i = 1:length(ctrls)], Infinite(t), start=c0[i])
120120

121121
set_jump_bounds!(model, sys, pmap)
122122
add_jump_cost_function!(model, sys, (tspan[1], tspan[2]), pmap; is_free_t)
@@ -161,7 +161,8 @@ function add_jump_cost_function!(model::InfiniteModel, sys, tspan, pmap; is_free
161161

162162
# Substitute integral
163163
iv = MTK.get_iv(sys)
164-
jcosts = map(c -> Symbolics.substitute(c, MTK.() => Symbolics.Integral(iv in tspan)), jcosts)
164+
jcosts = map(
165+
c -> Symbolics.substitute(c, MTK.() => Symbolics.Integral(iv in tspan)), jcosts)
165166

166167
intmap = Dict()
167168
for int in MTK.collect_applied_operators(jcosts, Symbolics.Integral)
@@ -183,7 +184,8 @@ function add_user_constraints!(model::InfiniteModel, sys, pmap; is_free_t = fals
183184
for u in MTK.get_unknowns(conssys)
184185
x = MTK.operation(u)
185186
t = only(arguments(u))
186-
MTK.symbolic_type(t) === NotSymbolic() && error("Provided specific time constraint in a free final time problem. This is not supported by the JuMP/InfiniteOpt collocation solvers. The offending variable is $u.")
187+
MTK.symbolic_type(t) === NotSymbolic() &&
188+
error("Provided specific time constraint in a free final time problem. This is not supported by the JuMP/InfiniteOpt collocation solvers. The offending variable is $u.")
187189
end
188190
end
189191

@@ -211,13 +213,15 @@ function substitute_jump_vars(model, sys, pmap, exprs)
211213
U = model[:U]
212214
V = model[:V]
213215
# for variables like x(t)
214-
whole_interval_map = Dict([[v => U[i] for (i, v) in enumerate(sts)]; [v => V[i] for (i, v) in enumerate(cts)]])
216+
whole_interval_map = Dict([[v => U[i] for (i, v) in enumerate(sts)];
217+
[v => V[i] for (i, v) in enumerate(cts)]])
215218
exprs = map(c -> Symbolics.substitute(c, whole_interval_map), exprs)
216219

217220
# for variables like x(1.0)
218221
x_ops = [MTK.operation(MTK.unwrap(st)) for st in sts]
219222
c_ops = [MTK.operation(MTK.unwrap(ct)) for ct in cts]
220-
fixed_t_map = Dict([[x_ops[i] => U[i] for i in 1:length(U)]; [c_ops[i] => V[i] for i in 1:length(V)]])
223+
fixed_t_map = Dict([[x_ops[i] => U[i] for i in 1:length(U)];
224+
[c_ops[i] => V[i] for i in 1:length(V)]])
221225
exprs = map(c -> Symbolics.substitute(c, fixed_t_map), exprs)
222226

223227
exprs = map(c -> Symbolics.substitute(c, Dict(pmap)), exprs)
@@ -236,11 +240,11 @@ function add_infopt_solve_constraints!(model::InfiniteModel, sys, pmap; is_free_
236240

237241
diff_eqs = substitute_jump_vars(model, sys, pmap, diff_equations(sys))
238242
diff_eqs = map(e -> Symbolics.substitute(e, diffsubmap), diff_eqs)
239-
@constraint(model, D[i = 1:length(diff_eqs)], diff_eqs[i].lhs == tₛ * diff_eqs[i].rhs)
243+
@constraint(model, D[i = 1:length(diff_eqs)], diff_eqs[i].lhs==tₛ * diff_eqs[i].rhs)
240244

241245
# Algebraic equations
242246
alg_eqs = substitute_jump_vars(model, sys, pmap, alg_equations(sys))
243-
@constraint(model, A[i = 1:length(alg_eqs)], alg_eqs[i].lhs == alg_eqs[i].rhs)
247+
@constraint(model, A[i = 1:length(alg_eqs)], alg_eqs[i].lhs==alg_eqs[i].rhs)
244248
end
245249

246250
function add_jump_solve_constraints!(prob, tableau; is_free_t = false)
@@ -283,11 +287,11 @@ function add_jump_solve_constraints!(prob, tableau; is_free_t = false)
283287
for (i, h) in enumerate(c)
284288
ΔU = @view ΔUs[i, :]
285289
Uₙ = U + ΔU * dt
286-
@constraint(model, [j = 1:nᵤ], K[i, j](τ) == tₛ * f(Uₙ, V, p, τ + h * dt)[j],
287-
DomainRestrictions(t => τ + h*dt), base_name="solve_K()")
290+
@constraint(model, [j = 1:nᵤ], K[i, j](τ)==tₛ * f(Uₙ, V, p, τ + h * dt)[j],
291+
DomainRestrictions(t => τ + h * dt), base_name="solve_K()")
288292
end
289-
@constraint(model, [n = 1:nᵤ], U[n](τ) + ΔU_tot[n] == U[n](τ + dt),
290-
DomainRestrictions(t => τ), base_name="solve_U()")
293+
@constraint(model, [n = 1:nᵤ], U[n](τ) + ΔU_tot[n]==U[n](τ + dt),
294+
DomainRestrictions(t => τ), base_name="solve_U()")
291295
end
292296
end
293297
end

src/systems/optimal_control_interface.jl

+13-9
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ function SciMLBase.ControlFunction{iip, specialize}(sys::ODESystem,
3939
inputs = unbound_inputs(sys),
4040
disturbance_inputs = disturbances(sys);
4141
version = nothing, tgrad = false,
42-
jac = false, controljac = false,
42+
jac = false, controljac = false,
4343
p = nothing, t = nothing,
4444
eval_expression = false,
4545
sparse = false, simplify = false,
@@ -52,8 +52,8 @@ function SciMLBase.ControlFunction{iip, specialize}(sys::ODESystem,
5252
initialization_data = nothing,
5353
cse = true,
5454
kwargs...) where {iip, specialize}
55-
56-
(f), _, _ = generate_control_function(sys, inputs, disturbance_inputs; eval_module, cse, kwargs...)
55+
(f), _, _ = generate_control_function(
56+
sys, inputs, disturbance_inputs; eval_module, cse, kwargs...)
5757

5858
if tgrad
5959
tgrad_gen = generate_tgrad(sys, dvs, ps;
@@ -113,7 +113,7 @@ function SciMLBase.ControlFunction{iip, specialize}(sys::ODESystem,
113113
W_prototype = nothing
114114
controljac_prototype = nothing
115115
end
116-
116+
117117
ControlFunction{iip, specialize}(f;
118118
sys = sys,
119119
jac = _jac === nothing ? nothing : _jac,
@@ -170,15 +170,19 @@ function process_tspan(tspan, dt, steps)
170170
is_free_time = false
171171
if isnothing(dt) && isnothing(steps)
172172
error("Must provide either the dt or the number of intervals to the collocation solvers (JuMP, InfiniteOpt, CasADi).")
173-
elseif symbolic_type(tspan[1]) === ScalarSymbolic() || symbolic_type(tspan[2]) === ScalarSymbolic()
174-
isnothing(steps) && error("Free final time problems require specifying the number of steps, rather than dt.")
175-
isnothing(dt) || @warn "Specified dt for free final time problem. This will be ignored; dt will be determined by the number of timesteps."
173+
elseif symbolic_type(tspan[1]) === ScalarSymbolic() ||
174+
symbolic_type(tspan[2]) === ScalarSymbolic()
175+
isnothing(steps) &&
176+
error("Free final time problems require specifying the number of steps, rather than dt.")
177+
isnothing(dt) ||
178+
@warn "Specified dt for free final time problem. This will be ignored; dt will be determined by the number of timesteps."
176179

177180
return steps, true
178181
else
179-
isnothing(steps) || @warn "Specified number of steps for problem with concrete tspan. This will be ignored; number of steps will be determined by dt."
182+
isnothing(steps) ||
183+
@warn "Specified number of steps for problem with concrete tspan. This will be ignored; number of steps will be determined by dt."
180184

181-
return length(tspan[1]:dt:tspan[2]), false
185+
return length(tspan[1]:dt:tspan[2]), false
182186
end
183187
end
184188

test/extensions/jump_control.jl

+30-26
Original file line numberDiff line numberDiff line change
@@ -79,35 +79,37 @@ end
7979
@testset "Linear systems" begin
8080
function is_bangbang(input_sol, lbounds, ubounds, rtol = 1e-4)
8181
bangbang = true
82-
for v in 1:length(input_sol.u[1]) - 1
83-
all(i -> (i[v], bounds[v]; rtol) || (i[v], bounds[u]; rtol), input_sol.u) || (bangbang = false)
82+
for v in 1:(length(input_sol.u[1]) - 1)
83+
all(i -> (i[v], bounds[v]; rtol) || (i[v], bounds[u]; rtol), input_sol.u) ||
84+
(bangbang = false)
8485
end
8586
bangbang
8687
end
8788

8889
# Double integrator
8990
t = M.t_nounits
9091
D = M.D_nounits
91-
@variables x(..) [bounds = (0., 0.25)] v(..)
92-
@variables u(t) [bounds = (-1., 1.), input = true]
92+
@variables x(..) [bounds = (0.0, 0.25)] v(..)
93+
@variables u(t) [bounds = (-1.0, 1.0), input = true]
9394
constr = [v(1.0) ~ 0.0]
9495
cost = [-x(1.0)] # Maximize the final distance.
95-
@named block = ODESystem([D(x(t)) ~ v(t), D(v(t)) ~ u], t; costs = cost, constraints = constr)
96-
block, input_idxs = structural_simplify(block, ([u],[]))
96+
@named block = ODESystem(
97+
[D(x(t)) ~ v(t), D(v(t)) ~ u], t; costs = cost, constraints = constr)
98+
block, input_idxs = structural_simplify(block, ([u], []))
9799

98-
u0map = [x(t) => 0., v(t) => 0.]
99-
tspan = (0., 1.)
100-
parammap = [u => 0.]
100+
u0map = [x(t) => 0.0, v(t) => 0.0]
101+
tspan = (0.0, 1.0)
102+
parammap = [u => 0.0]
101103
jprob = JuMPControlProblem(block, u0map, tspan, parammap; dt = 0.01)
102104
jsol = solve(jprob, Ipopt.Optimizer, :Verner8)
103105
# Linear systems have bang-bang controls
104-
@test is_bangbang(jsol.input_sol, [-1.], [1.])
106+
@test is_bangbang(jsol.input_sol, [-1.0], [1.0])
105107
# Test reached final position.
106108
@test (jsol.sol.u[end][1], 0.25, rtol = 1e-5)
107109

108110
iprob = InfiniteOptControlProblem(block, u0map, tspan, parammap; dt = 0.01)
109111
isol = solve(iprob, Ipopt.Optimizer; silent = true)
110-
@test is_bangbang(isol.input_sol, [-1.], [1.])
112+
@test is_bangbang(isol.input_sol, [-1.0], [1.0])
111113
@test (isol.sol.u[end][1], 0.25, rtol = 1e-5)
112114

113115
###################
@@ -118,21 +120,21 @@ end
118120
@parameters b c μ s ν
119121

120122
tspan = (0, 4)
121-
eqs = [D(w(t)) ~ -μ*w(t) + b*s*α*w(t),
122-
D(q(t)) ~ -ν*q(t) + c*(1 - α)*s*w(t)]
123+
eqs = [D(w(t)) ~ -μ * w(t) + b * s * α * w(t),
124+
D(q(t)) ~ -ν * q(t) + c * (1 - α) * s * w(t)]
123125
costs = [-q(tspan[2])]
124-
126+
125127
@named beesys = ODESystem(eqs, t; costs)
126-
beesys, input_idxs = structural_simplify(beesys, ([α],[]))
128+
beesys, input_idxs = structural_simplify(beesys, ([α], []))
127129
u0map = [w(t) => 40, q(t) => 2]
128130
pmap = [b => 1, c => 1, μ => 1, s => 1, ν => 1, α => 1]
129131

130132
jprob = JuMPControlProblem(beesys, u0map, tspan, pmap, dt = 0.01)
131133
jsol = solve(jprob, Ipopt.Optimizer, :Tsitouras5)
132-
@test is_bangbang(jsol.input_sol, [0.], [1.])
134+
@test is_bangbang(jsol.input_sol, [0.0], [1.0])
133135
iprob = InfiniteOptControlProblem(beesys, u0map, tspan, pmap, dt = 0.01)
134136
isol = solve(jprob, Ipopt.Optimizer, :Tsitouras5)
135-
@test is_bangbang(isol.input_sol, [0.], [1.])
137+
@test is_bangbang(isol.input_sol, [0.0], [1.0])
136138
end
137139

138140
@testset "Rocket launch" begin
@@ -144,18 +146,20 @@ end
144146
drag(h, v) = D_c * v^2 * exp(-h_c * (h - h₀) / h₀)
145147
gravity(h) = g₀ * (h₀ / h)
146148

147-
eqs = [D(h(t)) ~ v(t),
148-
D(v(t)) ~ (T(t) - drag(h(t), v(t))) / m(t) - gravity(h(t)),
149-
D(m(t)) ~ -T(t) / c]
149+
eqs = [D(h(t)) ~ v(t),
150+
D(v(t)) ~ (T(t) - drag(h(t), v(t))) / m(t) - gravity(h(t)),
151+
D(m(t)) ~ -T(t) / c]
150152

151-
(ts, te) = (0., 0.2)
153+
(ts, te) = (0.0, 0.2)
152154
costs = [-h(te)]
153155
constraints = [T(te) ~ 0]
154156
@named rocket = ODESystem(eqs, t; costs, constraints)
155157
rocket, input_idxs = structural_simplify(rocket, ([T(t)], []))
156158

157159
u0map = [h(t) => h₀, m(t) => m₀, v(t) => 0]
158-
pmap = [g₀ => 1, m₀ => 1.0, h_c => 500, c => 0.5*√(g₀*h₀), D_c => 0.5 * 620 * m₀/g₀, Tₘ => 3.5*g₀*m₀, T(t) => 0., h₀ => 1, m_c => 0.6]
160+
pmap = [
161+
g₀ => 1, m₀ => 1.0, h_c => 500, c => 0.5 * (g₀ * h₀), D_c => 0.5 * 620 * m₀ / g₀,
162+
Tₘ => 3.5 * g₀ * m₀, T(t) => 0.0, h₀ => 1, m_c => 0.6]
159163
jprob = JuMPControlProblem(rocket, u0map, (ts, te), pmap; dt = 0.005, cse = false)
160164
jsol = solve(jprob, Ipopt.Optimizer, :RadauIA3)
161165
@test jsol.sol.u[end][1] > 1.012
@@ -165,17 +169,17 @@ end
165169
t = M.t_nounits
166170
D = M.D_nounits
167171

168-
@variables x(..) u(..) [input = true, bounds = (0,1)]
172+
@variables x(..) u(..) [input = true, bounds = (0, 1)]
169173
@parameters tf
170-
eqs = [D(x(t)) ~ -2 + 0.5*u(t)]
174+
eqs = [D(x(t)) ~ -2 + 0.5 * u(t)]
171175
# Integral cost function
172-
costs = [-(x(t)-u(t)), -x(tf)]
176+
costs = [-(x(t) - u(t)), -x(tf)]
173177
consolidate(u) = u[1] + u[2]
174178
@named rocket = ODESystem(eqs, t; costs, consolidate)
175179
rocket, input_idxs = structural_simplify(rocket, ([u(t)], []))
176180

177181
u0map = [x(t) => 17.5]
178-
pmap = [u(t) => 0., tf => 8]
182+
pmap = [u(t) => 0.0, tf => 8]
179183
jprob = JuMPControlProblem(rocket, u0map, (0, tf), pmap; steps = 201)
180184
jsol = solve(jprob, Ipopt.Optimizer, :Tsitouras5)
181185
@test isapprox(jsol.sol.t[end], 10.0, rtol = 1e-3)

0 commit comments

Comments
 (0)