Skip to content

Commit 23bbbbc

Browse files
Re-enable BackTracking (#1761)
`BackTracking` as relaxation is now enabled again, with a thin wrapper to reject it when the residual gets worse. Upstream issue: SciML/OrdinaryDiffEq.jl#2442
1 parent b80a79a commit 23bbbbc

File tree

10 files changed

+133
-42
lines changed

10 files changed

+133
-42
lines changed

Manifest.toml

+3-9
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.10.4"
44
manifest_format = "2.0"
5-
project_hash = "c2cb085c326f61a96abd1a295e6fa775c585beba"
5+
project_hash = "a410a350a7b0c63bc6696029509aa68c14023275"
66

77
[[deps.ADTypes]]
88
git-tree-sha1 = "6778bcc27496dae5723ff37ee30af451db8b35fe"
@@ -1070,7 +1070,7 @@ uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
10701070
version = "1.2.0"
10711071

10721072
[[deps.NonlinearSolve]]
1073-
deps = ["ADTypes", "ArrayInterface", "ConcreteStructs", "DiffEqBase", "FastBroadcast", "FastClosures", "FiniteDiff", "ForwardDiff", "LazyArrays", "LineSearches", "LinearAlgebra", "LinearSolve", "MaybeInplace", "PrecompileTools", "Preferences", "Printf", "RecursiveArrayTools", "Reexport", "SciMLBase", "SimpleNonlinearSolve", "SparseArrays", "SparseDiffTools", "StaticArraysCore", "SymbolicIndexingInterface", "TimerOutputs"]
1073+
deps = ["ADTypes", "ArrayInterface", "ConcreteStructs", "DiffEqBase", "FastBroadcast", "FastClosures", "FiniteDiff", "ForwardDiff", "LazyArrays", "LineSearches", "LinearAlgebra", "LinearSolve", "MaybeInplace", "PrecompileTools", "Preferences", "Printf", "RecursiveArrayTools", "Reexport", "SciMLBase", "SimpleNonlinearSolve", "SparseArrays", "SparseDiffTools", "StaticArraysCore", "SymbolicIndexingInterface"]
10741074
git-tree-sha1 = "3adb1e5945b5a6b1eaee754077f25ccc402edd7f"
10751075
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
10761076
version = "3.13.1"
@@ -1302,7 +1302,7 @@ uuid = "295af30f-e4ad-537b-8983-00126c2a3abe"
13021302
version = "3.5.18"
13031303

13041304
[[deps.Ribasim]]
1305-
deps = ["Accessors", "Arrow", "BasicModelInterface", "CodecZstd", "ComponentArrays", "Configurations", "DBInterface", "DataInterpolations", "DataStructures", "Dates", "DiffEqCallbacks", "EnumX", "FiniteDiff", "ForwardDiff", "Graphs", "HiGHS", "IterTools", "JuMP", "Legolas", "LinearSolve", "Logging", "LoggingExtras", "MetaGraphsNext", "OrdinaryDiffEq", "PreallocationTools", "SQLite", "SciMLBase", "SparseArrays", "SparseConnectivityTracer", "StructArrays", "Tables", "TerminalLoggers", "TranscodingStreams"]
1305+
deps = ["Accessors", "Arrow", "BasicModelInterface", "CodecZstd", "ComponentArrays", "Configurations", "DBInterface", "DataInterpolations", "DataStructures", "Dates", "DiffEqCallbacks", "EnumX", "FiniteDiff", "ForwardDiff", "Graphs", "HiGHS", "IterTools", "JuMP", "Legolas", "LineSearches", "LinearSolve", "Logging", "LoggingExtras", "MetaGraphsNext", "OrdinaryDiffEq", "PreallocationTools", "SQLite", "SciMLBase", "SparseArrays", "SparseConnectivityTracer", "StructArrays", "Tables", "TerminalLoggers", "TranscodingStreams"]
13061306
path = "core"
13071307
uuid = "aac5e3d9-0b8f-4d4f-8241-b1a7a9632635"
13081308
version = "2024.10.0"
@@ -1669,12 +1669,6 @@ weakdeps = ["RecipesBase"]
16691669
[deps.TimeZones.extensions]
16701670
TimeZonesRecipesBaseExt = "RecipesBase"
16711671

1672-
[[deps.TimerOutputs]]
1673-
deps = ["ExprTools", "Printf"]
1674-
git-tree-sha1 = "5a13ae8a41237cff5ecf34f73eb1b8f42fff6531"
1675-
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
1676-
version = "0.5.24"
1677-
16781672
[[deps.TranscodingStreams]]
16791673
git-tree-sha1 = "d73336d81cafdc277ff45558bb7eaa2b04a8e472"
16801674
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"

core/Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b"
2424
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
2525
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
2626
Legolas = "741b9549-f6ed-4911-9fbf-4a1c0c97f0cd"
27+
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
2728
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
2829
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
2930
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
@@ -70,6 +71,7 @@ IOCapture = "0.2"
7071
IterTools = "1.4"
7172
JuMP = "1.15"
7273
Legolas = "0.5"
74+
LineSearches = "7"
7375
LinearSolve = "2.24"
7476
Logging = "<0.0.1, 1"
7577
LoggingExtras = "1"

core/ext/RibasimMakieExt.jl

+13-21
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
module RibasimMakieExt
22
using DataFrames: DataFrame
3-
using Makie: Figure, Axis, lines!, axislegend
3+
using Makie: Figure, Axis, scatterlines!, axislegend
44
using Ribasim: Ribasim, Model
55

66
function Ribasim.plot_basin_data!(model::Model, ax::Axis, column::Symbol)
77
basin_data = DataFrame(Ribasim.basin_table(model))
88
for node_id in unique(basin_data.node_id)
99
group = filter(:node_id => ==(node_id), basin_data)
10-
lines!(ax, group.time, getproperty(group, column); label = "Basin #$node_id")
10+
scatterlines!(ax, group.time, getproperty(group, column); label = "Basin #$node_id")
1111
end
1212

1313
axislegend(ax)
@@ -23,31 +23,23 @@ function Ribasim.plot_basin_data(model::Model)
2323
f
2424
end
2525

26-
function Ribasim.plot_flow!(
27-
model::Model,
28-
ax::Axis,
29-
edge_id::Int32;
30-
skip_conservative_out = false,
31-
)
26+
function Ribasim.plot_flow!(model::Model, ax::Axis, edge_metadata::Ribasim.EdgeMetadata)
3227
flow_data = DataFrame(Ribasim.flow_table(model))
33-
flow_data = filter(:edge_id => ==(edge_id), flow_data)
34-
first_row = first(flow_data)
35-
# Skip outflows of conservative nodes because these are the same as the inflows
36-
if skip_conservative_out &&
37-
Ribasim.NodeType.T(first_row.from_node_type) in Ribasim.conservative_nodetypes
38-
return nothing
39-
end
40-
label = "$(first_row.from_node_type) #$(first_row.from_node_id)$(first_row.to_node_type) #$(first_row.to_node_id)"
41-
lines!(ax, flow_data.time, flow_data.flow_rate; label)
28+
flow_data = filter(:edge_id => ==(edge_metadata.id), flow_data)
29+
label = "$(edge_metadata.edge[1])$(edge_metadata.edge[2])"
30+
scatterlines!(ax, flow_data.time, flow_data.flow_rate; label)
4231
return nothing
4332
end
4433

45-
function Ribasim.plot_flow(model::Model)
34+
function Ribasim.plot_flow(model::Model; skip_conservative_out = true)
4635
f = Figure()
4736
ax = Axis(f[1, 1]; xlabel = "time", ylabel = "flow rate [m³s⁻¹]")
48-
edge_ids = unique(Ribasim.flow_table(model).edge_id)
49-
for edge_id in edge_ids
50-
Ribasim.plot_flow!(model, ax, edge_id; skip_conservative_out = true)
37+
for edge_metadata in values(model.integrator.p.graph.edge_data)
38+
if skip_conservative_out &&
39+
edge_metadata.edge[1].type in Ribasim.conservative_nodetypes
40+
continue
41+
end
42+
Ribasim.plot_flow!(model, ax, edge_metadata)
5143
end
5244
axislegend(ax)
5345
f

core/src/Ribasim.jl

+11-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,15 @@ For more granular access, see:
1515
module Ribasim
1616

1717
# Algorithms for solving ODEs.
18-
using OrdinaryDiffEq: OrdinaryDiffEq, OrdinaryDiffEqRosenbrockAdaptiveAlgorithm, get_du
18+
using OrdinaryDiffEq:
19+
OrdinaryDiffEq,
20+
OrdinaryDiffEqRosenbrockAdaptiveAlgorithm,
21+
get_du,
22+
AbstractNLSolver,
23+
relax!,
24+
_compute_rhs!,
25+
calculate_residuals!
26+
using LineSearches: BackTracking
1927

2028
# Interface for defining and solving the ODE problem of the physical layer.
2129
using SciMLBase:
@@ -31,7 +39,8 @@ using SciMLBase:
3139
ODEProblem,
3240
ODESolution,
3341
VectorContinuousCallback,
34-
get_proposed_dt
42+
get_proposed_dt,
43+
DEIntegrator
3544

3645
# Automatically detecting the sparsity pattern of the Jacobian of water_balance!
3746
# through operator overloading

core/src/config.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ const algorithms = Dict{String, Type}(
230230
)
231231

232232
"Create an OrdinaryDiffEqAlgorithm from solver config"
233-
function algorithm(solver::Solver)::OrdinaryDiffEqAlgorithm
233+
function algorithm(solver::Solver; u0 = [])::OrdinaryDiffEqAlgorithm
234234
algotype = get(algorithms, solver.algorithm, nothing)
235235
if algotype === nothing
236236
options = join(keys(algorithms), ", ")
@@ -239,7 +239,9 @@ function algorithm(solver::Solver)::OrdinaryDiffEqAlgorithm
239239
end
240240
kwargs = Dict{Symbol, Any}()
241241
if algotype <: OrdinaryDiffEqNewtonAdaptiveAlgorithm
242-
kwargs[:nlsolve] = NLNewton(; relax = 0.1)
242+
kwargs[:nlsolve] = NLNewton(;
243+
relax = Ribasim.MonitoredBackTracking(; z_tmp = copy(u0), dz_tmp = copy(u0)),
244+
)
243245
end
244246
# not all algorithms support this keyword
245247
kwargs[:autodiff] = solver.autodiff

core/src/model.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ function Model(config_path::AbstractString)::Model
3737
end
3838

3939
function Model(config::Config)::Model
40-
alg = algorithm(config.solver)
4140
db_path = input_path(config, config.database)
4241
if !isfile(db_path)
4342
@error "Database file not found" db_path
@@ -109,6 +108,9 @@ function Model(config::Config)::Model
109108
u0 = ComponentVector{Float64}(; storage, integral)
110109
du0 = zero(u0)
111110

111+
# The Solver algorithm
112+
alg = algorithm(config.solver; u0)
113+
112114
# Synchronize level with storage
113115
set_current_basin_properties!(parameters.basin, u0, du0)
114116

core/src/read.jl

+7-3
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,8 @@ function Basin(db::DB, config::Config, graph::MetaGraph)::Basin
572572
error("Invalid Basin / profile table.")
573573
end
574574

575-
level_to_area = LinearInterpolation.(area, level; extrapolate = true)
575+
level_to_area =
576+
LinearInterpolation.(area, level; extrapolate = true, cache_parameters = true)
576577
storage_to_level = invert_integral.(level_to_area)
577578

578579
t_end = seconds_since(config.endtime, config.starttime)
@@ -921,6 +922,7 @@ function user_demand_static!(
921922
fill(first_row.return_factor, 2),
922923
return_factor_old.t;
923924
extrapolate = true,
925+
cache_parameters = true,
924926
)
925927
min_level[user_demand_idx] = first_row.min_level
926928

@@ -1026,8 +1028,10 @@ function UserDemand(db::DB, config::Config, graph::MetaGraph)::UserDemand
10261028
]
10271029
demand_from_timeseries = fill(false, n_user)
10281030
allocated = fill(Inf, n_user, n_priority)
1029-
return_factor =
1030-
[LinearInterpolation(zeros(2), trivial_timespan) for i in eachindex(node_ids)]
1031+
return_factor = [
1032+
LinearInterpolation(zeros(2), trivial_timespan; cache_parameters = true) for
1033+
i in eachindex(node_ids)
1034+
]
10311035
min_level = zeros(n_user)
10321036

10331037
# Process static table

core/src/solve.jl

+11
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,20 @@ function water_balance!(
5151
# Formulate du (controlled by PidControl)
5252
formulate_du_pid_controlled!(du, graph, pid_control)
5353

54+
# https://github.com/Deltares/Ribasim/issues/1705#issuecomment-2283293974
55+
stop_declining_negative_storage!(du, u)
56+
5457
return nothing
5558
end
5659

60+
function stop_declining_negative_storage!(du, u)
61+
for (i, s) in enumerate(u.storage)
62+
if s < 0
63+
du.storage[i] = max(du.storage[i], 0.0)
64+
end
65+
end
66+
end
67+
5768
function formulate_continuous_control!(du, p, t)::Nothing
5869
(; compound_variable, target_ref, func) = p.continuous_control
5970

core/src/util.jl

+79-3
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,18 @@ end
5353
Compute the area and level of a basin given its storage.
5454
"""
5555
function get_area_and_level(basin::Basin, state_idx::Int, storage::T)::Tuple{T, T} where {T}
56-
level = basin.storage_to_level[state_idx](max(storage, 0.0))
57-
area = basin.level_to_area[state_idx](level)
58-
56+
storage_to_level = basin.storage_to_level[state_idx]
57+
level_to_area = basin.level_to_area[state_idx]
58+
if storage >= 0
59+
level = storage_to_level(storage)
60+
else
61+
# Negative storage is not feasible and this yields a level
62+
# below the basin bottom, but this does yield usable gradients
63+
# for the non-linear solver
64+
bottom = first(level_to_area.t)
65+
level = bottom + derivative(storage_to_level, 0.0) * storage
66+
end
67+
area = level_to_area(level)
5968
return area, level
6069
end
6170

@@ -887,3 +896,70 @@ end
887896
(A::AbstractInterpolation)(t::GradientTracer) = t
888897
reduction_factor(x::GradientTracer, threshold::Real) = x
889898
relaxed_root(x::GradientTracer, threshold::Real) = x
899+
get_area_and_level(basin::Basin, state_idx::Int, storage::GradientTracer) = storage, storage
900+
stop_declining_negative_storage!(du, u::ComponentVector{<:GradientTracer}) = nothing
901+
902+
@kwdef struct MonitoredBackTracking{B, V}
903+
linesearch::B = BackTracking()
904+
dz_tmp::V = []
905+
z_tmp::V = []
906+
end
907+
908+
"""
909+
Compute the residual of the non-linear solver, i.e. a measure of the
910+
error in the solution to the implicit equation defined by the solver algorithm
911+
"""
912+
function residual(z, integrator, nlsolver, f)
913+
(; uprev, t, p, dt, opts, isdae) = integrator
914+
(; tmp, ztmp, γ, α, cache, method) = nlsolver
915+
(; ustep, atmp, tstep, k, invγdt, tstep, k, invγdt) = cache
916+
if isdae
917+
_uprev = get_dae_uprev(integrator, uprev)
918+
b, ustep2 =
919+
_compute_rhs!(tmp, ztmp, ustep, α, tstep, k, invγdt, p, _uprev, f::TF, z)
920+
else
921+
b, ustep2 =
922+
_compute_rhs!(tmp, ztmp, ustep, γ, α, tstep, k, invγdt, method, p, dt, f, z)
923+
end
924+
calculate_residuals!(
925+
atmp,
926+
b,
927+
uprev,
928+
ustep2,
929+
opts.abstol,
930+
opts.reltol,
931+
opts.internalnorm,
932+
t,
933+
)
934+
ndz = opts.internalnorm(atmp, t)
935+
return ndz
936+
end
937+
938+
"""
939+
MonitoredBackTracing is a thin wrapper of BackTracking, making sure that
940+
the BackTracking relaxation is rejected if it results in a residual increase
941+
"""
942+
function OrdinaryDiffEq.relax!(
943+
dz,
944+
nlsolver::AbstractNLSolver,
945+
integrator::DEIntegrator,
946+
f,
947+
linesearch::MonitoredBackTracking,
948+
)
949+
(; linesearch, dz_tmp, z_tmp) = linesearch
950+
951+
# Store step before relaxation
952+
@. dz_tmp = dz
953+
954+
# Apply relaxation and measure the residual change
955+
@. z_tmp = nlsolver.z + dz
956+
resid_before = residual(z_tmp, integrator, nlsolver, f)
957+
relax!(dz, nlsolver, integrator, f, linesearch)
958+
@. z_tmp = nlsolver.z + dz
959+
resid_after = residual(z_tmp, integrator, nlsolver, f)
960+
961+
# If the residual increased due to the relaxation, reject it
962+
if resid_after > resid_before
963+
@. dz = dz_tmp
964+
end
965+
end

core/test/main_test.jl

-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
@show backtrace
2525
end
2626
@test occursin("version in the TOML config file does not match", output)
27-
@test occursin("Info: Convergence bottlenecks in descending order of severity:", output)
2827
end
2928

3029
@testitem "main error logging" begin

0 commit comments

Comments
 (0)