Skip to content

Commit dae2c41

Browse files
Merge pull request #884 from AayushSabharwal/as/scc-refactor
refactor: change `SCCNonlinearProblem` fields
2 parents 86aa145 + 1dcbd1f commit dae2c41

File tree

4 files changed

+30
-30
lines changed

4 files changed

+30
-30
lines changed

src/problems/nonlinear_problems.jl

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -462,28 +462,30 @@ Note that this example aliases the parameters together for a memory-reduced repr
462462
* `probs`: the collection of problems to solve
463463
* `explictfuns!`: the explicit functions for mutating the parameter set
464464
"""
465-
mutable struct SCCNonlinearProblem{uType, iip, P, E, I, Par} <:
465+
mutable struct SCCNonlinearProblem{uType, iip, P, E, F <: NonlinearFunction{iip}, Par} <:
466466
AbstractNonlinearProblem{uType, iip}
467467
probs::P
468468
explicitfuns!::E
469-
full_index_provider::I
470-
parameter_object::Par
469+
# NonlinearFunction with `f = Returns(nothing)`
470+
f::F
471+
p::Par
471472
parameters_alias::Bool
472473

473-
function SCCNonlinearProblem{P, E, I, Par}(
474-
probs::P, funs::E, indp::I, pobj::Par, alias::Bool) where {P, E, I, Par}
474+
function SCCNonlinearProblem{P, E, F, Par}(probs::P, funs::E, f::F, pobj::Par,
475+
alias::Bool) where {P, E, F <: NonlinearFunction, Par}
475476
u0 = mapreduce(
476477
state_values, vcat, probs; init = similar(state_values(first(probs)), 0))
477478
uType = typeof(u0)
478-
new{uType, false, P, E, I, Par}(probs, funs, indp, pobj, alias)
479+
new{uType, false, P, E, F, Par}(probs, funs, f, pobj, alias)
479480
end
480481
end
481482

482-
function SCCNonlinearProblem(probs, explicitfuns!, full_index_provider = nothing,
483-
parameter_object = nothing, parameters_alias = false)
483+
function SCCNonlinearProblem(probs, explicitfuns!, parameter_object = nothing,
484+
parameters_alias = false; kwargs...)
485+
f = NonlinearFunction{false}(Returns(nothing); kwargs...)
484486
return SCCNonlinearProblem{typeof(probs), typeof(explicitfuns!),
485-
typeof(full_index_provider), typeof(parameter_object)}(
486-
probs, explicitfuns!, full_index_provider, parameter_object, parameters_alias)
487+
typeof(f), typeof(parameter_object)}(
488+
probs, explicitfuns!, f, parameter_object, parameters_alias)
487489
end
488490

489491
function Base.getproperty(prob::SCCNonlinearProblem, name::Symbol)
@@ -496,10 +498,10 @@ function Base.getproperty(prob::SCCNonlinearProblem, name::Symbol)
496498
end
497499

498500
function SymbolicIndexingInterface.symbolic_container(prob::SCCNonlinearProblem)
499-
prob.full_index_provider
501+
prob.f
500502
end
501503
function SymbolicIndexingInterface.parameter_values(prob::SCCNonlinearProblem)
502-
prob.parameter_object
504+
prob.p
503505
end
504506
function SymbolicIndexingInterface.state_values(prob::SCCNonlinearProblem)
505507
mapreduce(
@@ -516,8 +518,8 @@ function SymbolicIndexingInterface.set_state!(prob::SCCNonlinearProblem, val, id
516518
end
517519

518520
function SymbolicIndexingInterface.set_parameter!(prob::SCCNonlinearProblem, val, idx)
519-
if prob.parameter_object !== nothing
520-
set_parameter!(prob.parameter_object, val, idx)
521+
if prob.p !== nothing
522+
set_parameter!(prob.p, val, idx)
521523
prob.parameters_alias && return
522524
end
523525
for scc in prob.probs

src/remake.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -526,8 +526,7 @@ function remake(prob::SCCNonlinearProblem; u0 = missing, p = missing, probs = mi
526526
if p !== missing && !parameters_alias && probs === missing
527527
throw(ArgumentError("`parameters_alias` is `false` for the given `SCCNonlinearProblem`. Please provide the subproblems using the keyword `probs` with the parameters updated appropriately in each."))
528528
end
529-
newu0, newp = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults,
530-
indp = sys === missing ? prob.full_index_provider : sys)
529+
newu0, newp = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)
531530
if probs === missing
532531
probs = prob.probs
533532
end
@@ -547,11 +546,10 @@ function remake(prob::SCCNonlinearProblem; u0 = missing, p = missing, probs = mi
547546
end
548547
end
549548
if sys === missing
550-
sys = prob.full_index_provider
549+
sys = prob.f.sys
551550
end
552-
return SCCNonlinearProblem{
553-
typeof(probs), typeof(explicitfuns!), typeof(sys), typeof(newp)}(
554-
probs, explicitfuns!, sys, newp, parameters_alias)
551+
return SCCNonlinearProblem(
552+
probs, explicitfuns!, newp, parameters_alias; sys)
555553
end
556554

557555
function varmap_has_var(varmap, var)
@@ -784,11 +782,11 @@ end
784782

785783
function updated_u0_p(
786784
prob, u0, p, t0 = nothing; interpret_symbolicmap = true,
787-
use_defaults = false, indp = has_sys(prob.f) ? prob.f.sys : nothing)
785+
use_defaults = false)
788786
if u0 === missing && p === missing
789787
return state_values(prob), parameter_values(prob)
790788
end
791-
if indp === nothing
789+
if prob.f.sys === nothing
792790
if interpret_symbolicmap && eltype(p) !== Union{} && eltype(p) <: Pair
793791
throw(ArgumentError("This problem does not support symbolic maps with " *
794792
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *

test/downstream/modelingtoolkit_remake.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ fullsys = complete(fullsys)
8787
prob1 = NonlinearProblem(sys1, u0, p)
8888
prob2 = NonlinearProblem(sys2, u0, prob1.p)
8989
sccprob = SCCNonlinearProblem(
90-
[prob1, prob2], [Returns(nothing), Returns(nothing)], fullsys, prob1.p, true)
90+
[prob1, prob2], [Returns(nothing), Returns(nothing)], prob1.p, true; sys = fullsys)
9191
push!(syss, fullsys)
9292
push!(probs, sccprob)
9393

@@ -315,16 +315,16 @@ end
315315
prob1 = NonlinearProblem(sys1, u0, p)
316316
prob2 = NonlinearProblem(sys2, u0, prob1.p)
317317
sccprob = SCCNonlinearProblem(
318-
[prob1, prob2], [Returns(nothing), Returns(nothing)], fullsys, prob1.p, true)
318+
[prob1, prob2], [Returns(nothing), Returns(nothing)], prob1.p, true; sys = fullsys)
319319

320320
sccprob2 = remake(sccprob; u0 = 2ones(3))
321321
@test state_values(sccprob2) 2ones(3)
322322
@test sccprob2.probs[1].u0 2ones(2)
323323
@test sccprob2.probs[2].u0 2ones(1)
324324

325325
sccprob3 = remake(sccprob; p ==> 2.0])
326-
@test sccprob3.parameter_object === sccprob3.probs[1].p
327-
@test sccprob3.parameter_object === sccprob3.probs[2].p
326+
@test sccprob3.p === sccprob3.probs[1].p
327+
@test sccprob3.p === sccprob3.probs[2].p
328328

329329
@test_throws ["parameters_alias", "SCCNonlinearProblem"] remake(
330330
sccprob; parameters_alias = false, p ==> 2.0])
@@ -333,6 +333,6 @@ end
333333
sccprob4 = remake(sccprob; parameters_alias = false, p = newp,
334334
probs = [remake(prob1; p ==> 3.0]), prob2])
335335
@test !sccprob4.parameters_alias
336-
@test sccprob4.parameter_object !== sccprob4.probs[1].p
337-
@test sccprob4.parameter_object !== sccprob4.probs[2].p
336+
@test sccprob4.p !== sccprob4.probs[1].p
337+
@test sccprob4.p !== sccprob4.probs[2].p
338338
end

test/downstream/problem_interface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ prob = SteadyStateProblem(osys, u0, ps)
367367
prob = NonlinearProblem(model, [])
368368
sccprob = SciMLBase.SCCNonlinearProblem([prob1, prob2, prob3],
369369
SciMLBase.Void{Any}.([explicitfun1, explicitfun2, explicitfun3]),
370-
model, copy(cache))
370+
copy(cache); sys = model)
371371

372372
for sym in [u, u..., u[2] + u[3], p * u[1] + u[2]]
373373
@test prob[sym] sccprob[sym]
@@ -384,7 +384,7 @@ prob = SteadyStateProblem(osys, u0, ps)
384384
end
385385
sccprob.ps[p] = 2.5
386386
@test sccprob.ps[p] 2.5
387-
@test sccprob.parameter_object[1] 2.5
387+
@test sccprob.p[1] 2.5
388388
for scc in sccprob.probs
389389
@test parameter_values(scc)[1] 2.5
390390
end

0 commit comments

Comments
 (0)