Skip to content

Commit deb4314

Browse files
Merge pull request #1008 from Zf98ai/ensemble-doc-fix
Ensemble doc fix
2 parents cc59a15 + b2d1ca8 commit deb4314

File tree

1 file changed

+117
-10
lines changed

1 file changed

+117
-10
lines changed

src/ensemble/ensemble_problems.jl

Lines changed: 117 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,25 @@ EnsembleProblem(prob::AbstractSciMLProblem;
2828
`repeat` is the iteration of the repeat. At first, it is `1`, but if
2929
`rerun` was true this will be `2`, `3`, etc. counting the number of times
3030
problem `i` has been repeated.
31-
- `reduction`: This function determines how to reduce the data in each batch.
32-
Defaults to appending the `data` into `u`, initialised via `u_data`, from
33-
the batches. `I` is a range of indices giving the trajectories corresponding
34-
to the batches. The second part of the output determines whether the simulation
35-
has converged. If `true`, the simulation will exit early. By default, this is
36-
always `false`.
31+
32+
- `reduction`: This function is used to aggregate the results in each simulation batch.
33+
By default, it appends the `data` from the batch to `u`, which is initialized via `u_data`.
34+
The `I` is a range of indices corresponding to the trajectories for the current batch.
35+
### Arguments:
36+
- `u`: The solution from the current ensemble run. This is the accumulated data that gets
37+
updated in each batch.
38+
- `data`: The results from the current batch of simulations. This is typically some data
39+
(e.g., variable values, time steps) that is merged with `u`.
40+
- `I`: A range of indices corresponding to the simulations in the current batch. This provides
41+
the trajectory indices for the batch.
42+
43+
### Returns:
44+
- `(new_data, has_converged)`: A tuple where:
45+
- `new_data`: The updated accumulated data, typically the result of appending `data` to `u`.
46+
- `has_converged`: A boolean indicating whether the simulation has converged and should terminate early.
47+
If `true`, the simulation will stop early. If `false`, the simulation will continue. By default, this is
48+
`false`, meaning the simulation will not stop early.
49+
3750
- `u_init`: The initial form of the object that gets updated in-place inside the
3851
`reduction` function.
3952
- `safetycopy`: Determines whether a safety `deepcopy` is called on the `prob`
@@ -81,6 +94,21 @@ output_func(sol, i) = (sol[end, 2], false)
8194
Thus, the ensemble simulation would return as its data an array which is the
8295
end value of the 2nd dependent variable for each of the runs.
8396
"""
97+
98+
"""
99+
$(TYPEDEF)
100+
101+
Defines a structure to manage an ensemble (batch) of problems.
102+
Each field controls how the ensemble behaves during simulation.
103+
104+
## Arguments
105+
- `prob`: The original base problem to replicate or modify.
106+
- `prob_func`: A function that defines how to generate each subproblem.
107+
- `output_func`: A function to post-process each individual simulation result.
108+
- `reduction`: A function to combine results from all simulations.
109+
- `u_init`: The initial container used to accumulate the results.
110+
- `safetycopy`: Whether to copy the problem when creating subproblems (to avoid unintended modifications).
111+
"""
84112
struct EnsembleProblem{T, T2, T3, T4, T5} <: AbstractEnsembleProblem
85113
prob::T
86114
prob_func::T2
@@ -90,19 +118,58 @@ struct EnsembleProblem{T, T2, T3, T4, T5} <: AbstractEnsembleProblem
90118
safetycopy::Bool
91119
end
92120

121+
"""
122+
Returns the same problem without modification.
123+
"""
93124
DEFAULT_PROB_FUNC(prob, i, repeat) = prob
125+
126+
"""
127+
Returns the solution as-is, along with `false` indicating no rerun.
128+
"""
94129
DEFAULT_OUTPUT_FUNC(sol, i) = (sol, false)
130+
131+
"""
132+
Appends new data to the accumulated data and returns `false` to indicate no early termination.
133+
"""
95134
DEFAULT_REDUCTION(u, data, I) = append!(u, data), false
135+
136+
"""
137+
Selects the i-th problem from a vector of problems.
138+
"""
96139
DEFAULT_VECTOR_PROB_FUNC(prob, i, repeat) = prob[i]
140+
141+
"""
142+
$(TYPEDEF)
143+
144+
Constructor for deprecated usage where a vector of problems is passed directly.
145+
146+
!!! warning
147+
This constructor is deprecated. Use the standard ensemble syntax with `prob_func` instead.
148+
"""
97149
function EnsembleProblem(prob::AbstractVector{<:AbstractSciMLProblem}; kwargs...)
98-
Base.depwarn("This dispatch is deprecated for the standard ensemble syntax. See the Parallel
150+
Base.depwarn("This dispatch is deprecated for the standard ensemble syntax. See the Parallel \
99151
Ensembles Simulations Interface page for more details", :EnsembleProblem)
100152
invoke(EnsembleProblem,
101153
Tuple{Any},
102154
prob;
103155
prob_func = DEFAULT_VECTOR_PROB_FUNC,
104156
kwargs...)
105157
end
158+
159+
"""
160+
$(TYPEDEF)
161+
162+
Main constructor for `EnsembleProblem`.
163+
164+
## Keyword Arguments
165+
166+
- `prob`: The base problem.
167+
- `prob_func`: Function to modify the base problem per trajectory.
168+
- `output_func`: Function to extract output from a solution.
169+
- `reduction`: Function to aggregate results.
170+
- `u_init`: Initial value for aggregation.
171+
- `safetycopy`: Whether to deepcopy the problem before modifying.
172+
"""
106173
function EnsembleProblem(prob;
107174
prob_func = DEFAULT_PROB_FUNC,
108175
output_func = DEFAULT_OUTPUT_FUNC,
@@ -116,6 +183,11 @@ function EnsembleProblem(prob;
116183
EnsembleProblem(prob, _prob_func, _output_func, _reduction, _u_init, safetycopy)
117184
end
118185

186+
"""
187+
$(TYPEDEF)
188+
189+
Alternate constructor that uses only keyword arguments.
190+
"""
119191
function EnsembleProblem(; prob,
120192
prob_func = DEFAULT_PROB_FUNC,
121193
output_func = DEFAULT_OUTPUT_FUNC,
@@ -125,32 +197,67 @@ function EnsembleProblem(; prob,
125197
EnsembleProblem(prob; prob_func, output_func, reduction, u_init, safetycopy)
126198
end
127199

128-
#since NonlinearProblem might want to use this dispatch as well
200+
"""
201+
$(TYPEDEF)
202+
203+
Constructor that is used for NOnlinearProblem.
204+
205+
!!! warning
206+
This dispatch is deprecated. See the Parallel Ensembles Simulations Interface page.
207+
"""
129208
function SciMLBase.EnsembleProblem(
130209
prob::AbstractSciMLProblem, u0s::Vector{Vector{T}}; kwargs...) where {T}
131-
Base.depwarn("This dispatch is deprecated for the standard ensemble syntax. See the Parallel
132-
Ensembles Simulations Interface page for more details", :EnsebleProblem)
210+
Base.depwarn("This dispatch is deprecated for the standard ensemble syntax. See the Parallel \
211+
Ensembles Simulations Interface page for more details", :EnsembleProblem)
133212
prob_func = (prob, i, repeat = nothing) -> remake(prob, u0 = u0s[i])
134213
return SciMLBase.EnsembleProblem(prob; prob_func, kwargs...)
135214
end
136215

216+
"""
217+
$(TYPEDEF)
218+
219+
Defines a weighted version of an `EnsembleProblem`, where different simulations contribute unequally.
220+
221+
## Arguments
222+
223+
- `ensembleprob`: The base ensemble problem.
224+
- `weights`: A vector of weights corresponding to each simulation.
225+
"""
137226
struct WeightedEnsembleProblem{T1 <: AbstractEnsembleProblem, T2 <: AbstractVector} <:
138227
AbstractEnsembleProblem
139228
ensembleprob::T1
140229
weights::T2
141230
end
231+
232+
"""
233+
Returns a list of all accessible properties, including those from the inner ensemble and `:weights`.
234+
"""
142235
function Base.propertynames(e::WeightedEnsembleProblem)
143236
(Base.propertynames(getfield(e, :ensembleprob))..., :weights)
144237
end
238+
239+
"""
240+
Accesses properties of a `WeightedEnsembleProblem`.
241+
242+
Returns `weights` or delegates to the underlying ensemble.
243+
"""
145244
function Base.getproperty(e::WeightedEnsembleProblem, f::Symbol)
146245
f === :weights && return getfield(e, :weights)
147246
f === :ensembleprob && return getfield(e, :ensembleprob)
148247
return getproperty(getfield(e, :ensembleprob), f)
149248
end
249+
250+
"""
251+
$(TYPEDEF)
252+
253+
Constructor for `WeightedEnsembleProblem`. Ensures weights sum to 1 and matches problem count.
254+
255+
"""
150256
function WeightedEnsembleProblem(args...; weights, kwargs...)
151257
# TODO: allow skipping checks?
152258
@assert sum(weights) 1
153259
ep = EnsembleProblem(args...; kwargs...)
154260
@assert length(ep.prob) == length(weights)
155261
WeightedEnsembleProblem(ep, weights)
156262
end
263+

0 commit comments

Comments
 (0)