Skip to content

Ensemble doc fix #1008

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 29, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 117 additions & 10 deletions src/ensemble/ensemble_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,25 @@ EnsembleProblem(prob::AbstractSciMLProblem;
`repeat` is the iteration of the repeat. At first, it is `1`, but if
`rerun` was true this will be `2`, `3`, etc. counting the number of times
problem `i` has been repeated.
- `reduction`: This function determines how to reduce the data in each batch.
Defaults to appending the `data` into `u`, initialised via `u_data`, from
the batches. `I` is a range of indices giving the trajectories corresponding
to the batches. The second part of the output determines whether the simulation
has converged. If `true`, the simulation will exit early. By default, this is
always `false`.

- `reduction`: This function is used to aggregate the results in each simulation batch.
By default, it appends the `data` from the batch to `u`, which is initialized via `u_data`.
The `I` is a range of indices corresponding to the trajectories for the current batch.
### Arguments:
- `u`: The solution from the current ensemble run. This is the accumulated data that gets
updated in each batch.
- `data`: The results from the current batch of simulations. This is typically some data
(e.g., variable values, time steps) that is merged with `u`.
- `I`: A range of indices corresponding to the simulations in the current batch. This provides
the trajectory indices for the batch.

### Returns:
- `(new_data, has_converged)`: A tuple where:
- `new_data`: The updated accumulated data, typically the result of appending `data` to `u`.
- `has_converged`: A boolean indicating whether the simulation has converged and should terminate early.
If `true`, the simulation will stop early. If `false`, the simulation will continue. By default, this is
`false`, meaning the simulation will not stop early.

- `u_init`: The initial form of the object that gets updated in-place inside the
`reduction` function.
- `safetycopy`: Determines whether a safety `deepcopy` is called on the `prob`
Expand Down Expand Up @@ -81,6 +94,21 @@ output_func(sol, i) = (sol[end, 2], false)
Thus, the ensemble simulation would return as its data an array which is the
end value of the 2nd dependent variable for each of the runs.
"""

"""
$(TYPEDEF)

Defines a structure to manage an ensemble (batch) of problems.
Each field controls how the ensemble behaves during simulation.

## Arguments
- `prob`: The original base problem to replicate or modify.
- `prob_func`: A function that defines how to generate each subproblem.
- `output_func`: A function to post-process each individual simulation result.
- `reduction`: A function to combine results from all simulations.
- `u_init`: The initial container used to accumulate the results.
- `safetycopy`: Whether to copy the problem when creating subproblems (to avoid unintended modifications).
"""
struct EnsembleProblem{T, T2, T3, T4, T5} <: AbstractEnsembleProblem
prob::T
prob_func::T2
Expand All @@ -90,19 +118,58 @@ struct EnsembleProblem{T, T2, T3, T4, T5} <: AbstractEnsembleProblem
safetycopy::Bool
end

"""
Returns the same problem without modification.
"""
DEFAULT_PROB_FUNC(prob, i, repeat) = prob

"""
Returns the solution as-is, along with `false` indicating no rerun.
"""
DEFAULT_OUTPUT_FUNC(sol, i) = (sol, false)

"""
Appends new data to the accumulated data and returns `false` to indicate no early termination.
"""
DEFAULT_REDUCTION(u, data, I) = append!(u, data), false

"""
Selects the i-th problem from a vector of problems.
"""
DEFAULT_VECTOR_PROB_FUNC(prob, i, repeat) = prob[i]

"""
$(TYPEDEF)

Constructor for deprecated usage where a vector of problems is passed directly.

!!! warning
This constructor is deprecated. Use the standard ensemble syntax with `prob_func` instead.
"""
function EnsembleProblem(prob::AbstractVector{<:AbstractSciMLProblem}; kwargs...)
Base.depwarn("This dispatch is deprecated for the standard ensemble syntax. See the Parallel
Base.depwarn("This dispatch is deprecated for the standard ensemble syntax. See the Parallel \
Ensembles Simulations Interface page for more details", :EnsembleProblem)
invoke(EnsembleProblem,
Tuple{Any},
prob;
prob_func = DEFAULT_VECTOR_PROB_FUNC,
kwargs...)
end

"""
$(TYPEDEF)

Main constructor for `EnsembleProblem`.

## Keyword Arguments

- `prob`: The base problem.
- `prob_func`: Function to modify the base problem per trajectory.
- `output_func`: Function to extract output from a solution.
- `reduction`: Function to aggregate results.
- `u_init`: Initial value for aggregation.
- `safetycopy`: Whether to deepcopy the problem before modifying.
"""
function EnsembleProblem(prob;
prob_func = DEFAULT_PROB_FUNC,
output_func = DEFAULT_OUTPUT_FUNC,
Expand All @@ -116,6 +183,11 @@ function EnsembleProblem(prob;
EnsembleProblem(prob, _prob_func, _output_func, _reduction, _u_init, safetycopy)
end

"""
$(TYPEDEF)

Alternate constructor that uses only keyword arguments.
"""
function EnsembleProblem(; prob,
prob_func = DEFAULT_PROB_FUNC,
output_func = DEFAULT_OUTPUT_FUNC,
Expand All @@ -125,32 +197,67 @@ function EnsembleProblem(; prob,
EnsembleProblem(prob; prob_func, output_func, reduction, u_init, safetycopy)
end

#since NonlinearProblem might want to use this dispatch as well
"""
$(TYPEDEF)

Constructor that is used for NOnlinearProblem.

!!! warning
This dispatch is deprecated. See the Parallel Ensembles Simulations Interface page.
"""
function SciMLBase.EnsembleProblem(
prob::AbstractSciMLProblem, u0s::Vector{Vector{T}}; kwargs...) where {T}
Base.depwarn("This dispatch is deprecated for the standard ensemble syntax. See the Parallel
Ensembles Simulations Interface page for more details", :EnsebleProblem)
Base.depwarn("This dispatch is deprecated for the standard ensemble syntax. See the Parallel \
Ensembles Simulations Interface page for more details", :EnsembleProblem)
prob_func = (prob, i, repeat = nothing) -> remake(prob, u0 = u0s[i])
return SciMLBase.EnsembleProblem(prob; prob_func, kwargs...)
end

"""
$(TYPEDEF)

Defines a weighted version of an `EnsembleProblem`, where different simulations contribute unequally.

## Arguments

- `ensembleprob`: The base ensemble problem.
- `weights`: A vector of weights corresponding to each simulation.
"""
struct WeightedEnsembleProblem{T1 <: AbstractEnsembleProblem, T2 <: AbstractVector} <:
AbstractEnsembleProblem
ensembleprob::T1
weights::T2
end

"""
Returns a list of all accessible properties, including those from the inner ensemble and `:weights`.
"""
function Base.propertynames(e::WeightedEnsembleProblem)
(Base.propertynames(getfield(e, :ensembleprob))..., :weights)
end

"""
Accesses properties of a `WeightedEnsembleProblem`.

Returns `weights` or delegates to the underlying ensemble.
"""
function Base.getproperty(e::WeightedEnsembleProblem, f::Symbol)
f === :weights && return getfield(e, :weights)
f === :ensembleprob && return getfield(e, :ensembleprob)
return getproperty(getfield(e, :ensembleprob), f)
end

"""
$(TYPEDEF)

Constructor for `WeightedEnsembleProblem`. Ensures weights sum to 1 and matches problem count.

"""
function WeightedEnsembleProblem(args...; weights, kwargs...)
# TODO: allow skipping checks?
@assert sum(weights) ≈ 1
ep = EnsembleProblem(args...; kwargs...)
@assert length(ep.prob) == length(weights)
WeightedEnsembleProblem(ep, weights)
end

Loading