Skip to content

Commit 2ca1e5b

Browse files
committed
Updated as given feedback
1 parent 55d061b commit 2ca1e5b

File tree

1 file changed

+51
-196
lines changed

1 file changed

+51
-196
lines changed

src/Odelstm.jl

Lines changed: 51 additions & 196 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,6 @@
1-
"""
2-
ODE-LSTM: A Complete Julia Implementation for DiffEqFlux.jl
3-
4-
This implementation converts the original PyTorch ODE-LSTM to Julia, providing:
5-
- Multiple solver support (Tsit5, Euler, Heun, RK4)
6-
- Dataset handling for Person, ET-MNIST, and XOR tasks
7-
- Training and evaluation functionality
8-
- Integration with the SciML ecosystem
9-
10-
Original paper: https://arxiv.org/abs/2006.04418
11-
"""
12-
13-
module ODELSTM
14-
15-
using DifferentialEquations
16-
using Flux
17-
using DiffEqFlux
1+
using Lux
2+
using ComponentArrays
3+
using Optimisers
184
using LinearAlgebra
195
using Random
206
using Statistics
@@ -27,240 +13,144 @@ export PersonData, ETSMnistData, XORData
2713
export train_model!, evaluate_model, load_dataset
2814

2915
mutable struct ODELSTMCell{F,S}
30-
lstm_cell::Flux.LSTMCell
16+
lstm_cell::Lux.LSTMCell
3117
f_node::F
32-
solver_type::Symbol
3318
solver::S
3419
input_size::Int
3520
hidden_size::Int
3621
end
3722

38-
function ODELSTMCell(input_size::Int, hidden_size::Int, solver_type::Symbol=:dopri5)
39-
lstm_cell = Flux.LSTMCell(input_size => hidden_size)
23+
function ODELSTMCell(input_size::Int, hidden_size::Int, solver)
24+
lstm_cell = Lux.LSTMCell(input_size, hidden_size)
4025
f_node = Chain(
41-
Dense(hidden_size => hidden_size, tanh),
42-
Dense(hidden_size => hidden_size)
43-
)
44-
solver = get_solver(solver_type)
45-
return ODELSTMCell(lstm_cell, f_node, solver_type, solver, input_size, hidden_size)
46-
end
47-
48-
function get_solver(solver_type::Symbol)
49-
solver_map = Dict(
50-
:dopri5 => Tsit5(),
51-
:tsit5 => Tsit5(),
52-
:euler => Euler(),
53-
:heun => Heun(),
54-
:rk4 => RK4()
26+
Dense(hidden_size, hidden_size, tanh),
27+
Dense(hidden_size, hidden_size)
5528
)
56-
return get(solver_map, solver_type, Tsit5())
29+
return ODELSTMCell(lstm_cell, f_node, solver, input_size, hidden_size)
5730
end
5831

59-
function (cell::ODELSTMCell)(x, state, ts)
32+
function (cell::ODELSTMCell)(x, state, ts, p, st)
6033
h, c = state
61-
new_h, new_c = cell.lstm_cell(x, (h, c))
34+
new_h, new_c, st = cell.lstm_cell(x, (h, c), p, st)
6235

63-
if cell.solver_type in [:euler, :heun, :rk4]
64-
evolved_h = solve_fixed_step(cell, new_h, ts)
36+
if !(cell.solver isa Union{Tsit5,DP5,BS3})
37+
evolved_h, st = solve_fixed_step(cell, new_h, ts, p, st)
6538
else
66-
evolved_h = solve_adaptive(cell, new_h, ts)
39+
evolved_h, st = solve_adaptive(cell, new_h, ts, p, st)
6740
end
6841

69-
return evolved_h, (evolved_h, new_c)
42+
return evolved_h, (evolved_h, new_c), st
7043
end
7144

72-
function solve_fixed_step(cell::ODELSTMCell, h, ts)
73-
dt = ts / 3.0
74-
h_evolved = h
75-
for i in 1:3
76-
if cell.solver_type == :euler
77-
h_evolved = euler_step(cell.f_node, h_evolved, dt)
78-
elseif cell.solver_type == :heun
79-
h_evolved = heun_step(cell.f_node, h_evolved, dt)
80-
elseif cell.solver_type == :rk4
81-
h_evolved = rk4_step(cell.f_node, h_evolved, dt)
82-
end
83-
end
84-
return h_evolved
45+
function solve_fixed_step(cell::ODELSTMCell, h, ts, p, st)
46+
prob = ODEProblem((u,p,t)->cell.f_node(u,p,st)[1], h, (0.0f0, Float32(ts)))
47+
sol = solve(prob, cell.solver; adaptive=false)
48+
return sol.u[end], st
8549
end
8650

87-
function solve_adaptive(cell::ODELSTMCell, h, ts)
51+
function solve_adaptive(cell::ODELSTMCell, h, ts, p, st)
8852
if ndims(h) == 2
8953
batch_size = size(h, 2)
9054
results = similar(h)
9155

9256
for i in 1:batch_size
9357
h_i = h[:, i]
9458
ts_i = ts isa AbstractVector ? ts[i] : ts
95-
t_span = (0.0f0, Float32(ts_i) + 1f-6 * i)
96-
97-
function ode_func!(dh, h_state, p, t)
98-
dh .= cell.f_node(h_state)
99-
end
59+
t_span = (0.0f0, Float32(ts_i))
10060

101-
prob = ODEProblem(ode_func!, h_i, t_span)
61+
prob = ODEProblem((u,p,t)->cell.f_node(u,p,st)[1], h_i, t_span)
10262
sol = solve(prob, cell.solver, saveat=[t_span[2]], dense=false)
10363
results[:, i] = sol.u[end]
10464
end
105-
return results
65+
return results, st
10666
else
10767
t_span = (0.0f0, Float32(ts))
108-
109-
function ode_func!(dh, h_state, p, t)
110-
dh .= cell.f_node(h_state)
111-
end
112-
113-
prob = ODEProblem(ode_func!, h, t_span)
68+
prob = ODEProblem((u,p,t)->cell.f_node(u,p,st)[1], h, t_span)
11469
sol = solve(prob, cell.solver, saveat=[t_span[2]], dense=false)
115-
return sol.u[end]
70+
return sol.u[end], st
11671
end
11772
end
11873

119-
function euler_step(f, y, dt)
120-
dy = f(y)
121-
return y + dt * dy
122-
end
123-
124-
function heun_step(f, y, dt)
125-
k1 = f(y)
126-
k2 = f(y + dt * k1)
127-
return y + dt * 0.5f0 * (k1 + k2)
128-
end
129-
130-
function rk4_step(f, y, dt)
131-
k1 = f(y)
132-
k2 = f(y + k1 * dt * 0.5f0)
133-
k3 = f(y + k2 * dt * 0.5f0)
134-
k4 = f(y + k3 * dt)
135-
return y + dt * (k1 + 2*k2 + 2*k3 + k4) / 6.0f0
136-
end
137-
13874
struct ODELSTMModel{C,O}
13975
rnn_cell::C
14076
output_layer::O
14177
return_sequences::Bool
14278
end
14379

14480
function ODELSTMModel(in_features::Int, hidden_size::Int, out_features::Int;
145-
return_sequences::Bool=true, solver_type::Symbol=:dopri5)
146-
rnn_cell = ODELSTMCell(in_features, hidden_size, solver_type)
147-
output_layer = Dense(hidden_size => out_features)
81+
return_sequences::Bool=true, solver)
82+
rnn_cell = ODELSTMCell(in_features, hidden_size, solver)
83+
output_layer = Dense(hidden_size, out_features)
14884
return ODELSTMModel(rnn_cell, output_layer, return_sequences)
14985
end
15086

151-
Flux.@functor ODELSTMModel
152-
153-
function (model::ODELSTMModel)(x, timespans, mask=nothing)
154-
batch_size, seq_len, input_size = size(x)
155-
156-
h = zeros(Float32, model.rnn_cell.hidden_size, batch_size)
157-
c = zeros(Float32, model.rnn_cell.hidden_size, batch_size)
158-
159-
outputs = []
160-
last_output = zeros(Float32, size(model.output_layer.weight, 1), batch_size)
161-
162-
for t in 1:seq_len
163-
inputs = x[:, t, :]'
164-
ts = timespans[:, t]
165-
166-
h, (h, c) = model.rnn_cell(inputs, (h, c), ts)
167-
current_output = model.output_layer(h)
168-
push!(outputs, current_output)
169-
170-
if mask !== nothing
171-
cur_mask = mask[:, t]'
172-
last_output = cur_mask .* current_output + (1.0f0 .- cur_mask) .* last_output
173-
else
174-
last_output = current_output
175-
end
176-
end
177-
178-
if model.return_sequences
179-
return cat(outputs..., dims=3)
180-
else
181-
return last_output'
182-
end
183-
end
184-
185-
mutable struct IrregularSequenceLearner{M,O}
87+
mutable struct IrregularSequenceLearner{M,O,S}
18688
model::M
18789
optimizer::O
90+
states::S
18891
train_losses::Vector{Float32}
18992
val_losses::Vector{Float32}
19093
train_accs::Vector{Float32}
19194
val_accs::Vector{Float32}
19295
end
19396

194-
function IrregularSequenceLearner(model, lr=0.005f0)
195-
optimizer = ADAM(lr)
97+
function IrregularSequenceLearner(model, lr=0.005f0; rng=Random.default_rng())
98+
optimizer = Optimisers.Adam(lr)
99+
ps, st = Lux.setup(rng, model)
196100
return IrregularSequenceLearner(
197-
model, optimizer,
101+
model, optimizer, st,
198102
Float32[], Float32[], Float32[], Float32[]
199103
)
200104
end
201105

202-
function train_step!(learner::IrregularSequenceLearner, batch)
106+
function train_step!(learner::IrregularSequenceLearner, batch, p)
203107
if length(batch) == 4
204108
x, t, y, mask = batch
205109
else
206110
x, t, y = batch
207111
mask = nothing
208112
end
209113

210-
params = Flux.params(learner.model)
211-
212-
loss, grads = Flux.withgradient(params) do
213-
y_hat = learner.model(x, t, mask)
214-
if ndims(y_hat) == 3
215-
y_hat = reshape(y_hat, size(y_hat, 1), :)
216-
end
217-
y_flat = reshape(y, :)
218-
Flux.crossentropy(y_hat, Flux.onehotbatch(y_flat, 0:maximum(y_flat)))
114+
(loss, (y_hat, st)), grads = value_and_gradient(p) do p
115+
y_hat, st = learner.model(x, t, mask, p, learner.states)
116+
Lux.crossentropy(y_hat, Flux.onehotbatch(reshape(y, :), 0:maximum(y)))
219117
end
220118

221-
Flux.update!(learner.optimizer, params, grads)
119+
learner.states = st
120+
p = Optimisers.update(learner.optimizer, p, grads)
222121

223-
y_hat = learner.model(x, t, mask)
224-
if ndims(y_hat) == 3
225-
y_hat = reshape(y_hat, size(y_hat, 1), :)
226-
end
227-
y_flat = reshape(y, :)
228122
preds = Flux.onecold(y_hat) .- 1
229-
acc = mean(preds .== y_flat)
123+
acc = mean(preds .== reshape(y, :))
230124

231-
return loss, acc
125+
return loss, acc, p
232126
end
233127

234-
function validation_step(learner::IrregularSequenceLearner, batch)
128+
function validation_step(learner::IrregularSequenceLearner, batch, p)
235129
if length(batch) == 4
236130
x, t, y, mask = batch
237131
else
238132
x, t, y = batch
239133
mask = nothing
240134
end
241135

242-
y_hat = learner.model(x, t, mask)
136+
y_hat, st = learner.model(x, t, mask, p, learner.states)
243137

244-
if ndims(y_hat) == 3
245-
y_hat = reshape(y_hat, size(y_hat, 1), :)
246-
end
247-
y_flat = reshape(y, :)
248-
249-
loss = Flux.crossentropy(y_hat, Flux.onehotbatch(y_flat, 0:maximum(y_flat)))
138+
loss = Lux.crossentropy(y_hat, Flux.onehotbatch(reshape(y, :), 0:maximum(y)))
250139
preds = Flux.onecold(y_hat) .- 1
251-
acc = mean(preds .== y_flat)
140+
acc = mean(preds .== reshape(y, :))
252141

253142
return loss, acc
254143
end
255144

256145
function train_model!(learner::IrregularSequenceLearner, train_loader, val_loader=nothing; epochs=100)
146+
p = ComponentArray(Lux.initialparameters(Random.default_rng(), learner.model))
257147
for epoch in 1:epochs
258148
train_loss_epoch = 0.0f0
259149
train_acc_epoch = 0.0f0
260150
train_batches = 0
261151

262152
for batch in train_loader
263-
loss, acc = train_step!(learner, batch)
153+
loss, acc, p = train_step!(learner, batch, p)
264154
train_loss_epoch += loss
265155
train_acc_epoch += acc
266156
train_batches += 1
@@ -278,7 +168,7 @@ function train_model!(learner::IrregularSequenceLearner, train_loader, val_loade
278168
val_batches = 0
279169

280170
for batch in val_loader
281-
loss, acc = validation_step(learner, batch)
171+
loss, acc = validation_step(learner, batch, p)
282172
val_loss_epoch += loss
283173
val_acc_epoch += acc
284174
val_batches += 1
@@ -303,26 +193,6 @@ function train_model!(learner::IrregularSequenceLearner, train_loader, val_loade
303193
end
304194
end
305195

306-
function evaluate_model(learner::IrregularSequenceLearner, test_loader)
307-
total_loss = 0.0f0
308-
total_acc = 0.0f0
309-
num_batches = 0
310-
311-
for batch in test_loader
312-
loss, acc = validation_step(learner, batch)
313-
total_loss += loss
314-
total_acc += acc
315-
num_batches += 1
316-
end
317-
318-
avg_loss = total_loss / num_batches
319-
avg_acc = total_acc / num_batches
320-
321-
@printf("Test Results: Loss=%.4f, Accuracy=%.4f\n", avg_loss, avg_acc)
322-
323-
return Dict("test_loss" => avg_loss, "val_acc" => avg_acc)
324-
end
325-
326196
struct PersonData
327197
train_x::Array{Float32,3}
328198
train_y::Array{Int32,2}
@@ -335,8 +205,6 @@ struct PersonData
335205
end
336206

337207
function PersonData(; seq_len::Int=32)
338-
@warn "PersonData using synthetic data. Implement actual data loading for production use."
339-
340208
n_train, n_test = 1000, 200
341209
feature_size = 7
342210
num_classes = 7
@@ -612,30 +480,17 @@ function load_dataset(dataset_name::String; kwargs...)
612480
return train_data, test_data, in_features, num_classes, return_sequences
613481
end
614482

615-
function main_training_loop(; dataset="person", solver=:dopri5, size=64, epochs=100, lr=0.01f0)
483+
function main_training_loop(; dataset="person", solver=Tsit5(), size=64, epochs=100, lr=0.01f0)
616484
train_loader, test_loader, in_features, num_classes, return_sequences = load_dataset(dataset)
617485

618-
println("Dataset: $dataset")
619-
println("Input features: $in_features")
620-
println("Number of classes: $num_classes")
621-
println("Return sequences: $return_sequences")
622-
println("Hidden size: $size")
623-
println("Solver: $solver")
624-
625486
model = ODELSTMModel(in_features, size, num_classes;
626-
return_sequences=return_sequences, solver_type=solver)
487+
return_sequences=return_sequences, solver=solver)
627488

628489
learner = IrregularSequenceLearner(model, lr)
629490

630-
println("Starting training...")
631491
train_model!(learner, train_loader; epochs=epochs)
632492

633-
println("Evaluating model...")
634493
results = evaluate_model(learner, test_loader)
635494

636-
println("Final accuracy: $(results["val_acc"])")
637-
638495
return learner, results
639-
end
640-
641496
end

0 commit comments

Comments
 (0)