From b107f0f09d9735ddc9478ba292e4613aea86b70f Mon Sep 17 00:00:00 2001 From: ChrisRackauckas-Claude Date: Tue, 10 Feb 2026 08:05:30 -0500 Subject: [PATCH] Add FlexUnits.jl extension for ODE solver compatibility Add DiffEqBaseFlexUnitsExt extension module that enables FlexUnits.jl quantities to work with DiffEqBase's adaptive ODE solvers. This mirrors the existing DiffEqBaseUnitfulExt pattern, providing: - value/unitfulvalue methods for FlexUnits.Quantity types - ODE_DEFAULT_NORM for scalar and array Quantity types - UNITLESS_ABS2 for unit-stripped squared absolute values - _rate_prototype for computing du/dt prototypes with time units Co-Authored-By: Chris Rackauckas Co-Authored-By: Claude Opus 4.6 --- Project.toml | 6 ++++- ext/DiffEqBaseFlexUnitsExt.jl | 45 +++++++++++++++++++++++++++++++++++ test/downstream/Project.toml | 1 + test/downstream/flexunits.jl | 9 +++++++ test/runtests.jl | 1 + 5 files changed, 61 insertions(+), 1 deletion(-) create mode 100644 ext/DiffEqBaseFlexUnitsExt.jl create mode 100644 test/downstream/flexunits.jl diff --git a/Project.toml b/Project.toml index 68409bd23..697226431 100644 --- a/Project.toml +++ b/Project.toml @@ -35,6 +35,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +FlexUnits = "76e01b6b-c995-4ce6-8559-91e72a3d4e95" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8" GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb" @@ -51,6 +52,7 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" DiffEqBaseCUDAExt = "CUDA" DiffEqBaseChainRulesCoreExt = "ChainRulesCore" DiffEqBaseEnzymeExt = ["ChainRulesCore", "Enzyme"] +DiffEqBaseFlexUnitsExt = "FlexUnits" DiffEqBaseForwardDiffExt = ["ForwardDiff"] DiffEqBaseGTPSAExt = "GTPSA" DiffEqBaseGeneralizedGeneratedExt = "GeneralizedGenerated" @@ -77,6 +79,7 @@ Enzyme = "0.13.100" FastBroadcast = "0.3.5" FastClosures = "0.3.2" FastPower = "1.1" +FlexUnits = "0.4" ForwardDiff = "0.10, 1" FunctionWrappers = "1.0" FunctionWrappersWrappers = "0.1" @@ -117,6 +120,7 @@ DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +FlexUnits = "76e01b6b-c995-4ce6-8559-91e72a3d4e95" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" @@ -134,4 +138,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Distributed", "Measurements", "Unitful", "LabelledArrays", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Pkg", "Random", "ReverseDiff", "StaticArrays", "SafeTestsets", "Test", "Distributions", "Aqua"] +test = ["Distributed", "Measurements", "Unitful", "FlexUnits", "LabelledArrays", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Pkg", "Random", "ReverseDiff", "StaticArrays", "SafeTestsets", "Test", "Distributions", "Aqua"] diff --git a/ext/DiffEqBaseFlexUnitsExt.jl b/ext/DiffEqBaseFlexUnitsExt.jl new file mode 100644 index 000000000..d14b4ca96 --- /dev/null +++ b/ext/DiffEqBaseFlexUnitsExt.jl @@ -0,0 +1,45 @@ +module DiffEqBaseFlexUnitsExt + +using DiffEqBase +import SciMLBase: unitfulvalue, value +using FlexUnits + +# Support adaptive errors should be errorless for exponentiation +value(::Type{Quantity{T, U}}) where {T, U} = T +value(x::Quantity{T, U}) where {T, U} = dstrip(x) + +unitfulvalue(::Type{T}) where {T <: Quantity} = T +unitfulvalue(x::Quantity) = x + +@inline function DiffEqBase.ODE_DEFAULT_NORM( + u::AbstractArray{ + <:Quantity, + N, + }, + t + ) where {N} + return sqrt( + sum( + x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), + zip((value(x) for x in u), Iterators.repeated(t)) + ) / length(u) + ) +end +@inline function DiffEqBase.ODE_DEFAULT_NORM( + u::Array{<:Quantity, N}, + t + ) where {N} + return sqrt( + sum( + x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), + zip((value(x) for x in u), Iterators.repeated(t)) + ) / length(u) + ) +end +@inline DiffEqBase.ODE_DEFAULT_NORM(u::Quantity, t) = abs(value(u)) +@inline function DiffEqBase.UNITLESS_ABS2(x::Quantity) + return real(abs2(dstrip(x))) +end + +DiffEqBase._rate_prototype(u, t::Quantity, onet) = u / unit(t) +end diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index 31e65c28d..19e6e65c3 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -6,6 +6,7 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +FlexUnits = "76e01b6b-c995-4ce6-8559-91e72a3d4e95" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8" Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" diff --git a/test/downstream/flexunits.jl b/test/downstream/flexunits.jl new file mode 100644 index 000000000..0634f82b4 --- /dev/null +++ b/test/downstream/flexunits.jl @@ -0,0 +1,9 @@ +using FlexUnits, FlexUnits.UnitRegistry, OrdinaryDiffEq, Test + +f(du, u, p, t) = du .= 3 * u"1/s" * u +prob = ODEProblem(f, [2.0u"m"], (0.0u"s", 1.0u"s")) +intg = init(prob, Tsit5(), dt = 0.01u"s") +@test_nowarn step!(intg, 0.02u"s", true) + +@test SciMLBase.unitfulvalue(1.0u"1/s") == 1.0u"1/s" +@test SciMLBase.value(1.0u"1/s") isa Real diff --git a/test/runtests.jl b/test/runtests.jl index a44c55424..21b689b41 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -48,6 +48,7 @@ end @time @safetestset "Null DE Handling" include("downstream/null_de.jl") @time @safetestset "StaticArrays + AD" include("downstream/static_arrays_ad.jl") @time @safetestset "Unitful" include("downstream/unitful.jl") + @time @safetestset "FlexUnits" include("downstream/flexunits.jl") @time @safetestset "Dual Detection Solution" include("downstream/dual_detection_solution.jl") @time @safetestset "Null Parameters" include("downstream/null_params_test.jl") @time @safetestset "Ensemble Simulations" include("downstream/ensemble.jl")