Skip to content

Commit 86aa145

Browse files
Merge pull request #883 from AayushSabharwal/as/scc-remake
feat: add `remake` for `SCCNonlinearProblem`
2 parents 55c8717 + 844ebfd commit 86aa145

File tree

2 files changed

+111
-2
lines changed

2 files changed

+111
-2
lines changed

src/remake.jl

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,52 @@ function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p
508508
end
509509
end
510510

511+
"""
512+
remake(prob::SCCNonlinearProblem; u0 = missing, p = missing, probs = missing,
513+
parameters_alias = prob.parameters_alias, sys = missing, explicitfuns! = missing)
514+
515+
Remake the given `SCCNonlinearProblem`. `u0` is the state vector for the entire problem,
516+
which will be chunked appropriately and used to `remake` the individual subproblems. `p`
517+
is the parameter object for `prob`. If `parameters_alias`, the same parameter object will be
518+
used to `remake` the individual subproblems. Otherwise if `p !== missing`, this function will
519+
error and require that `probs` be specified. `probs` is the collection of subproblems. Even if
520+
`probs` is explicitly specified, the value of `u0` provided to `remake` will be used to
521+
override the values in `probs`. `sys` is the index provider for the full system.
522+
"""
523+
function remake(prob::SCCNonlinearProblem; u0 = missing, p = missing, probs = missing,
524+
parameters_alias = prob.parameters_alias, sys = missing,
525+
interpret_symbolicmap = true, use_defaults = false, explicitfuns! = missing)
526+
if p !== missing && !parameters_alias && probs === missing
527+
throw(ArgumentError("`parameters_alias` is `false` for the given `SCCNonlinearProblem`. Please provide the subproblems using the keyword `probs` with the parameters updated appropriately in each."))
528+
end
529+
newu0, newp = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults,
530+
indp = sys === missing ? prob.full_index_provider : sys)
531+
if probs === missing
532+
probs = prob.probs
533+
end
534+
offset = 0
535+
if u0 !== missing || p !== missing && parameters_alias
536+
probs = map(probs) do subprob
537+
subprob = if parameters_alias
538+
remake(subprob;
539+
u0 = newu0[(offset + 1):(offset + length(state_values(subprob)))],
540+
p = newp)
541+
else
542+
remake(subprob;
543+
u0 = newu0[(offset + 1):(offset + length(state_values(subprob)))])
544+
end
545+
offset += length(state_values(subprob))
546+
return subprob
547+
end
548+
end
549+
if sys === missing
550+
sys = prob.full_index_provider
551+
end
552+
return SCCNonlinearProblem{
553+
typeof(probs), typeof(explicitfuns!), typeof(sys), typeof(newp)}(
554+
probs, explicitfuns!, sys, newp, parameters_alias)
555+
end
556+
511557
function varmap_has_var(varmap, var)
512558
haskey(varmap, var) || hasname(var) && haskey(varmap, getname(var))
513559
end
@@ -737,11 +783,12 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}, t0)
737783
end
738784

739785
function updated_u0_p(
740-
prob, u0, p, t0 = nothing; interpret_symbolicmap = true, use_defaults = false)
786+
prob, u0, p, t0 = nothing; interpret_symbolicmap = true,
787+
use_defaults = false, indp = has_sys(prob.f) ? prob.f.sys : nothing)
741788
if u0 === missing && p === missing
742789
return state_values(prob), parameter_values(prob)
743790
end
744-
if !has_sys(prob.f)
791+
if indp === nothing
745792
if interpret_symbolicmap && eltype(p) !== Union{} && eltype(p) <: Pair
746793
throw(ArgumentError("This problem does not support symbolic maps with " *
747794
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *

test/downstream/modelingtoolkit_remake.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,24 @@ discu0 = Dict([u0..., x(k - 1) => 0.0, y(k - 1) => 0.0, z(k - 1) => 0.0])
7373
push!(syss, discsys)
7474
push!(probs, DiscreteProblem(fn, getindex.((discu0,), unknowns(discsys)), (0, 10), ps))
7575

76+
# TODO: Rewrite this example when the MTK codegen is merged
77+
@named sys1 = NonlinearSystem(
78+
[0 ~ x^3 * β + y^3 * ρ - σ, 0 ~ x^2 + 2x * y + y^2], [x, y], [σ, β, ρ])
79+
sys1 = complete(sys1)
80+
@named sys2 = NonlinearSystem([0 ~ z^2 - 4z + 4], [z], [])
81+
sys2 = complete(sys2)
82+
@named fullsys = NonlinearSystem(
83+
[0 ~ x^3 * β + y^3 * ρ - σ, 0 ~ x^2 + 2x * y + y^2, 0 ~ z^2 - 4z + 4],
84+
[x, y, z], [σ, β, ρ])
85+
fullsys = complete(fullsys)
86+
87+
prob1 = NonlinearProblem(sys1, u0, p)
88+
prob2 = NonlinearProblem(sys2, u0, prob1.p)
89+
sccprob = SCCNonlinearProblem(
90+
[prob1, prob2], [Returns(nothing), Returns(nothing)], fullsys, prob1.p, true)
91+
push!(syss, fullsys)
92+
push!(probs, sccprob)
93+
7694
for (sys, prob) in zip(syss, probs)
7795
@test parameter_values(prob) isa ModelingToolkit.MTKParameters
7896

@@ -274,3 +292,47 @@ end
274292
@test_throws SciMLBase.CyclicDependencyError remake(
275293
prob; u0 = [x => 2y + p, y => q + 3], p = [p => x + y, q => p + 3])
276294
end
295+
296+
@testset "SCCNonlinearProblem" begin
297+
@named sys1 = NonlinearSystem(
298+
[0 ~ x^3 * β + y^3 * ρ - σ, 0 ~ x^2 + 2x * y + y^2], [x, y], [σ, β, ρ])
299+
sys1 = complete(sys1)
300+
@named sys2 = NonlinearSystem([0 ~ z^2 - 4z + 4], [z], [])
301+
sys2 = complete(sys2)
302+
@named fullsys = NonlinearSystem(
303+
[0 ~ x^3 * β + y^3 * ρ - σ, 0 ~ x^2 + 2x * y + y^2, 0 ~ z^2 - 4z + 4],
304+
[x, y, z], [σ, β, ρ])
305+
fullsys = complete(fullsys)
306+
307+
u0 = [x => 1.0,
308+
y => 0.0,
309+
z => 0.0]
310+
311+
p ==> 28.0,
312+
ρ => 10.0,
313+
β => 8 / 3]
314+
315+
prob1 = NonlinearProblem(sys1, u0, p)
316+
prob2 = NonlinearProblem(sys2, u0, prob1.p)
317+
sccprob = SCCNonlinearProblem(
318+
[prob1, prob2], [Returns(nothing), Returns(nothing)], fullsys, prob1.p, true)
319+
320+
sccprob2 = remake(sccprob; u0 = 2ones(3))
321+
@test state_values(sccprob2) 2ones(3)
322+
@test sccprob2.probs[1].u0 2ones(2)
323+
@test sccprob2.probs[2].u0 2ones(1)
324+
325+
sccprob3 = remake(sccprob; p ==> 2.0])
326+
@test sccprob3.parameter_object === sccprob3.probs[1].p
327+
@test sccprob3.parameter_object === sccprob3.probs[2].p
328+
329+
@test_throws ["parameters_alias", "SCCNonlinearProblem"] remake(
330+
sccprob; parameters_alias = false, p ==> 2.0])
331+
332+
newp = remake_buffer(sys1, prob1.p, [σ], [3.0])
333+
sccprob4 = remake(sccprob; parameters_alias = false, p = newp,
334+
probs = [remake(prob1; p ==> 3.0]), prob2])
335+
@test !sccprob4.parameters_alias
336+
@test sccprob4.parameter_object !== sccprob4.probs[1].p
337+
@test sccprob4.parameter_object !== sccprob4.probs[2].p
338+
end

0 commit comments

Comments
 (0)