Skip to content

Commit 2e6ce5c

Browse files
Merge pull request #586 from MilkshakeForReal/patch-1
Don't initialize parameters twice
2 parents c8baeec + 24c60d4 commit 2e6ce5c

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

src/discretize.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -437,8 +437,8 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
437437
end
438438
else
439439
x = map(chain) do x
440-
_x = ComponentArrays.ComponentArray(Lux.setup(Random.default_rng(),
441-
x)[1])
440+
_x = ComponentArrays.ComponentArray(Lux.initialparameters(Random.default_rng(),
441+
x))
442442
Float64.(_x) # No ComponentArray GPU support
443443
end
444444
names = ntuple(i -> depvars[i], length(chain))
@@ -451,8 +451,8 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
451451
init_params = init_params isa Array ? Float64.(init_params) :
452452
init_params
453453
else
454-
init_params = Float64.(ComponentArrays.ComponentArray(Lux.setup(Random.default_rng(),
455-
chain)[1]))
454+
init_params = Float64.(ComponentArrays.ComponentArray(Lux.initialparameters(Random.default_rng(),
455+
chain)))
456456
end
457457
end
458458
else

src/pinn_types.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ mutable struct Phi{C, S}
346346
f::C
347347
st::S
348348
function Phi(chain::Lux.AbstractExplicitLayer)
349-
ps, st = Lux.setup(Random.default_rng(), chain)
349+
st = Lux.initialstates(Random.default_rng(), chain)
350350
new{typeof(chain), typeof(st)}(chain, st)
351351
end
352352
function Phi(chain::Flux.Chain)

0 commit comments

Comments
 (0)