@@ -613,31 +613,30 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
613
613
function get_likelihood_estimate_function (discretization:: PhysicsInformedNN )
614
614
function full_loss_function (θ, p)
615
615
# the aggregation happens on cpu even if the losses are gpu, probably fine since it's only a few of them
616
- pde_losses = [pde_loss_function (θ) for pde_loss_function in pde_loss_functions]
617
- bc_losses = [bc_loss_function (θ) for bc_loss_function in bc_loss_functions]
616
+ # we need to type annotate the empty vector for autodiff to succeed in the case of empty equations/additional symbolic loss/boundary conditions.
617
+ pde_losses = num_pde_losses == 0 ? adaloss_T[] : [pde_loss_function (θ) for pde_loss_function in pde_loss_functions]
618
+ asl_losses = num_asl_losses == 0 ? adaloss_T[] : [asl_loss_function (θ) for asl_loss_function in asl_loss_functions]
619
+ bc_losses = num_bc_losses == 0 ? adaloss_T[] : [bc_loss_function (θ) for bc_loss_function in bc_loss_functions]
618
620
619
621
# this is kind of a hack, and means that whenever the outer function is evaluated the increment goes up, even if it's not being optimized
620
622
# that's why we prefer the user to maintain the increment in the outer loop callback during optimization
621
623
ChainRulesCore. @ignore_derivatives if self_increment
622
624
iteration[1 ] += 1
623
625
end
624
- # the aggregation happens on cpu even if the losses are gpu, probably fine since it's only a few of them
625
- # we need to type annotate the empty vector for autodiff to succeed in the case of empty equations/additional symbolic loss/boundary conditions.
626
- pde_losses = num_pde_losses == 0 ? adaloss_T[] : [pde_loss_function (θ) for pde_loss_function in pde_loss_functions]
627
- asl_losses = num_asl_losses == 0 ? adaloss_T[] : [asl_loss_function (θ) for asl_loss_function in asl_loss_functions]
628
- bc_losses = num_bc_losses == 0 ? adaloss_T[] : [bc_loss_function (θ) for bc_loss_function in bc_loss_functions]
629
626
630
627
ChainRulesCore. @ignore_derivatives begin
631
628
reweight_losses_func (θ, pde_losses,
632
- bc_losses)
629
+ asl_losses, bc_losses)
633
630
end
634
631
635
632
weighted_pde_losses = adaloss. pde_loss_weights .* pde_losses
633
+ weighted_asl_losses = adaloss. asl_loss_weights .* asl_losses
636
634
weighted_bc_losses = adaloss. bc_loss_weights .* bc_losses
637
635
638
636
sum_weighted_pde_losses = sum (weighted_pde_losses)
637
+ sum_weighted_asl_losses = sum (weighted_asl_losses)
639
638
sum_weighted_bc_losses = sum (weighted_bc_losses)
640
- weighted_loss_before_additional = sum_weighted_pde_losses + sum_weighted_bc_losses
639
+ weighted_loss_before_additional = sum_weighted_pde_losses + sum_weighted_asl_losses + sum_weighted_bc_losses
641
640
642
641
full_weighted_loss = if additional_loss isa Nothing
643
642
weighted_loss_before_additional
@@ -694,21 +693,12 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
694
693
iteration[1 ])
695
694
end
696
695
end
697
- ChainRulesCore. @ignore_derivatives begin reweight_losses_func (θ, pde_losses,
698
- asl_losses, bc_losses) end
699
696
700
697
return full_weighted_loss
701
698
end
702
- weighted_pde_losses = adaloss. pde_loss_weights .* pde_losses
703
- weighted_asl_losses = adaloss. asl_loss_weights .* asl_losses
704
- weighted_bc_losses = adaloss. bc_loss_weights .* bc_losses
705
699
706
700
return full_loss_function
707
701
end
708
- sum_weighted_pde_losses = sum (weighted_pde_losses)
709
- sum_weighted_asl_losses = sum (weighted_asl_losses)
710
- sum_weighted_bc_losses = sum (weighted_bc_losses)
711
- weighted_loss_before_additional = sum_weighted_pde_losses + sum_weighted_asl_losses + sum_weighted_bc_losses
712
702
713
703
function get_likelihood_estimate_function (discretization:: BayesianPINN )
714
704
dataset_pde, dataset_bc = discretization. dataset
@@ -796,46 +786,6 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
796
786
end
797
787
798
788
return full_loss_function
799
- ChainRulesCore. @ignore_derivatives begin if iteration[1 ] % log_frequency == 0
800
- logvector (pinnrep. logger, pde_losses, " unweighted_loss/pde_losses" ,
801
- iteration[1 ])
802
- logvector (pinnrep. logger, asl_losses, " unweighted_loss/asl_losses" ,
803
- iteration[1 ])
804
- logvector (pinnrep. logger, bc_losses, " unweighted_loss/bc_losses" , iteration[1 ])
805
- logvector (pinnrep. logger, weighted_pde_losses,
806
- " weighted_loss/weighted_pde_losses" ,
807
- iteration[1 ])
808
- logvector (pinnrep. logger, weighted_asl_losses,
809
- " weighted_loss/weighted_asl_losses" ,
810
- iteration[1 ])
811
- logvector (pinnrep. logger, weighted_bc_losses,
812
- " weighted_loss/weighted_bc_losses" ,
813
- iteration[1 ])
814
- if ! (additional_loss isa Nothing)
815
- logscalar (pinnrep. logger, weighted_additional_loss_val,
816
- " weighted_loss/weighted_additional_loss" , iteration[1 ])
817
- end
818
- logscalar (pinnrep. logger, sum_weighted_pde_losses,
819
- " weighted_loss/sum_weighted_pde_losses" , iteration[1 ])
820
- logscalar (pinnrep. logger, sum_weighted_bc_losses,
821
- " weighted_loss/sum_weighted_bc_losses" , iteration[1 ])
822
- logscalar (pinnrep. logger, sum_weighted_asl_losses,
823
- " weighted_loss/sum_weighted_asl_losses" , iteration[1 ])
824
- logscalar (pinnrep. logger, full_weighted_loss,
825
- " weighted_loss/full_weighted_loss" ,
826
- iteration[1 ])
827
- logvector (pinnrep. logger, adaloss. pde_loss_weights,
828
- " adaptive_loss/pde_loss_weights" ,
829
- iteration[1 ])
830
- logvector (pinnrep. logger, adaloss. asl_loss_weights,
831
- " adaptive_loss/asl_loss_weights" ,
832
- iteration[1 ])
833
- logvector (pinnrep. logger, adaloss. bc_loss_weights,
834
- " adaptive_loss/bc_loss_weights" ,
835
- iteration[1 ])
836
- end end
837
-
838
- return full_weighted_loss
839
789
end
840
790
841
791
full_loss_function = get_likelihood_estimate_function (discretization)
0 commit comments