From f32a32d275854b11b7c7cf6cc6769d8f1dfd3a43 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Tue, 8 Jul 2025 15:40:51 +0200 Subject: [PATCH] Extend multishoot to multidimensional NeuralODE --- src/multiple_shooting.jl | 27 ++-- test/multiple_shoot_tests.jl | 305 +++++++++++++++++++---------------- 2 files changed, 177 insertions(+), 155 deletions(-) diff --git a/src/multiple_shooting.jl b/src/multiple_shooting.jl index ab2def1b29..fa20b32eaf 100644 --- a/src/multiple_shooting.jl +++ b/src/multiple_shooting.jl @@ -36,7 +36,8 @@ Arguments: function multiple_shoot(p, ode_data, tsteps, prob::ODEProblem, loss_function::F, continuity_loss::C, solver::SciMLBase.AbstractODEAlgorithm, group_size::Integer; continuity_term::Real = 100, kwargs...) where {F, C} - datasize = size(ode_data, 2) + datasize = size(ode_data, ndims(ode_data)) + griddims = ntuple(_ -> Colon(), ndims(ode_data) - 1) if group_size < 2 || group_size > datasize throw(DomainError(group_size, "group_size can't be < 2 or > number of data points")) @@ -48,7 +49,7 @@ function multiple_shoot(p, ode_data, tsteps, prob::ODEProblem, loss_function::F, # Multiple shooting predictions sols = [solve( remake(prob; p, tspan = (tsteps[first(rg)], tsteps[last(rg)]), - u0 = ode_data[:, first(rg)]), + u0 = ode_data[griddims..., first(rg)]), solver; saveat = tsteps[rg], kwargs...) for rg in ranges] @@ -61,15 +62,15 @@ function multiple_shoot(p, ode_data, tsteps, prob::ODEProblem, loss_function::F, # Calculate multiple shooting loss loss = 0 for (i, rg) in enumerate(ranges) - u = ode_data[:, rg] - û = group_predictions[i] + u = ode_data[griddims..., rg] + û = group_predictions[i][griddims..., :] loss += loss_function(u, û) if i > 1 # Ensure continuity between last state in previous prediction # and current initial condition in ode_data loss += continuity_term * - continuity_loss(group_predictions[i - 1][:, end], u[:, 1]) + continuity_loss(group_predictions[i - 1][griddims..., end], u[griddims..., 1]) end end @@ -121,16 +122,18 @@ function multiple_shoot(p, ode_data, tsteps, ensembleprob::EnsembleProblem, ensemblealg::SciMLBase.BasicEnsembleAlgorithm, loss_function::F, continuity_loss::C, solver::SciMLBase.AbstractODEAlgorithm, group_size::Integer; continuity_term::Real = 100, kwargs...) where {F, C} - datasize = size(ode_data, 2) + ntraj = size(ode_data, ndims(ode_data)) + datasize = size(ode_data, ndims(ode_data)-1) + griddims = ntuple(_ -> Colon(), ndims(ode_data) - 2) prob = ensembleprob.prob if group_size < 2 || group_size > datasize throw(DomainError(group_size, "group_size can't be < 2 or > number of data points")) end - @assert ndims(ode_data)==3 "ode_data must have three dimension: `size(ode_data) = (problem_dimension,length(tsteps),trajectories)" - @assert size(ode_data, 2) == length(tsteps) - @assert size(ode_data, 3) == kwargs[:trajectories] + @assert ndims(ode_data)>=3 "ode_data must have at least three dimension: `size(ode_data) = (problem_dimension,length(tsteps),trajectories)" + @assert datasize == length(tsteps) + @assert ntraj == kwargs[:trajectories] # Get ranges that partition data to groups of size group_size ranges = group_ranges(datasize, group_size) @@ -140,7 +143,7 @@ function multiple_shoot(p, ode_data, tsteps, ensembleprob::EnsembleProblem, rg -> begin newprob = remake(prob; p = p, tspan = (tsteps[first(rg)], tsteps[last(rg)])) function prob_func(prob, i, repeat) - remake(prob; u0 = ode_data[:, first(rg), i]) + remake(prob; u0 = ode_data[griddims..., first(rg), i]) end newensembleprob = EnsembleProblem( newprob, prob_func, ensembleprob.output_func, ensembleprob.reduction, @@ -158,7 +161,7 @@ function multiple_shoot(p, ode_data, tsteps, ensembleprob::EnsembleProblem, loss = 0 for (i, rg) in enumerate(ranges) û = group_predictions[i] - u = ode_data[:, rg, :] # trajectories are at dims 3 + u = ode_data[griddims..., rg, :] # trajectories are at dims 3 # just summing up losses for all trajectories # but other alternatives might be considered @@ -168,7 +171,7 @@ function multiple_shoot(p, ode_data, tsteps, ensembleprob::EnsembleProblem, # Ensure continuity between last state in previous prediction # and current initial condition in ode_data loss += continuity_term * - continuity_loss(group_predictions[i - 1][:, end, :], u[:, 1, :]) + continuity_loss(group_predictions[i - 1][griddims..., end, :], u[griddims..., 1, :]) end end diff --git a/test/multiple_shoot_tests.jl b/test/multiple_shoot_tests.jl index 1fb5fa5b8e..057f5ab474 100644 --- a/test/multiple_shoot_tests.jl +++ b/test/multiple_shoot_tests.jl @@ -12,148 +12,167 @@ @test_throws DomainError group_ranges(10, 1) @test_throws DomainError group_ranges(10, 11) - ## Define initial conditions and time steps - datasize = 30 - u0 = Float32[2.0, 0.0] - tspan = (0.0f0, 5.0f0) - tsteps = range(tspan[1], tspan[2]; length = datasize) - - # Get the data - function trueODEfunc(du, u, p, t) - true_A = [-0.1 2.0; -2.0 -0.1] - du .= ((u .^ 3)'true_A)' + # Test configurations + test_configs = [ + ( + name = "Vector Test Config", + u0 = Float32[2.0, 0.0], + ode_func = (du, u, p, t) -> (du .= ((u .^ 3)'*[-0.1 2.0; -2.0 -0.1])'), + nn = Chain(x -> x .^ 3, Dense(2 => 16, tanh), Dense(16 => 2)), + u0s_ensemble = [Float32[2.0, 0.0], Float32[3.0, 1.0]] + ), + ( + name = "Multi-D Test Config", + u0 = Float32[2.0 0.0; 1.0 1.5; 0.5 -1.0], + ode_func = (du, u, p, t) -> (du .= ((u .^ 3).*[-0.01 0.02; -0.02 -0.01; 0.01 -0.05])), + nn = Chain(x -> x .^ 3, Dense(3 => 3, tanh)), + u0s_ensemble = [Float32[2.0 0.0; 1.0 1.5; 0.5 -1.0], Float32[3.0 1.0; 2.0 0.5; 1.5 -0.5]] + ) + ] + + for config in test_configs + @info "Running tests for: $(config.name)" + + ## Define initial conditions and time steps + datasize = 30 + u0 = config.u0 + tspan = (0.0f0, 5.0f0) + tsteps = range(tspan[1], tspan[2]; length = datasize) + + # Get the data + trueODEfunc = config.ode_func + prob_trueode = ODEProblem(trueODEfunc, u0, tspan) + ode_data = Array(solve(prob_trueode, Tsit5(); saveat = tsteps)) + + # Define the Neural Network + nn = config.nn + p_init, st = Lux.setup(rng, nn) + p_init = ComponentArray(p_init) + + neuralode = NeuralODE(nn, tspan, Tsit5(); saveat = tsteps) + prob_node = ODEProblem((u, p, t) -> first(nn(u, p, st)), u0, tspan, p_init) + + predict_single_shooting(p) = Array(first(neuralode(u0, p, st))) + + # Define loss function + loss_function(data, pred) = sum(abs2, data - pred) + + ## Evaluate Single Shooting + function loss_single_shooting(p) + pred = predict_single_shooting(p) + l = loss_function(ode_data, pred) + return l + end + + adtype = Optimization.AutoZygote() + optf = Optimization.OptimizationFunction((p, _) -> loss_single_shooting(p), adtype) + optprob = Optimization.OptimizationProblem(optf, p_init) + res_single_shooting = Optimization.solve(optprob, Adam(0.05); maxiters = 300) + + loss_ss = loss_single_shooting(res_single_shooting.minimizer) + @info "Single shooting loss: $(loss_ss)" + + ## Test Multiple Shooting + group_size = 3 + continuity_term = 200 + + function loss_multiple_shooting(p) + return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function, Tsit5(), + group_size; continuity_term, abstol = 1e-8, reltol = 1e-6)[1] # test solver kwargs + end + + adtype = Optimization.AutoZygote() + optf = Optimization.OptimizationFunction((p, _) -> loss_multiple_shooting(p), adtype) + optprob = Optimization.OptimizationProblem(optf, p_init) + res_ms = Optimization.solve(optprob, Adam(0.05); maxiters = 300) + + # Calculate single shooting loss with parameter from multiple_shoot training + loss_ms = loss_single_shooting(res_ms.minimizer) + println("Multiple shooting loss: $(loss_ms)") + @test loss_ms < 10loss_ss + + # Test with custom loss function + group_size = 4 + continuity_term = 50 + + function continuity_loss_abs2(û_end, u_0) + return sum(abs2, û_end - u_0) # using abs2 instead of default abs + end + + function loss_multiple_shooting_abs2(p) + return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function, + continuity_loss_abs2, Tsit5(), group_size; continuity_term)[1] + end + + adtype = Optimization.AutoZygote() + optf = Optimization.OptimizationFunction( + (p, _) -> loss_multiple_shooting_abs2(p), adtype) + optprob = Optimization.OptimizationProblem(optf, p_init) + res_ms_abs2 = Optimization.solve(optprob, Adam(0.05); maxiters = 300) + + loss_ms_abs2 = loss_single_shooting(res_ms_abs2.minimizer) + println("Multiple shooting loss with abs2: $(loss_ms_abs2)") + @test loss_ms_abs2 < loss_ss + + ## Test different SensitivityAlgorithm (default is InterpolatingAdjoint) + function loss_multiple_shooting_fd(p) + return multiple_shoot( + p, ode_data, tsteps, prob_node, loss_function, continuity_loss_abs2, + Tsit5(), group_size; continuity_term, sensealg = ForwardDiffSensitivity())[1] + end + + adtype = Optimization.AutoZygote() + optf = Optimization.OptimizationFunction((p, _) -> loss_multiple_shooting_fd(p), adtype) + optprob = Optimization.OptimizationProblem(optf, p_init) + res_ms_fd = Optimization.solve(optprob, Adam(0.05); maxiters = 300) + + # Calculate single shooting loss with parameter from multiple_shoot training + loss_ms_fd = loss_single_shooting(res_ms_fd.minimizer) + println("Multiple shooting loss with ForwardDiffSensitivity: $(loss_ms_fd)") + @test loss_ms_fd < 10loss_ss + + # Integration return codes `!= :Success` should return infinite loss. + # In this case, we trigger `retcode = :MaxIters` by setting the solver option `maxiters=1`. + loss_fail = multiple_shoot(p_init, ode_data, tsteps, prob_node, loss_function, + Tsit5(), datasize; maxiters = 1, verbose = false)[1] + @test loss_fail == Inf + + ## Test for DomainErrors + @test_throws DomainError multiple_shoot( + p_init, ode_data, tsteps, prob_node, loss_function, Tsit5(), 1) + @test_throws DomainError multiple_shoot( + p_init, ode_data, tsteps, prob_node, loss_function, Tsit5(), datasize + 1) + + ## Ensembles + u0s = config.u0s_ensemble + function prob_func(prob, i, repeat) + remake(prob; u0 = u0s[i]) + end + ensemble_prob = EnsembleProblem(prob_node; prob_func = prob_func) + ensemble_prob_trueODE = EnsembleProblem(prob_trueode; prob_func = prob_func) + ensemble_alg = EnsembleThreads() + trajectories = 2 + ode_data_ensemble = Array(solve( + ensemble_prob_trueODE, Tsit5(), ensemble_alg; trajectories, saveat = tsteps)) + + group_size = 3 + continuity_term = 200 + function loss_multiple_shooting_ens(p) + return multiple_shoot(p, ode_data_ensemble, tsteps, ensemble_prob, ensemble_alg, + loss_function, Tsit5(), group_size; continuity_term, + trajectories, abstol = 1e-8, reltol = 1e-6)[1] + end + + adtype = Optimization.AutoZygote() + optf = Optimization.OptimizationFunction( + (p, _) -> loss_multiple_shooting_ens(p), adtype) + optprob = Optimization.OptimizationProblem(optf, p_init) + res_ms_ensembles = Optimization.solve(optprob, Adam(0.05); maxiters = 300) + + loss_ms_ensembles = loss_single_shooting(res_ms_ensembles.minimizer) + + println("Multiple shooting loss with EnsembleProblem: $(loss_ms_ensembles)") + + @test loss_ms_ensembles < 10loss_ss end - prob_trueode = ODEProblem(trueODEfunc, u0, tspan) - ode_data = Array(solve(prob_trueode, Tsit5(); saveat = tsteps)) - - # Define the Neural Network - nn = Chain(x -> x .^ 3, Dense(2 => 16, tanh), Dense(16 => 2)) - p_init, st = Lux.setup(rng, nn) - p_init = ComponentArray(p_init) - - neuralode = NeuralODE(nn, tspan, Tsit5(); saveat = tsteps) - prob_node = ODEProblem((u, p, t) -> first(nn(u, p, st)), u0, tspan, p_init) - - predict_single_shooting(p) = Array(first(neuralode(u0, p, st))) - - # Define loss function - loss_function(data, pred) = sum(abs2, data - pred) - - ## Evaluate Single Shooting - function loss_single_shooting(p) - pred = predict_single_shooting(p) - l = loss_function(ode_data, pred) - return l - end - - adtype = Optimization.AutoZygote() - optf = Optimization.OptimizationFunction((p, _) -> loss_single_shooting(p), adtype) - optprob = Optimization.OptimizationProblem(optf, p_init) - res_single_shooting = Optimization.solve(optprob, Adam(0.05); maxiters = 300) - - loss_ss = loss_single_shooting(res_single_shooting.minimizer) - @info "Single shooting loss: $(loss_ss)" - - ## Test Multiple Shooting - group_size = 3 - continuity_term = 200 - - function loss_multiple_shooting(p) - return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function, Tsit5(), - group_size; continuity_term, abstol = 1e-8, reltol = 1e-6)[1] # test solver kwargs - end - - adtype = Optimization.AutoZygote() - optf = Optimization.OptimizationFunction((p, _) -> loss_multiple_shooting(p), adtype) - optprob = Optimization.OptimizationProblem(optf, p_init) - res_ms = Optimization.solve(optprob, Adam(0.05); maxiters = 300) - - # Calculate single shooting loss with parameter from multiple_shoot training - loss_ms = loss_single_shooting(res_ms.minimizer) - println("Multiple shooting loss: $(loss_ms)") - @test loss_ms < 10loss_ss - - # Test with custom loss function - group_size = 4 - continuity_term = 50 - - function continuity_loss_abs2(û_end, u_0) - return sum(abs2, û_end - u_0) # using abs2 instead of default abs - end - - function loss_multiple_shooting_abs2(p) - return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function, - continuity_loss_abs2, Tsit5(), group_size; continuity_term)[1] - end - - adtype = Optimization.AutoZygote() - optf = Optimization.OptimizationFunction( - (p, _) -> loss_multiple_shooting_abs2(p), adtype) - optprob = Optimization.OptimizationProblem(optf, p_init) - res_ms_abs2 = Optimization.solve(optprob, Adam(0.05); maxiters = 300) - - loss_ms_abs2 = loss_single_shooting(res_ms_abs2.minimizer) - println("Multiple shooting loss with abs2: $(loss_ms_abs2)") - @test loss_ms_abs2 < loss_ss - - ## Test different SensitivityAlgorithm (default is InterpolatingAdjoint) - function loss_multiple_shooting_fd(p) - return multiple_shoot( - p, ode_data, tsteps, prob_node, loss_function, continuity_loss_abs2, - Tsit5(), group_size; continuity_term, sensealg = ForwardDiffSensitivity())[1] - end - - adtype = Optimization.AutoZygote() - optf = Optimization.OptimizationFunction((p, _) -> loss_multiple_shooting_fd(p), adtype) - optprob = Optimization.OptimizationProblem(optf, p_init) - res_ms_fd = Optimization.solve(optprob, Adam(0.05); maxiters = 300) - - # Calculate single shooting loss with parameter from multiple_shoot training - loss_ms_fd = loss_single_shooting(res_ms_fd.minimizer) - println("Multiple shooting loss with ForwardDiffSensitivity: $(loss_ms_fd)") - @test loss_ms_fd < 10loss_ss - - # Integration return codes `!= :Success` should return infinite loss. - # In this case, we trigger `retcode = :MaxIters` by setting the solver option `maxiters=1`. - loss_fail = multiple_shoot(p_init, ode_data, tsteps, prob_node, loss_function, - Tsit5(), datasize; maxiters = 1, verbose = false)[1] - @test loss_fail == Inf - - ## Test for DomainErrors - @test_throws DomainError multiple_shoot( - p_init, ode_data, tsteps, prob_node, loss_function, Tsit5(), 1) - @test_throws DomainError multiple_shoot( - p_init, ode_data, tsteps, prob_node, loss_function, Tsit5(), datasize + 1) - - ## Ensembles - u0s = [Float32[2.0, 0.0], Float32[3.0, 1.0]] - function prob_func(prob, i, repeat) - remake(prob; u0 = u0s[i]) - end - ensemble_prob = EnsembleProblem(prob_node; prob_func = prob_func) - ensemble_prob_trueODE = EnsembleProblem(prob_trueode; prob_func = prob_func) - ensemble_alg = EnsembleThreads() - trajectories = 2 - ode_data_ensemble = Array(solve( - ensemble_prob_trueODE, Tsit5(), ensemble_alg; trajectories, saveat = tsteps)) - - group_size = 3 - continuity_term = 200 - function loss_multiple_shooting_ens(p) - return multiple_shoot(p, ode_data_ensemble, tsteps, ensemble_prob, ensemble_alg, - loss_function, Tsit5(), group_size; continuity_term, - trajectories, abstol = 1e-8, reltol = 1e-6)[1] # test solver kwargs - end - - adtype = Optimization.AutoZygote() - optf = Optimization.OptimizationFunction( - (p, _) -> loss_multiple_shooting_ens(p), adtype) - optprob = Optimization.OptimizationProblem(optf, p_init) - res_ms_ensembles = Optimization.solve(optprob, Adam(0.05); maxiters = 300) - - loss_ms_ensembles = loss_single_shooting(res_ms_ensembles.minimizer) - - println("Multiple shooting loss with EnsembleProblem: $(loss_ms_ensembles)") - - @test loss_ms_ensembles < 10loss_ss end