diff --git a/src/ensemble/ensemble_problems.jl b/src/ensemble/ensemble_problems.jl index 50479a473..102dc5a7b 100644 --- a/src/ensemble/ensemble_problems.jl +++ b/src/ensemble/ensemble_problems.jl @@ -28,12 +28,25 @@ EnsembleProblem(prob::AbstractSciMLProblem; `repeat` is the iteration of the repeat. At first, it is `1`, but if `rerun` was true this will be `2`, `3`, etc. counting the number of times problem `i` has been repeated. - - `reduction`: This function determines how to reduce the data in each batch. - Defaults to appending the `data` into `u`, initialised via `u_data`, from - the batches. `I` is a range of indices giving the trajectories corresponding - to the batches. The second part of the output determines whether the simulation - has converged. If `true`, the simulation will exit early. By default, this is - always `false`. + + - `reduction`: This function is used to aggregate the results in each simulation batch. + By default, it appends the `data` from the batch to `u`, which is initialized via `u_data`. + The `I` is a range of indices corresponding to the trajectories for the current batch. + ### Arguments: + - `u`: The solution from the current ensemble run. This is the accumulated data that gets + updated in each batch. + - `data`: The results from the current batch of simulations. This is typically some data + (e.g., variable values, time steps) that is merged with `u`. + - `I`: A range of indices corresponding to the simulations in the current batch. This provides + the trajectory indices for the batch. + + ### Returns: + - `(new_data, has_converged)`: A tuple where: + - `new_data`: The updated accumulated data, typically the result of appending `data` to `u`. + - `has_converged`: A boolean indicating whether the simulation has converged and should terminate early. + If `true`, the simulation will stop early. If `false`, the simulation will continue. By default, this is + `false`, meaning the simulation will not stop early. + - `u_init`: The initial form of the object that gets updated in-place inside the `reduction` function. - `safetycopy`: Determines whether a safety `deepcopy` is called on the `prob` @@ -81,6 +94,21 @@ output_func(sol, i) = (sol[end, 2], false) Thus, the ensemble simulation would return as its data an array which is the end value of the 2nd dependent variable for each of the runs. """ + +""" +$(TYPEDEF) + +Defines a structure to manage an ensemble (batch) of problems. +Each field controls how the ensemble behaves during simulation. + +## Arguments + - `prob`: The original base problem to replicate or modify. + - `prob_func`: A function that defines how to generate each subproblem. + - `output_func`: A function to post-process each individual simulation result. + - `reduction`: A function to combine results from all simulations. + - `u_init`: The initial container used to accumulate the results. + - `safetycopy`: Whether to copy the problem when creating subproblems (to avoid unintended modifications). +""" struct EnsembleProblem{T, T2, T3, T4, T5} <: AbstractEnsembleProblem prob::T prob_func::T2 @@ -90,12 +118,36 @@ struct EnsembleProblem{T, T2, T3, T4, T5} <: AbstractEnsembleProblem safetycopy::Bool end +""" +Returns the same problem without modification. +""" DEFAULT_PROB_FUNC(prob, i, repeat) = prob + +""" +Returns the solution as-is, along with `false` indicating no rerun. +""" DEFAULT_OUTPUT_FUNC(sol, i) = (sol, false) + +""" +Appends new data to the accumulated data and returns `false` to indicate no early termination. +""" DEFAULT_REDUCTION(u, data, I) = append!(u, data), false + +""" +Selects the i-th problem from a vector of problems. +""" DEFAULT_VECTOR_PROB_FUNC(prob, i, repeat) = prob[i] + +""" +$(TYPEDEF) + +Constructor for deprecated usage where a vector of problems is passed directly. + +!!! warning + This constructor is deprecated. Use the standard ensemble syntax with `prob_func` instead. +""" function EnsembleProblem(prob::AbstractVector{<:AbstractSciMLProblem}; kwargs...) - Base.depwarn("This dispatch is deprecated for the standard ensemble syntax. See the Parallel + Base.depwarn("This dispatch is deprecated for the standard ensemble syntax. See the Parallel \ Ensembles Simulations Interface page for more details", :EnsembleProblem) invoke(EnsembleProblem, Tuple{Any}, @@ -103,6 +155,21 @@ function EnsembleProblem(prob::AbstractVector{<:AbstractSciMLProblem}; kwargs... prob_func = DEFAULT_VECTOR_PROB_FUNC, kwargs...) end + +""" +$(TYPEDEF) + +Main constructor for `EnsembleProblem`. + +## Keyword Arguments + +- `prob`: The base problem. +- `prob_func`: Function to modify the base problem per trajectory. +- `output_func`: Function to extract output from a solution. +- `reduction`: Function to aggregate results. +- `u_init`: Initial value for aggregation. +- `safetycopy`: Whether to deepcopy the problem before modifying. +""" function EnsembleProblem(prob; prob_func = DEFAULT_PROB_FUNC, output_func = DEFAULT_OUTPUT_FUNC, @@ -116,6 +183,11 @@ function EnsembleProblem(prob; EnsembleProblem(prob, _prob_func, _output_func, _reduction, _u_init, safetycopy) end +""" +$(TYPEDEF) + +Alternate constructor that uses only keyword arguments. +""" function EnsembleProblem(; prob, prob_func = DEFAULT_PROB_FUNC, output_func = DEFAULT_OUTPUT_FUNC, @@ -125,28 +197,62 @@ function EnsembleProblem(; prob, EnsembleProblem(prob; prob_func, output_func, reduction, u_init, safetycopy) end -#since NonlinearProblem might want to use this dispatch as well +""" +$(TYPEDEF) + +Constructor that is used for NOnlinearProblem. + +!!! warning + This dispatch is deprecated. See the Parallel Ensembles Simulations Interface page. +""" function SciMLBase.EnsembleProblem( prob::AbstractSciMLProblem, u0s::Vector{Vector{T}}; kwargs...) where {T} - Base.depwarn("This dispatch is deprecated for the standard ensemble syntax. See the Parallel - Ensembles Simulations Interface page for more details", :EnsebleProblem) + Base.depwarn("This dispatch is deprecated for the standard ensemble syntax. See the Parallel \ + Ensembles Simulations Interface page for more details", :EnsembleProblem) prob_func = (prob, i, repeat = nothing) -> remake(prob, u0 = u0s[i]) return SciMLBase.EnsembleProblem(prob; prob_func, kwargs...) end +""" +$(TYPEDEF) + +Defines a weighted version of an `EnsembleProblem`, where different simulations contribute unequally. + +## Arguments + +- `ensembleprob`: The base ensemble problem. +- `weights`: A vector of weights corresponding to each simulation. +""" struct WeightedEnsembleProblem{T1 <: AbstractEnsembleProblem, T2 <: AbstractVector} <: AbstractEnsembleProblem ensembleprob::T1 weights::T2 end + +""" +Returns a list of all accessible properties, including those from the inner ensemble and `:weights`. +""" function Base.propertynames(e::WeightedEnsembleProblem) (Base.propertynames(getfield(e, :ensembleprob))..., :weights) end + +""" +Accesses properties of a `WeightedEnsembleProblem`. + +Returns `weights` or delegates to the underlying ensemble. +""" function Base.getproperty(e::WeightedEnsembleProblem, f::Symbol) f === :weights && return getfield(e, :weights) f === :ensembleprob && return getfield(e, :ensembleprob) return getproperty(getfield(e, :ensembleprob), f) end + +""" +$(TYPEDEF) + +Constructor for `WeightedEnsembleProblem`. Ensures weights sum to 1 and matches problem count. + +""" function WeightedEnsembleProblem(args...; weights, kwargs...) # TODO: allow skipping checks? @assert sum(weights) ≈ 1 @@ -154,3 +260,4 @@ function WeightedEnsembleProblem(args...; weights, kwargs...) @assert length(ep.prob) == length(weights) WeightedEnsembleProblem(ep, weights) end +