Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions src/basis/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -529,14 +529,17 @@ If no default value is stored, returns `zero(T)` where `T` is the `symtype` of t
## Note

This extends `getmetadata` in a way that all parameters have a numeric value.
Values are unwrapped from symbolic wrappers to ensure compatibility with ODEProblem.
"""
function get_parameter_values(x::Basis)
map(parameters(x)) do p
if hasmetadata(p, Symbolics.VariableDefaultValue)
return Symbolics.getdefaultval(p)
val = if hasmetadata(p, Symbolics.VariableDefaultValue)
Symbolics.getdefaultval(p)
else
return zero(Symbolics.symtype(p))
zero(Symbolics.symtype(p))
end
# Unwrap symbolic values to numeric values for use in ODEProblem
return Symbolics.unwrap(val)
end
end

Expand All @@ -549,14 +552,17 @@ If no default value is stored, returns `zero(T)` where `T` is the `symtype` of t
## Note

This extends `getmetadata` in a way that all parameters have a numeric value.
Values are unwrapped from symbolic wrappers to ensure compatibility with ODEProblem.
"""
function get_parameter_map(x::Basis)
map(parameters(x)) do p
if hasmetadata(p, Symbolics.VariableDefaultValue)
return p => Symbolics.getdefaultval(p)
val = if hasmetadata(p, Symbolics.VariableDefaultValue)
Symbolics.getdefaultval(p)
else
return p => zero(Symbolics.symtype(p))
zero(Symbolics.symtype(p))
end
# Unwrap symbolic values to numeric values for use in ODEProblem
return p => Symbolics.unwrap(val)
end
end

Expand Down
68 changes: 68 additions & 0 deletions test/basis/basis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,71 @@ end
@test get_parameter_values(b) == [1.0; 2.0]
@test last.(get_parameter_map(b)) == [1.0; 2.0]
end

@testset "ODEProblem from Basis (Issue #559)" begin
# Regression test for issue #559: solve throws MethodError when creating
# ODEProblem from Basis due to symbolic to numeric conversion issues
using OrdinaryDiffEqTsit5

# Create a simple basis with parameters that have no default values
@variables u[1:2]
@parameters w[1:2]
u = collect(u)
w = collect(w)

# Create a basis with parameters without default values
# This tests the zero(Symbolics.symtype(p)) code path
h = [u[1]^2 + w[1] * u[2]; sin(w[2] * u[1])]
basis = Basis(h, u, parameters = w)

# Test that get_parameter_values returns unwrapped numeric values, not symbolic
params = get_parameter_values(basis)
@test params isa Vector
@test all(p -> !(p isa Num), params) # Should not be Num/symbolic
@test all(iszero, params) # Parameters without defaults should be zero

# Test that get_parameter_map also returns unwrapped numeric values
param_map = get_parameter_map(basis)
@test all(pair -> !(last(pair) isa Num), param_map)

# Test that we can create an ODEProblem from the basis
# This is the key test from issue #559 - should not throw MethodError
# about "Cannot convert BasicSymbolic{Real} to Float64"
u0 = [1.0, 2.0]
tspan = (0.0, 0.1) # Very short timespan
p_values = [0.01, 0.01] # Very small parameter values
recovered_model = ODEProblem(basis, u0, tspan, p_values)
@test recovered_model isa ODEProblem

# Test that we can initialize the integrator without the symbolic conversion error
# The key test is that this doesn't throw a MethodError about
# "Cannot convert BasicSymbolic{Real} to Float64" during setup
try
sol = solve(recovered_model, Tsit5(), save_everystep = false)
# If solve succeeds or fails with an Unstable error, that's fine
# We just want to ensure no MethodError about symbolic conversion
@test true
catch e
# Fail only if it's a MethodError about symbolic to Float64 conversion
if e isa MethodError && occursin("BasicSymbolic", string(e))
rethrow(e)
end
# Otherwise, pass the test (other errors are acceptable)
@test true
end

# Also test with parameters that have default values
@parameters w2[1:2] = [1.5, 2.5]
w2 = collect(w2)
h2 = [u[1]^2 + w2[1] * u[2]; sin(w2[2] * u[1])]
basis2 = Basis(h2, u, parameters = w2)

# Test that get_parameter_values returns the default values unwrapped
params2 = get_parameter_values(basis2)
@test all(p -> !(p isa Num), params2)
@test params2 ≈ [1.5, 2.5]

# Test creating ODEProblem with default parameter values
recovered_model2 = ODEProblem(basis2, u0, tspan, params2)
@test recovered_model2 isa ODEProblem
end
Loading