diff --git a/src/problems/bvp_problems.jl b/src/problems/bvp_problems.jl index 8db6e45a0d..cecd6c0889 100644 --- a/src/problems/bvp_problems.jl +++ b/src/problems/bvp_problems.jl @@ -320,7 +320,8 @@ struct SecondOrderBVProblem{uType, tType, isinplace, nlls, P, F, PT, K} <: problem_type::PT kwargs::K - @add_kwonly function SecondOrderBVProblem{iip}(f::DynamicalBVPFunction{iip, TP}, u0, tspan, + @add_kwonly function SecondOrderBVProblem{iip}( + f::DynamicalBVPFunction{iip, TP}, u0, tspan, p = NullParameters(); problem_type = nothing, nlls = nothing, kwargs...) where {iip, TP} _u0 = prepare_initial_state(u0) @@ -331,16 +332,19 @@ struct SecondOrderBVProblem{uType, tType, isinplace, nlls, P, F, PT, K} <: typeof(problem_type), typeof(kwargs)}(f, _u0, _tspan, p, problem_type, kwargs) end - function SecondOrderBVProblem{iip}(f, bc, u0, tspan, p = NullParameters(); kwargs...) where {iip} + function SecondOrderBVProblem{iip}( + f, bc, u0, tspan, p = NullParameters(); kwargs...) where {iip} SecondOrderBVProblem(DynamicalBVPFunction{iip}(f, bc), u0, tspan, p; kwargs...) end end function SecondOrderBVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...) iip = isinplace(f, 5) - return SecondOrderBVProblem{iip}(DynamicalBVPFunction{iip}(f, bc), u0, tspan, p; kwargs...) + return SecondOrderBVProblem{iip}( + DynamicalBVPFunction{iip}(f, bc), u0, tspan, p; kwargs...) end -function SecondOrderBVProblem(f::DynamicalBVPFunction, u0, tspan, p = NullParameters(); kwargs...) +function SecondOrderBVProblem( + f::DynamicalBVPFunction, u0, tspan, p = NullParameters(); kwargs...) return SecondOrderBVProblem{isinplace(f)}(f, u0, tspan, p; kwargs...) -end \ No newline at end of file +end diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 83e6b8e3b1..becff1bf94 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -1789,7 +1789,8 @@ and more. For all cases, `u` is the state and `p` are the parameters. ```julia OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD(); grad = nothing, hess = nothing, hv = nothing, - cons = nothing, cons_j = nothing, cons_h = nothing, + cons = nothing, cons_j = nothing, cons_jvp = nothing, + cons_vjp = nothing, cons_h = nothing, hess_prototype = nothing, cons_jac_prototype = nothing, cons_hess_prototype = nothing, @@ -1827,6 +1828,8 @@ function described in [Callback Functions](https://docs.sciml.ai/Optimization/st bounds passed as `lcons` and `ucons` to [`OptimizationProblem`](@ref), in case of equality constraints `lcons` and `ucons` should be passed equal values. - `cons_j(J,x,p)` or `J=cons_j(x,p)`: the Jacobian of the constraints. +- `cons_jvp(Jv,v,x,p)` or `Jv=cons_jvp(v,x,p)`: the Jacobian-vector product of the constraints. +- `cons_vjp(Jv,v,x,p)` or `Jv=cons_vjp(v,x,p)`: the Jacobian-vector product of the constraints. - `cons_h(H,x,p)` or `H=cons_h(x,p)`: the Hessian of the constraints, provided as an array of Hessians with `res[i]` being the Hessian with respect to the `i`th output on `cons`. - `hess_prototype`: a prototype matrix matching the type that matches the Hessian. For example, @@ -1892,7 +1895,7 @@ For more details on this argument, see the ODEFunction documentation. The fields of the OptimizationFunction type directly match the names of the inputs. """ -struct OptimizationFunction{iip, AD, F, G, H, HV, C, CJ, CH, HP, CJP, CHP, O, +struct OptimizationFunction{iip, AD, F, G, H, HV, C, CJ, CJV, CVJ, CH, HP, CJP, CHP, O, EX, CEX, SYS, LH, LHP, HCV, CJCV, CHCV, LHCV} <: AbstractOptimizationFunction{iip} f::F @@ -1902,6 +1905,8 @@ struct OptimizationFunction{iip, AD, F, G, H, HV, C, CJ, CH, HP, CJP, CHP, O, hv::HV cons::C cons_j::CJ + cons_jvp::CJV + cons_vjp::CVJ cons_h::CH hess_prototype::HP cons_jac_prototype::CJP @@ -2137,7 +2142,8 @@ For more details on this argument, see the ODEFunction documentation. The fields of the DynamicalBVPFunction type directly match the names of the inputs. """ -struct DynamicalBVPFunction{iip, specialize, twopoint, F, BF, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP, +struct DynamicalBVPFunction{ + iip, specialize, twopoint, F, BF, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP, JP, BCJP, BCRP, SP, TW, TWt, TPJ, O, TCV, BCTCV, SYS} <: AbstractBVPFunction{iip, twopoint} f::F @@ -3758,7 +3764,8 @@ OptimizationFunction(args...; kwargs...) = OptimizationFunction{true}(args...; k function OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD(); grad = nothing, hess = nothing, hv = nothing, - cons = nothing, cons_j = nothing, cons_h = nothing, + cons = nothing, cons_j = nothing, cons_jvp = nothing, + cons_vjp = nothing, cons_h = nothing, hess_prototype = nothing, cons_jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing, @@ -3780,7 +3787,8 @@ function OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD(); sys = sys_or_symbolcache(sys, syms, paramsyms) OptimizationFunction{iip, typeof(adtype), typeof(f), typeof(grad), typeof(hess), typeof(hv), - typeof(cons), typeof(cons_j), typeof(cons_h), + typeof(cons), typeof(cons_j), typeof(cons_jvp), + typeof(cons_vjp), typeof(cons_h), typeof(hess_prototype), typeof(cons_jac_prototype), typeof(cons_hess_prototype), typeof(observed), @@ -3789,7 +3797,8 @@ function OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD(); typeof(cons_jac_colorvec), typeof(cons_hess_colorvec), typeof(lag_hess_colorvec) }(f, adtype, grad, hess, - hv, cons, cons_j, cons_h, + hv, cons, cons_j, cons_jvp, + cons_vjp, cons_h, hess_prototype, cons_jac_prototype, cons_hess_prototype, observed, expr, cons_expr, sys, lag_h, lag_hess_prototype, hess_colorvec, cons_jac_colorvec, @@ -3987,7 +3996,6 @@ function DynamicalBVPFunction{iip, specialize, twopoint}(f, bc; colorvec = __has_colorvec(f) ? f.colorvec : nothing, bccolorvec = __has_colorvec(bc) ? bc.colorvec : nothing, sys = __has_sys(f) ? f.sys : nothing) where {iip, specialize, twopoint} - if mass_matrix === I && f isa Tuple mass_matrix = ((I for i in 1:length(f))...,) end @@ -4095,7 +4103,7 @@ function DynamicalBVPFunction{iip, specialize, twopoint}(f, bc; _f = prepare_function(f) sys = something(sys, SymbolCache(syms, paramsyms, indepsym)) - + if specialize === NoSpecialize DynamicalBVPFunction{iip, specialize, twopoint, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, @@ -4127,11 +4135,11 @@ function DynamicalBVPFunction{iip}(f, bc; twopoint::Union{Val, Bool} = Val(false end DynamicalBVPFunction{iip}(f::DynamicalBVPFunction, bc; kwargs...) where {iip} = f function DynamicalBVPFunction(f, bc; twopoint::Union{Val, Bool} = Val(false), kwargs...) - DynamicalBVPFunction{isinplace(f, 5), FullSpecialize, _unwrap_val(twopoint)}(f, bc; kwargs...) + DynamicalBVPFunction{isinplace(f, 5), FullSpecialize, _unwrap_val(twopoint)}( + f, bc; kwargs...) end DynamicalBVPFunction(f::DynamicalBVPFunction; kwargs...) = f - function IntegralFunction{iip, specialize}(f, integrand_prototype) where {iip, specialize} _f = prepare_function(f) IntegralFunction{iip, specialize, typeof(_f), typeof(integrand_prototype)}(_f, diff --git a/test/function_building_error_messages.jl b/test/function_building_error_messages.jl index c9773b93e6..3437bf1068 100644 --- a/test/function_building_error_messages.jl +++ b/test/function_building_error_messages.jl @@ -706,24 +706,35 @@ DynamicalBVPFunction(dbfiip, dbciip, jac = dbjac, bcjac = dbcjac) DynamicalBVPFunction(dbfoop, dbcoop, jac = dbjac, bcjac = dbcjac) dbWfact(du, u, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfiip, dbciip, Wfact = dbWfact) -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfoop, dbciip, Wfact = dbWfact) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfiip, dbciip, Wfact = dbWfact) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfoop, dbciip, Wfact = dbWfact) dbWfact(du, u, p, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfiip, dbciip, Wfact = dbWfact) -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfoop, dbciip, Wfact = dbWfact) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfiip, dbciip, Wfact = dbWfact) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfoop, dbciip, Wfact = dbWfact) dbWfact(du, u, p, gamma, t) = [1.0] -@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfiip, dbciip, Wfact = dbWfact) -@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfoop, dbciip, Wfact = dbWfact) +@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction( + dbfiip, dbciip, Wfact = dbWfact) +@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction( + dbfoop, dbciip, Wfact = dbWfact) dbWfact(ddu, du, u, p, gamma, t) = [1.0] DynamicalBVPFunction(dbfiip, dbciip, Wfact = dbWfact) -@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfoop, dbciip, Wfact = dbWfact) +@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction( + dbfoop, dbciip, Wfact = dbWfact) dbWfact_t(du, u, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfiip, dbciip, Wfact_t = dbWfact_t) -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfoop, dbciip, Wfact_t = dbWfact_t) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfiip, dbciip, Wfact_t = dbWfact_t) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfoop, dbciip, Wfact_t = dbWfact_t) dbWfact_t(du, u, p, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfiip, dbciip, Wfact_t = dbWfact_t) -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfoop, dbciip, Wfact_t = dbWfact_t) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfiip, dbciip, Wfact_t = dbWfact_t) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfoop, dbciip, Wfact_t = dbWfact_t) dbWfact_t(du, u, p, gamma, t) = [1.0] @test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfiip, dbciip, @@ -738,18 +749,25 @@ DynamicalBVPFunction(dbfiip, dbciip, Wfact_t = dbWfact_t) Wfact_t = dbWfact_t) dbtgrad(du, u, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfiip, dbciip, tgrad = dbtgrad) -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfoop, dbciip, tgrad = dbtgrad) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfiip, dbciip, tgrad = dbtgrad) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfoop, dbciip, tgrad = dbtgrad) dbtgrad(du, u, p, t) = [1.0] -@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfiip, dbciip, tgrad = dbtgrad) -@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfoop, dbciip, tgrad = dbtgrad) +@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction( + dbfiip, dbciip, tgrad = dbtgrad) +@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction( + dbfoop, dbciip, tgrad = dbtgrad) dbtgrad(ddu, du, u, p, t) = [1.0] DynamicalBVPFunction(dbfiip, dbciip, tgrad = dbtgrad) -@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfoop, dbciip, tgrad = dbtgrad) +@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction( + dbfoop, dbciip, tgrad = dbtgrad) dbparamjac(du, u, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfiip, dbciip, paramjac = dbparamjac) -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfoop, dbciip, paramjac = dbparamjac) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfiip, dbciip, paramjac = dbparamjac) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfoop, dbciip, paramjac = dbparamjac) dbparamjac(du, u, p, t) = [1.0] @test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfiip, dbciip, @@ -764,25 +782,35 @@ DynamicalBVPFunction(dbfiip, dbciip, paramjac = dbparamjac) paramjac = dbparamjac) dbjvp(du, u, p, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfiip, dbciip, jvp = dbjvp) -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfoop, dbciip, jvp = dbjvp) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfiip, dbciip, jvp = dbjvp) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfoop, dbciip, jvp = dbjvp) dbjvp(du, u, v, p, t) = [1.0] -@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfiip, dbciip, jvp = dbjvp) -@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfoop, dbciip, jvp = dbjvp) +@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction( + dbfiip, dbciip, jvp = dbjvp) +@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction( + dbfoop, dbciip, jvp = dbjvp) dbjvp(ddu, du, u, v, p, t) = [1.0] DynamicalBVPFunction(dbfiip, dbciip, jvp = dbjvp) -@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfoop, dbciip, jvp = dbjvp) +@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction( + dbfoop, dbciip, jvp = dbjvp) dbvjp(du, u, p, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfiip, dbciip, vjp = dbvjp) -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfoop, dbciip, vjp = dbvjp) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfiip, dbciip, vjp = dbvjp) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfoop, dbciip, vjp = dbvjp) dbvjp(du, u, v, p, t) = [1.0] -@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfiip, dbciip, vjp = dbvjp) -@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfoop, dbciip, vjp = dbvjp) +@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction( + dbfiip, dbciip, vjp = dbvjp) +@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction( + dbfoop, dbciip, vjp = dbvjp) dbvjp(ddu, du, u, v, p, t) = [1.0] DynamicalBVPFunction(dbfiip, dbciip, vjp = dbvjp) -@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfoop, dbciip, vjp = dbvjp) +@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction( + dbfoop, dbciip, vjp = dbvjp) # IntegralFunction