Skip to content

Commit 414e397

Browse files
committed
rebase: fixes
1 parent 177e137 commit 414e397

File tree

2 files changed

+16
-63
lines changed

2 files changed

+16
-63
lines changed

src/discretize.jl

+8-58
Original file line numberDiff line numberDiff line change
@@ -613,31 +613,30 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
613613
function get_likelihood_estimate_function(discretization::PhysicsInformedNN)
614614
function full_loss_function(θ, p)
615615
# the aggregation happens on cpu even if the losses are gpu, probably fine since it's only a few of them
616-
pde_losses = [pde_loss_function(θ) for pde_loss_function in pde_loss_functions]
617-
bc_losses = [bc_loss_function(θ) for bc_loss_function in bc_loss_functions]
616+
# we need to type annotate the empty vector for autodiff to succeed in the case of empty equations/additional symbolic loss/boundary conditions.
617+
pde_losses = num_pde_losses == 0 ? adaloss_T[] : [pde_loss_function(θ) for pde_loss_function in pde_loss_functions]
618+
asl_losses = num_asl_losses == 0 ? adaloss_T[] : [asl_loss_function(θ) for asl_loss_function in asl_loss_functions]
619+
bc_losses = num_bc_losses == 0 ? adaloss_T[] : [bc_loss_function(θ) for bc_loss_function in bc_loss_functions]
618620

619621
# this is kind of a hack, and means that whenever the outer function is evaluated the increment goes up, even if it's not being optimized
620622
# that's why we prefer the user to maintain the increment in the outer loop callback during optimization
621623
ChainRulesCore.@ignore_derivatives if self_increment
622624
iteration[1] += 1
623625
end
624-
# the aggregation happens on cpu even if the losses are gpu, probably fine since it's only a few of them
625-
# we need to type annotate the empty vector for autodiff to succeed in the case of empty equations/additional symbolic loss/boundary conditions.
626-
pde_losses = num_pde_losses == 0 ? adaloss_T[] : [pde_loss_function(θ) for pde_loss_function in pde_loss_functions]
627-
asl_losses = num_asl_losses == 0 ? adaloss_T[] : [asl_loss_function(θ) for asl_loss_function in asl_loss_functions]
628-
bc_losses = num_bc_losses == 0 ? adaloss_T[] : [bc_loss_function(θ) for bc_loss_function in bc_loss_functions]
629626

630627
ChainRulesCore.@ignore_derivatives begin
631628
reweight_losses_func(θ, pde_losses,
632-
bc_losses)
629+
asl_losses, bc_losses)
633630
end
634631

635632
weighted_pde_losses = adaloss.pde_loss_weights .* pde_losses
633+
weighted_asl_losses = adaloss.asl_loss_weights .* asl_losses
636634
weighted_bc_losses = adaloss.bc_loss_weights .* bc_losses
637635

638636
sum_weighted_pde_losses = sum(weighted_pde_losses)
637+
sum_weighted_asl_losses = sum(weighted_asl_losses)
639638
sum_weighted_bc_losses = sum(weighted_bc_losses)
640-
weighted_loss_before_additional = sum_weighted_pde_losses + sum_weighted_bc_losses
639+
weighted_loss_before_additional = sum_weighted_pde_losses + sum_weighted_asl_losses + sum_weighted_bc_losses
641640

642641
full_weighted_loss = if additional_loss isa Nothing
643642
weighted_loss_before_additional
@@ -694,21 +693,12 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
694693
iteration[1])
695694
end
696695
end
697-
ChainRulesCore.@ignore_derivatives begin reweight_losses_func(θ, pde_losses,
698-
asl_losses, bc_losses) end
699696

700697
return full_weighted_loss
701698
end
702-
weighted_pde_losses = adaloss.pde_loss_weights .* pde_losses
703-
weighted_asl_losses = adaloss.asl_loss_weights .* asl_losses
704-
weighted_bc_losses = adaloss.bc_loss_weights .* bc_losses
705699

706700
return full_loss_function
707701
end
708-
sum_weighted_pde_losses = sum(weighted_pde_losses)
709-
sum_weighted_asl_losses = sum(weighted_asl_losses)
710-
sum_weighted_bc_losses = sum(weighted_bc_losses)
711-
weighted_loss_before_additional = sum_weighted_pde_losses + sum_weighted_asl_losses + sum_weighted_bc_losses
712702

713703
function get_likelihood_estimate_function(discretization::BayesianPINN)
714704
dataset_pde, dataset_bc = discretization.dataset
@@ -796,46 +786,6 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
796786
end
797787

798788
return full_loss_function
799-
ChainRulesCore.@ignore_derivatives begin if iteration[1] % log_frequency == 0
800-
logvector(pinnrep.logger, pde_losses, "unweighted_loss/pde_losses",
801-
iteration[1])
802-
logvector(pinnrep.logger, asl_losses, "unweighted_loss/asl_losses",
803-
iteration[1])
804-
logvector(pinnrep.logger, bc_losses, "unweighted_loss/bc_losses", iteration[1])
805-
logvector(pinnrep.logger, weighted_pde_losses,
806-
"weighted_loss/weighted_pde_losses",
807-
iteration[1])
808-
logvector(pinnrep.logger, weighted_asl_losses,
809-
"weighted_loss/weighted_asl_losses",
810-
iteration[1])
811-
logvector(pinnrep.logger, weighted_bc_losses,
812-
"weighted_loss/weighted_bc_losses",
813-
iteration[1])
814-
if !(additional_loss isa Nothing)
815-
logscalar(pinnrep.logger, weighted_additional_loss_val,
816-
"weighted_loss/weighted_additional_loss", iteration[1])
817-
end
818-
logscalar(pinnrep.logger, sum_weighted_pde_losses,
819-
"weighted_loss/sum_weighted_pde_losses", iteration[1])
820-
logscalar(pinnrep.logger, sum_weighted_bc_losses,
821-
"weighted_loss/sum_weighted_bc_losses", iteration[1])
822-
logscalar(pinnrep.logger, sum_weighted_asl_losses,
823-
"weighted_loss/sum_weighted_asl_losses", iteration[1])
824-
logscalar(pinnrep.logger, full_weighted_loss,
825-
"weighted_loss/full_weighted_loss",
826-
iteration[1])
827-
logvector(pinnrep.logger, adaloss.pde_loss_weights,
828-
"adaptive_loss/pde_loss_weights",
829-
iteration[1])
830-
logvector(pinnrep.logger, adaloss.asl_loss_weights,
831-
"adaptive_loss/asl_loss_weights",
832-
iteration[1])
833-
logvector(pinnrep.logger, adaloss.bc_loss_weights,
834-
"adaptive_loss/bc_loss_weights",
835-
iteration[1])
836-
end end
837-
838-
return full_weighted_loss
839789
end
840790

841791
full_loss_function = get_likelihood_estimate_function(discretization)

src/pinn_types.jl

+8-5
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ASL, ADA, LOG, K} <: AbstractPIN
103103
derivative = nothing,
104104
param_estim = false,
105105
additional_loss = nothing,
106-
additional_symb_loss = [],
106+
additional_symb_loss = [],
107107
adaptive_loss = nothing,
108108
logger = nothing,
109109
log_options = LogOptions(),
@@ -136,14 +136,14 @@ struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ASL, ADA, LOG, K} <: AbstractPIN
136136

137137
new{typeof(strategy), typeof(init_params), typeof(_phi), typeof(_derivative),
138138
typeof(param_estim),
139-
typeof(additional_loss), typeof(adaptive_loss), typeof(logger), typeof(kwargs)}(chain,
139+
typeof(additional_loss), typeof(additional_symb_loss), typeof(adaptive_loss), typeof(logger), typeof(kwargs)}(chain,
140140
strategy,
141141
init_params,
142142
_phi,
143143
_derivative,
144144
param_estim,
145145
additional_loss,
146-
additional_symb_loss,
146+
additional_symb_loss,
147147
adaptive_loss,
148148
logger,
149149
log_options,
@@ -162,6 +162,7 @@ BayesianPINN(chain,
162162
phi = nothing,
163163
param_estim = false,
164164
additional_loss = nothing,
165+
additional_symb_loss = nothing,
165166
adaptive_loss = nothing,
166167
logger = nothing,
167168
log_options = LogOptions(),
@@ -211,14 +212,15 @@ methodology.
211212
* `iteration`: used to control the iteration counter???
212213
* `kwargs`: Extra keyword arguments.
213214
"""
214-
struct BayesianPINN{T, P, PH, DER, PE, AL, ADA, LOG, D, K} <: AbstractPINN
215+
struct BayesianPINN{T, P, PH, DER, PE, AL, ASL, ADA, LOG, D, K} <: AbstractPINN
215216
chain::Any
216217
strategy::T
217218
init_params::P
218219
phi::PH
219220
derivative::DER
220221
param_estim::PE
221222
additional_loss::AL
223+
additional_symb_loss::ASL
222224
adaptive_loss::ADA
223225
logger::LOG
224226
log_options::LogOptions
@@ -235,6 +237,7 @@ struct BayesianPINN{T, P, PH, DER, PE, AL, ADA, LOG, D, K} <: AbstractPINN
235237
derivative = nothing,
236238
param_estim = false,
237239
additional_loss = nothing,
240+
additional_symb_loss = nothing,
238241
adaptive_loss = nothing,
239242
logger = nothing,
240243
log_options = LogOptions(),
@@ -272,7 +275,7 @@ struct BayesianPINN{T, P, PH, DER, PE, AL, ADA, LOG, D, K} <: AbstractPINN
272275

273276
new{typeof(strategy), typeof(init_params), typeof(_phi), typeof(_derivative),
274277
typeof(param_estim),
275-
typeof(additional_loss), typeof(adaptive_loss), typeof(logger), typeof(dataset),
278+
typeof(additional_loss), typeof(additional_symb_loss), typeof(adaptive_loss), typeof(logger), typeof(dataset),
276279
typeof(kwargs)}(chain,
277280
strategy,
278281
init_params,

0 commit comments

Comments
 (0)