1- """
2- ODE-LSTM: A Complete Julia Implementation for DiffEqFlux.jl
3-
4- This implementation converts the original PyTorch ODE-LSTM to Julia, providing:
5- - Multiple solver support (Tsit5, Euler, Heun, RK4)
6- - Dataset handling for Person, ET-MNIST, and XOR tasks
7- - Training and evaluation functionality
8- - Integration with the SciML ecosystem
9-
10- Original paper: https://arxiv.org/abs/2006.04418
11- """
12-
13- module ODELSTM
14-
15- using DifferentialEquations
16- using Flux
17- using DiffEqFlux
1+ using Lux
2+ using ComponentArrays
3+ using Optimisers
184using LinearAlgebra
195using Random
206using Statistics
@@ -27,240 +13,144 @@ export PersonData, ETSMnistData, XORData
2713export train_model!, evaluate_model, load_dataset
2814
2915mutable struct ODELSTMCell{F,S}
30- lstm_cell:: Flux .LSTMCell
16+ lstm_cell:: Lux .LSTMCell
3117 f_node:: F
32- solver_type:: Symbol
3318 solver:: S
3419 input_size:: Int
3520 hidden_size:: Int
3621end
3722
38- function ODELSTMCell (input_size:: Int , hidden_size:: Int , solver_type :: Symbol = :dopri5 )
39- lstm_cell = Flux . LSTMCell (input_size => hidden_size)
23+ function ODELSTMCell (input_size:: Int , hidden_size:: Int , solver )
24+ lstm_cell = Lux . LSTMCell (input_size, hidden_size)
4025 f_node = Chain (
41- Dense (hidden_size => hidden_size, tanh),
42- Dense (hidden_size => hidden_size)
43- )
44- solver = get_solver (solver_type)
45- return ODELSTMCell (lstm_cell, f_node, solver_type, solver, input_size, hidden_size)
46- end
47-
48- function get_solver (solver_type:: Symbol )
49- solver_map = Dict (
50- :dopri5 => Tsit5 (),
51- :tsit5 => Tsit5 (),
52- :euler => Euler (),
53- :heun => Heun (),
54- :rk4 => RK4 ()
26+ Dense (hidden_size, hidden_size, tanh),
27+ Dense (hidden_size, hidden_size)
5528 )
56- return get (solver_map, solver_type, Tsit5 () )
29+ return ODELSTMCell (lstm_cell, f_node, solver, input_size, hidden_size )
5730end
5831
59- function (cell:: ODELSTMCell )(x, state, ts)
32+ function (cell:: ODELSTMCell )(x, state, ts, p, st )
6033 h, c = state
61- new_h, new_c = cell. lstm_cell (x, (h, c))
34+ new_h, new_c, st = cell. lstm_cell (x, (h, c), p, st )
6235
63- if cell. solver_type in [ :euler , :heun , :rk4 ]
64- evolved_h = solve_fixed_step (cell, new_h, ts)
36+ if ! ( cell. solver isa Union{Tsit5,DP5,BS3})
37+ evolved_h, st = solve_fixed_step (cell, new_h, ts, p, st )
6538 else
66- evolved_h = solve_adaptive (cell, new_h, ts)
39+ evolved_h, st = solve_adaptive (cell, new_h, ts, p, st )
6740 end
6841
69- return evolved_h, (evolved_h, new_c)
42+ return evolved_h, (evolved_h, new_c), st
7043end
7144
72- function solve_fixed_step (cell:: ODELSTMCell , h, ts)
73- dt = ts / 3.0
74- h_evolved = h
75- for i in 1 : 3
76- if cell. solver_type == :euler
77- h_evolved = euler_step (cell. f_node, h_evolved, dt)
78- elseif cell. solver_type == :heun
79- h_evolved = heun_step (cell. f_node, h_evolved, dt)
80- elseif cell. solver_type == :rk4
81- h_evolved = rk4_step (cell. f_node, h_evolved, dt)
82- end
83- end
84- return h_evolved
45+ function solve_fixed_step (cell:: ODELSTMCell , h, ts, p, st)
46+ prob = ODEProblem ((u,p,t)-> cell. f_node (u,p,st)[1 ], h, (0.0f0 , Float32 (ts)))
47+ sol = solve (prob, cell. solver; adaptive= false )
48+ return sol. u[end ], st
8549end
8650
87- function solve_adaptive (cell:: ODELSTMCell , h, ts)
51+ function solve_adaptive (cell:: ODELSTMCell , h, ts, p, st )
8852 if ndims (h) == 2
8953 batch_size = size (h, 2 )
9054 results = similar (h)
9155
9256 for i in 1 : batch_size
9357 h_i = h[:, i]
9458 ts_i = ts isa AbstractVector ? ts[i] : ts
95- t_span = (0.0f0 , Float32 (ts_i) + 1f-6 * i)
96-
97- function ode_func! (dh, h_state, p, t)
98- dh .= cell. f_node (h_state)
99- end
59+ t_span = (0.0f0 , Float32 (ts_i))
10060
101- prob = ODEProblem (ode_func! , h_i, t_span)
61+ prob = ODEProblem ((u,p,t) -> cell . f_node (u,p,st)[ 1 ] , h_i, t_span)
10262 sol = solve (prob, cell. solver, saveat= [t_span[2 ]], dense= false )
10363 results[:, i] = sol. u[end ]
10464 end
105- return results
65+ return results, st
10666 else
10767 t_span = (0.0f0 , Float32 (ts))
108-
109- function ode_func! (dh, h_state, p, t)
110- dh .= cell. f_node (h_state)
111- end
112-
113- prob = ODEProblem (ode_func!, h, t_span)
68+ prob = ODEProblem ((u,p,t)-> cell. f_node (u,p,st)[1 ], h, t_span)
11469 sol = solve (prob, cell. solver, saveat= [t_span[2 ]], dense= false )
115- return sol. u[end ]
70+ return sol. u[end ], st
11671 end
11772end
11873
119- function euler_step (f, y, dt)
120- dy = f (y)
121- return y + dt * dy
122- end
123-
124- function heun_step (f, y, dt)
125- k1 = f (y)
126- k2 = f (y + dt * k1)
127- return y + dt * 0.5f0 * (k1 + k2)
128- end
129-
130- function rk4_step (f, y, dt)
131- k1 = f (y)
132- k2 = f (y + k1 * dt * 0.5f0 )
133- k3 = f (y + k2 * dt * 0.5f0 )
134- k4 = f (y + k3 * dt)
135- return y + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6.0f0
136- end
137-
13874struct ODELSTMModel{C,O}
13975 rnn_cell:: C
14076 output_layer:: O
14177 return_sequences:: Bool
14278end
14379
14480function ODELSTMModel (in_features:: Int , hidden_size:: Int , out_features:: Int ;
145- return_sequences:: Bool = true , solver_type :: Symbol = :dopri5 )
146- rnn_cell = ODELSTMCell (in_features, hidden_size, solver_type )
147- output_layer = Dense (hidden_size => out_features)
81+ return_sequences:: Bool = true , solver )
82+ rnn_cell = ODELSTMCell (in_features, hidden_size, solver )
83+ output_layer = Dense (hidden_size, out_features)
14884 return ODELSTMModel (rnn_cell, output_layer, return_sequences)
14985end
15086
151- Flux. @functor ODELSTMModel
152-
153- function (model:: ODELSTMModel )(x, timespans, mask= nothing )
154- batch_size, seq_len, input_size = size (x)
155-
156- h = zeros (Float32, model. rnn_cell. hidden_size, batch_size)
157- c = zeros (Float32, model. rnn_cell. hidden_size, batch_size)
158-
159- outputs = []
160- last_output = zeros (Float32, size (model. output_layer. weight, 1 ), batch_size)
161-
162- for t in 1 : seq_len
163- inputs = x[:, t, :]'
164- ts = timespans[:, t]
165-
166- h, (h, c) = model. rnn_cell (inputs, (h, c), ts)
167- current_output = model. output_layer (h)
168- push! (outputs, current_output)
169-
170- if mask != = nothing
171- cur_mask = mask[:, t]'
172- last_output = cur_mask .* current_output + (1.0f0 .- cur_mask) .* last_output
173- else
174- last_output = current_output
175- end
176- end
177-
178- if model. return_sequences
179- return cat (outputs... , dims= 3 )
180- else
181- return last_output'
182- end
183- end
184-
185- mutable struct IrregularSequenceLearner{M,O}
87+ mutable struct IrregularSequenceLearner{M,O,S}
18688 model:: M
18789 optimizer:: O
90+ states:: S
18891 train_losses:: Vector{Float32}
18992 val_losses:: Vector{Float32}
19093 train_accs:: Vector{Float32}
19194 val_accs:: Vector{Float32}
19295end
19396
194- function IrregularSequenceLearner (model, lr= 0.005f0 )
195- optimizer = ADAM (lr)
97+ function IrregularSequenceLearner (model, lr= 0.005f0 ; rng= Random. default_rng ())
98+ optimizer = Optimisers. Adam (lr)
99+ ps, st = Lux. setup (rng, model)
196100 return IrregularSequenceLearner (
197- model, optimizer,
101+ model, optimizer, st,
198102 Float32[], Float32[], Float32[], Float32[]
199103 )
200104end
201105
202- function train_step! (learner:: IrregularSequenceLearner , batch)
106+ function train_step! (learner:: IrregularSequenceLearner , batch, p )
203107 if length (batch) == 4
204108 x, t, y, mask = batch
205109 else
206110 x, t, y = batch
207111 mask = nothing
208112 end
209113
210- params = Flux. params (learner. model)
211-
212- loss, grads = Flux. withgradient (params) do
213- y_hat = learner. model (x, t, mask)
214- if ndims (y_hat) == 3
215- y_hat = reshape (y_hat, size (y_hat, 1 ), :)
216- end
217- y_flat = reshape (y, :)
218- Flux. crossentropy (y_hat, Flux. onehotbatch (y_flat, 0 : maximum (y_flat)))
114+ (loss, (y_hat, st)), grads = value_and_gradient (p) do p
115+ y_hat, st = learner. model (x, t, mask, p, learner. states)
116+ Lux. crossentropy (y_hat, Flux. onehotbatch (reshape (y, :), 0 : maximum (y)))
219117 end
220118
221- Flux. update! (learner. optimizer, params, grads)
119+ learner. states = st
120+ p = Optimisers. update (learner. optimizer, p, grads)
222121
223- y_hat = learner. model (x, t, mask)
224- if ndims (y_hat) == 3
225- y_hat = reshape (y_hat, size (y_hat, 1 ), :)
226- end
227- y_flat = reshape (y, :)
228122 preds = Flux. onecold (y_hat) .- 1
229- acc = mean (preds .== y_flat )
123+ acc = mean (preds .== reshape (y, :) )
230124
231- return loss, acc
125+ return loss, acc, p
232126end
233127
234- function validation_step (learner:: IrregularSequenceLearner , batch)
128+ function validation_step (learner:: IrregularSequenceLearner , batch, p )
235129 if length (batch) == 4
236130 x, t, y, mask = batch
237131 else
238132 x, t, y = batch
239133 mask = nothing
240134 end
241135
242- y_hat = learner. model (x, t, mask)
136+ y_hat, st = learner. model (x, t, mask, p, learner . states )
243137
244- if ndims (y_hat) == 3
245- y_hat = reshape (y_hat, size (y_hat, 1 ), :)
246- end
247- y_flat = reshape (y, :)
248-
249- loss = Flux. crossentropy (y_hat, Flux. onehotbatch (y_flat, 0 : maximum (y_flat)))
138+ loss = Lux. crossentropy (y_hat, Flux. onehotbatch (reshape (y, :), 0 : maximum (y)))
250139 preds = Flux. onecold (y_hat) .- 1
251- acc = mean (preds .== y_flat )
140+ acc = mean (preds .== reshape (y, :) )
252141
253142 return loss, acc
254143end
255144
256145function train_model! (learner:: IrregularSequenceLearner , train_loader, val_loader= nothing ; epochs= 100 )
146+ p = ComponentArray (Lux. initialparameters (Random. default_rng (), learner. model))
257147 for epoch in 1 : epochs
258148 train_loss_epoch = 0.0f0
259149 train_acc_epoch = 0.0f0
260150 train_batches = 0
261151
262152 for batch in train_loader
263- loss, acc = train_step! (learner, batch)
153+ loss, acc, p = train_step! (learner, batch, p )
264154 train_loss_epoch += loss
265155 train_acc_epoch += acc
266156 train_batches += 1
@@ -278,7 +168,7 @@ function train_model!(learner::IrregularSequenceLearner, train_loader, val_loade
278168 val_batches = 0
279169
280170 for batch in val_loader
281- loss, acc = validation_step (learner, batch)
171+ loss, acc = validation_step (learner, batch, p )
282172 val_loss_epoch += loss
283173 val_acc_epoch += acc
284174 val_batches += 1
@@ -303,26 +193,6 @@ function train_model!(learner::IrregularSequenceLearner, train_loader, val_loade
303193 end
304194end
305195
306- function evaluate_model (learner:: IrregularSequenceLearner , test_loader)
307- total_loss = 0.0f0
308- total_acc = 0.0f0
309- num_batches = 0
310-
311- for batch in test_loader
312- loss, acc = validation_step (learner, batch)
313- total_loss += loss
314- total_acc += acc
315- num_batches += 1
316- end
317-
318- avg_loss = total_loss / num_batches
319- avg_acc = total_acc / num_batches
320-
321- @printf (" Test Results: Loss=%.4f, Accuracy=%.4f\n " , avg_loss, avg_acc)
322-
323- return Dict (" test_loss" => avg_loss, " val_acc" => avg_acc)
324- end
325-
326196struct PersonData
327197 train_x:: Array{Float32,3}
328198 train_y:: Array{Int32,2}
@@ -335,8 +205,6 @@ struct PersonData
335205end
336206
337207function PersonData (; seq_len:: Int = 32 )
338- @warn " PersonData using synthetic data. Implement actual data loading for production use."
339-
340208 n_train, n_test = 1000 , 200
341209 feature_size = 7
342210 num_classes = 7
@@ -612,30 +480,17 @@ function load_dataset(dataset_name::String; kwargs...)
612480 return train_data, test_data, in_features, num_classes, return_sequences
613481end
614482
615- function main_training_loop (; dataset= " person" , solver= :dopri5 , size= 64 , epochs= 100 , lr= 0.01f0 )
483+ function main_training_loop (; dataset= " person" , solver= Tsit5 () , size= 64 , epochs= 100 , lr= 0.01f0 )
616484 train_loader, test_loader, in_features, num_classes, return_sequences = load_dataset (dataset)
617485
618- println (" Dataset: $dataset " )
619- println (" Input features: $in_features " )
620- println (" Number of classes: $num_classes " )
621- println (" Return sequences: $return_sequences " )
622- println (" Hidden size: $size " )
623- println (" Solver: $solver " )
624-
625486 model = ODELSTMModel (in_features, size, num_classes;
626- return_sequences= return_sequences, solver_type = solver)
487+ return_sequences= return_sequences, solver = solver)
627488
628489 learner = IrregularSequenceLearner (model, lr)
629490
630- println (" Starting training..." )
631491 train_model! (learner, train_loader; epochs= epochs)
632492
633- println (" Evaluating model..." )
634493 results = evaluate_model (learner, test_loader)
635494
636- println (" Final accuracy: $(results[" val_acc" ]) " )
637-
638495 return learner, results
639- end
640-
641496end
0 commit comments