Skip to content

Commit ba5dbf9

Browse files
authored
Merge pull request #781 from AstitvaAggarwal/Bpinn_pde
BPINN solver Docs(Manual and tutorial)
2 parents 0687aaf + ee5c1df commit ba5dbf9

File tree

13 files changed

+272
-52
lines changed

13 files changed

+272
-52
lines changed

docs/Project.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
[deps]
2+
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
23
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
34
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
45
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
56
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
67
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
78
Integrals = "de52edbc-65ea-441a-8357-d3a637375a31"
89
IntegralsCubature = "c31f79ba-6e32-46d4-a52f-182a8ac42a54"
10+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
911
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
1012
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
13+
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
1114
NeuralPDE = "315f7962-48a3-4962-8226-d0f33b1235f0"
1215
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
1316
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
@@ -20,6 +23,7 @@ Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
2023
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2124

2225
[compat]
26+
AdvancedHMC = "0.5"
2327
DiffEqBase = "6.106"
2428
Documenter = "1"
2529
DomainSets = "0.6"
@@ -28,6 +32,7 @@ Integrals = "3.3"
2832
IntegralsCubature = "=0.2.2"
2933
Lux = "0.4, 0.5"
3034
ModelingToolkit = "8.33"
35+
MonteCarloMeasurements = "1"
3136
NeuralPDE = "5.3"
3237
Optimization = "3.9"
3338
OptimizationOptimJL = "0.1"

docs/pages.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ pages = ["index.md",
44
#"examples/nnrode_example.md", # currently incorrect
55
],
66
"PDE PINN Tutorials" => Any["Introduction to NeuralPDE for PDEs" => "tutorials/pdesystem.md",
7+
"Bayesian PINNs for PDEs" => "tutorials/low_level_2.md",
78
"Using GPUs" => "tutorials/gpu.md",
89
"Defining Systems of PDEs" => "tutorials/systems.md",
910
"Imposing Constraints" => "tutorials/constraints.md",
@@ -21,6 +22,7 @@ pages = ["index.md",
2122
"examples/nonlinear_hyperbolic.md"],
2223
"Manual" => Any["manual/ode.md",
2324
"manual/pinns.md",
25+
"manual/bpinns.md",
2426
"manual/training_strategies.md",
2527
"manual/adaptive_losses.md",
2628
"manual/logging.md",

docs/src/manual/bpinns.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# `BayesianPINN` Discretizer for PDESystems
2+
3+
Using the Bayesian PINN solvers, we can solve general nonlinear PDEs, ODEs and also simultaneously perform parameter estimation on them.
4+
5+
Note: The BPINN PDE solver also works for ODEs defined using ModelingToolkit, [ModelingToolkit.jl PDESystem documentation](https://docs.sciml.ai/ModelingToolkit/stable/systems/PDESystem/). Despite this, the ODE specific BPINN solver `BNNODE` [refer](https://docs.sciml.ai/NeuralPDE/dev/manual/ode/#NeuralPDE.BNNODE) exists and uses `NeuralPDE.ahmc_bayesian_pinn_ode` at a lower level.
6+
7+
# `BayesianPINN` Discretizer for PDESystems and lower level Bayesian PINN Solver calls for PDEs and ODEs.
8+
9+
```@docs
10+
NeuralPDE.BayesianPINN
11+
NeuralPDE.ahmc_bayesian_pinn_ode
12+
NeuralPDE.ahmc_bayesian_pinn_pde
13+
```
14+
15+
## `symbolic_discretize` for `BayesianPINN` and lower level interface.
16+
17+
```@docs
18+
SciMLBase.symbolic_discretize(::PDESystem, ::NeuralPDE.AbstractPINN)
19+
NeuralPDE.BPINNstats
20+
NeuralPDE.BPINNsolution
21+
```
22+

docs/src/manual/pinns.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ NeuralPDE.Phi
2929
SciMLBase.discretize(::PDESystem, ::NeuralPDE.PhysicsInformedNN)
3030
```
3131

32-
## `symbolic_discretize` and the lower-level interface
32+
## `symbolic_discretize` for `PhysicsInformedNN` and the lower-level interface
3333

3434
```@docs
35-
SciMLBase.symbolic_discretize(::PDESystem, ::NeuralPDE.PhysicsInformedNN)
35+
SciMLBase.symbolic_discretize(::PDESystem, ::NeuralPDE.AbstractPINN)
3636
NeuralPDE.PINNRepresentation
3737
NeuralPDE.PINNLossFunctions
3838
```

docs/src/tutorials/low_level.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Investigating `symbolic_discretize` with the 1-D Burgers' Equation
1+
# Investigating `symbolic_discretize` with the `PhysicsInformedNN` Discretizer for the 1-D Burgers' Equation
22

33
Let's consider the Burgers' equation:
44

docs/src/tutorials/low_level_2.md

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Using `ahmc_bayesian_pinn_pde` with the `BayesianPINN` Discretizer for the Kuramoto–Sivashinsky equation
2+
3+
Consider the Kuramoto–Sivashinsky equation:
4+
5+
```math
6+
∂_t u(x, t) + u(x, t) ∂_x u(x, t) + \alpha ∂^2_x u(x, t) + \beta ∂^3_x u(x, t) + \gamma ∂^4_x u(x, t) = 0 \, ,
7+
```
8+
9+
where $\alpha = \gamma = 1$ and $\beta = 4$. The exact solution is:
10+
11+
```math
12+
u_e(x, t) = 11 + 15 \tanh \theta - 15 \tanh^2 \theta - 15 \tanh^3 \theta \, ,
13+
```
14+
15+
where $\theta = t - x/2$ and with initial and boundary conditions:
16+
17+
```math
18+
\begin{align*}
19+
u( x, 0) &= u_e( x, 0) \, ,\\
20+
u( 10, t) &= u_e( 10, t) \, ,\\
21+
u(-10, t) &= u_e(-10, t) \, ,\\
22+
∂_x u( 10, t) &= ∂_x u_e( 10, t) \, ,\\
23+
∂_x u(-10, t) &= ∂_x u_e(-10, t) \, .
24+
\end{align*}
25+
```
26+
27+
With Bayesian Physics-Informed Neural Networks, here is an example of using `BayesianPINN` discretization with `ahmc_bayesian_pinn_pde` :
28+
29+
```@example low_level_2
30+
using NeuralPDE, Flux, Lux, ModelingToolkit, LinearAlgebra, AdvancedHMC
31+
import ModelingToolkit: Interval, infimum, supremum, Distributions
32+
using Plots, MonteCarloMeasurements
33+
34+
@parameters x, t, α
35+
@variables u(..)
36+
Dt = Differential(t)
37+
Dx = Differential(x)
38+
Dx2 = Differential(x)^2
39+
Dx3 = Differential(x)^3
40+
Dx4 = Differential(x)^4
41+
42+
# α = 1
43+
β = 4
44+
γ = 1
45+
eq = Dt(u(x, t)) + u(x, t) * Dx(u(x, t)) + α * Dx2(u(x, t)) + β * Dx3(u(x, t)) + γ * Dx4(u(x, t)) ~ 0
46+
47+
u_analytic(x, t; z = -x / 2 + t) = 11 + 15 * tanh(z) - 15 * tanh(z)^2 - 15 * tanh(z)^3
48+
du(x, t; z = -x / 2 + t) = 15 / 2 * (tanh(z) + 1) * (3 * tanh(z) - 1) * sech(z)^2
49+
50+
bcs = [u(x, 0) ~ u_analytic(x, 0),
51+
u(-10, t) ~ u_analytic(-10, t),
52+
u(10, t) ~ u_analytic(10, t),
53+
Dx(u(-10, t)) ~ du(-10, t),
54+
Dx(u(10, t)) ~ du(10, t)]
55+
56+
# Space and time domains
57+
domains = [x ∈ Interval(-10.0, 10.0),
58+
t ∈ Interval(0.0, 1.0)]
59+
60+
# Discretization
61+
dx = 0.4;
62+
dt = 0.2;
63+
64+
# Function to compute analytical solution at a specific point (x, t)
65+
function u_analytic_point(x, t)
66+
z = -x / 2 + t
67+
return 11 + 15 * tanh(z) - 15 * tanh(z)^2 - 15 * tanh(z)^3
68+
end
69+
70+
# Function to generate the dataset matrix
71+
function generate_dataset_matrix(domains, dx, dt)
72+
x_values = -10:dx:10
73+
t_values = 0.0:dt:1.0
74+
75+
dataset = []
76+
77+
for t in t_values
78+
for x in x_values
79+
u_value = u_analytic_point(x, t)
80+
push!(dataset, [u_value, x, t])
81+
end
82+
end
83+
84+
return vcat([data' for data in dataset]...)
85+
end
86+
87+
datasetpde = [generate_dataset_matrix(domains, dx, dt)]
88+
89+
# noise to dataset
90+
noisydataset = deepcopy(datasetpde)
91+
noisydataset[1][:, 1] = noisydataset[1][:, 1] .+ randn(size(noisydataset[1][:, 1])) .* 5 / 100 .*
92+
noisydataset[1][:, 1]
93+
```
94+
95+
Plotting dataset, added noise is set at 5%.
96+
```@example low_level_2
97+
plot(datasetpde[1][:, 2], datasetpde[1][:, 1], title="Dataset from Analytical Solution")
98+
plot!(noisydataset[1][:, 2], noisydataset[1][:, 1])
99+
```
100+
101+
```@example low_level_2
102+
# Neural network
103+
chain = Lux.Chain(Lux.Dense(2, 8, Lux.tanh),
104+
Lux.Dense(8, 8, Lux.tanh),
105+
Lux.Dense(8, 1))
106+
107+
discretization = NeuralPDE.BayesianPINN([chain],
108+
GridTraining([dx, dt]), param_estim = true, dataset = [noisydataset, nothing])
109+
110+
@named pde_system = PDESystem(eq,
111+
bcs,
112+
domains,
113+
[x, t],
114+
[u(x, t)],
115+
[α],
116+
defaults = Dict([α => 0.5]))
117+
118+
sol1 = ahmc_bayesian_pinn_pde(pde_system,
119+
discretization;
120+
draw_samples = 100, Kernel = AdvancedHMC.NUTS(0.8),
121+
bcstd = [0.2, 0.2, 0.2, 0.2, 0.2],
122+
phystd = [1.0], l2std = [0.05], param = [Distributions.LogNormal(0.5, 2)],
123+
priorsNNw = (0.0, 10.0),
124+
saveats = [1 / 100.0, 1 / 100.0], progress = true)
125+
```
126+
127+
And some analysis:
128+
129+
```@example low_level_2
130+
phi = discretization.phi[1]
131+
xs, ts = [infimum(d.domain):dx:supremum(d.domain) for (d, dx) in zip(domains, [dx / 10, dt])]
132+
u_predict = [[first(pmean(phi([x, t], sol1.estimated_nn_params[1]))) for x in xs]
133+
for t in ts]
134+
u_real = [[u_analytic(x, t) for x in xs] for t in ts]
135+
diff_u = [[abs(u_analytic(x, t) - first(pmean(phi([x, t], sol1.estimated_nn_params[1]))))
136+
for x in xs]
137+
for t in ts]
138+
139+
p1 = plot(xs, u_predict, title = "predict")
140+
p2 = plot(xs, u_real, title = "analytic")
141+
p3 = plot(xs, diff_u, title = "error")
142+
plot(p1, p2, p3)
143+
```

src/BPINN_ode.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ of the physics-informed neural network which is used as a solver for a standard
2626
* `Kernel`: Choice of MCMC Sampling Algorithm. Defaults to `AdvancedHMC.HMC`
2727
2828
## Keyword Arguments
29-
(refer ahmc_bayesian_pinn_ode() keyword arguments.)
29+
(refer `NeuralPDE.ahmc_bayesian_pinn_ode` keyword arguments.)
3030
3131
## Example
3232

src/NeuralPDE.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,6 @@ export NNODE, TerminalPDEProblem, NNPDEHan, NNPDENS, NNRODE,
6767
AbstractAdaptiveLoss, NonAdaptiveLoss, GradientScaleAdaptiveLoss,
6868
MiniMaxAdaptiveLoss, LogOptions,
6969
ahmc_bayesian_pinn_ode, BNNODE, ahmc_bayesian_pinn_pde, vector_to_parameters,
70-
BPINNsolution
70+
BPINNsolution, BayesianPINN
7171

7272
end # module

src/PDE_BPINN.jl

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,52 @@ function inference(samples, pinnrep, saveats, numensemble, ℓπ)
295295
end
296296
end
297297

298-
# priors: pdf for W,b + pdf for ODE params
298+
"""
299+
```julia
300+
ahmc_bayesian_pinn_pde(pde_system, discretization;
301+
draw_samples = 1000,
302+
bcstd = [0.01], l2std = [0.05],
303+
phystd = [0.05], priorsNNw = (0.0, 2.0),
304+
param = [], nchains = 1, Kernel = HMC(0.1, 30),
305+
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
306+
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
307+
Integratorkwargs = (Integrator = Leapfrog,), saveats = [1 / 10.0],
308+
numensemble = floor(Int, draw_samples / 3), progress = false, verbose = false)
309+
```
310+
## NOTES
311+
* Dataset is required for accurate Parameter estimation + solving equations.
312+
* Returned solution is a BPINNsolution consisting of Ensemble solution, estimated PDE and NN parameters
313+
for chosen `saveats` grid spacing and last n = `numensemble` samples in Chain. the complete set of samples
314+
in the MCMC chain is returned as `fullsolution`, refer `BPINNsolution` for more details.
315+
316+
## Positional Arguments
317+
* `pde_system`: ModelingToolkit defined PDE equation or system of equations.
318+
* `discretization`: BayesianPINN discretization for the given pde_system, Neural Network and training strategy.
319+
320+
## Keyword Arguments
321+
* `draw_samples`: number of samples to be drawn in the MCMC algorithms (warmup samples are ~2/3 of draw samples)
322+
* `bcstd`: Vector of standard deviations of BPINN prediction against Initial/Boundary Condition equations.
323+
* `l2std`: Vector of standard deviations of BPINN prediction against L2 losses/Dataset for each dependant variable of interest.
324+
* `phystd`: Vector of standard deviations of BPINN prediction against Chosen Underlying PDE equations.
325+
* `priorsNNw`: Tuple of (mean, std) for BPINN Network parameters. Weights and Biases of BPINN are Normal Distributions by default.
326+
* `param`: Vector of chosen PDE's parameter's Distributions in case of Inverse problems.
327+
* `nchains`: number of chains you want to sample
328+
329+
# AdvancedHMC.jl is still developing convenience structs so might need changes on new releases.
330+
* `Kernel`: Choice of MCMC Sampling Algorithm object HMC/NUTS/HMCDA (AdvancedHMC.jl implemenations ).
331+
* `Adaptorkwargs`: `Adaptor`, `Metric`, `targetacceptancerate`. Refer: https://turinglang.org/AdvancedHMC.jl/stable/
332+
Note: Target percentage(in decimal) of iterations in which the proposals are accepted (0.8 by default)
333+
* `Integratorkwargs`: `Integrator`, `jitter_rate`, `tempering_rate`. Refer: https://turinglang.org/AdvancedHMC.jl/stable/
334+
* `saveats`: Grid spacing for each independant variable for evaluation of ensemble solution, estimated parameters.
335+
* `numensemble`: Number of last samples to take for creation of ensemble solution, estimated parameters.
336+
* `progress`: controls whether to show the progress meter or not.
337+
* `verbose`: controls the verbosity. (Sample call args in AHMC)
338+
339+
"""
340+
341+
"""
342+
priors: pdf for W,b + pdf for PDE params
343+
"""
299344
function ahmc_bayesian_pinn_pde(pde_system, discretization;
300345
draw_samples = 1000,
301346
bcstd = [0.01], l2std = [0.05],
@@ -369,6 +414,7 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
369414
#ode parameter estimation
370415
nparameters = length(initial_θ)
371416
ninv = length(param)
417+
# add init_params for NN params
372418
priors = [
373419
MvNormal(priorsNNw[1] * ones(nparameters),
374420
LinearAlgebra.Diagonal(abs2.(priorsNNw[2] .* ones(nparameters)))),

src/advancedHMC_MCMC.jl

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -436,40 +436,34 @@ Incase you are only solving the Equations for solution, do not provide dataset
436436
437437
## Keyword Arguments
438438
* `strategy`: The training strategy used to choose the points for the evaluations. By default GridTraining is used with given physdt discretization.
439-
* `dataset`: Vector containing Vectors of corresponding u,t values
440439
* `init_params`: intial parameter values for BPINN (ideally for multiple chains different initializations preferred)
441-
* `nchains`: number of chains you want to sample (random initialisation of params by default)
440+
* `nchains`: number of chains you want to sample
442441
* `draw_samples`: number of samples to be drawn in the MCMC algorithms (warmup samples are ~2/3 of draw samples)
443-
* `l2std`: standard deviation of BPINN predicition against L2 losses/Dataset
444-
* `phystd`: standard deviation of BPINN predicition against Chosen Underlying ODE System
445-
* `priorsNNw`: Vector of [mean, std] for BPINN parameter. Weights and Biases of BPINN are Normal Distributions by default
442+
* `l2std`: standard deviation of BPINN prediction against L2 losses/Dataset
443+
* `phystd`: standard deviation of BPINN prediction against Chosen Underlying ODE System
444+
* `priorsNNw`: Tuple of (mean, std) for BPINN Network parameters. Weights and Biases of BPINN are Normal Distributions by default.
446445
* `param`: Vector of chosen ODE parameters Distributions in case of Inverse problems.
447446
* `autodiff`: Boolean Value for choice of Derivative Backend(default is numerical)
448447
* `physdt`: Timestep for approximating ODE in it's Time domain. (1/20.0 by default)
449448
450449
# AdvancedHMC.jl is still developing convenience structs so might need changes on new releases.
451450
* `Kernel`: Choice of MCMC Sampling Algorithm (AdvancedHMC.jl implemenations HMC/NUTS/HMCDA)
452-
* `Integratorkwargs`: A NamedTuple containing the chosen integrator and its keyword Arguments, as follows :
453-
* `Integrator`: https://turinglang.org/AdvancedHMC.jl/stable/
454-
* `jitter_rate`: https://turinglang.org/AdvancedHMC.jl/stable/
455-
* `tempering_rate`: https://turinglang.org/AdvancedHMC.jl/stable/
456-
* `Adaptorkwargs`: A NamedTuple containing the chosen Adaptor, it's Metric and targetacceptancerate, as follows :
457-
* `Adaptor`: https://turinglang.org/AdvancedHMC.jl/stable/
458-
* `Metric`: https://turinglang.org/AdvancedHMC.jl/stable/
459-
* `targetacceptancerate`: Target percentage(in decimal) of iterations in which the proposals were accepted(0.8 by default)
451+
* `Integratorkwargs`: `Integrator`, `jitter_rate`, `tempering_rate`. Refer: https://turinglang.org/AdvancedHMC.jl/stable/
452+
* `Adaptorkwargs`: `Adaptor`, `Metric`, `targetacceptancerate`. Refer: https://turinglang.org/AdvancedHMC.jl/stable/
453+
Note: Target percentage(in decimal) of iterations in which the proposals are accepted (0.8 by default)
460454
* `MCMCargs`: A NamedTuple containing all the chosen MCMC kernel's(HMC/NUTS/HMCDA) Arguments, as follows :
461455
* `n_leapfrog`: number of leapfrog steps for HMC
462456
* `δ`: target acceptance probability for NUTS and HMCDA
463457
* `λ`: target trajectory length for HMCDA
464458
* `max_depth`: Maximum doubling tree depth (NUTS)
465459
* `Δ_max`: Maximum divergence during doubling tree (NUTS)
460+
Refer: https://turinglang.org/AdvancedHMC.jl/stable/
466461
* `progress`: controls whether to show the progress meter or not.
467462
* `verbose`: controls the verbosity. (Sample call args in AHMC)
468463
469464
"""
470465

471466
"""
472-
dataset would be (x̂,t)
473467
priors: pdf for W,b + pdf for ODE params
474468
"""
475469
function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain;

0 commit comments

Comments
 (0)