@@ -184,4 +184,58 @@ defmodule Axon.IntegrationTest do
184184 assert_equal ( step_state1 , step_state2 )
185185 end )
186186 end
187+
188+ test "dropout with certain optimizers regression test" do
189+ { train , _test } = get_test_data ( 100 , 0 , 10 , { 10 } , 2 , 1337 )
190+
191+ train =
192+ train
193+ |> Stream . map ( fn { xs , ys } ->
194+ { xs , one_hot ( ys , num_classes: 2 ) }
195+ end )
196+ |> Enum . to_list ( )
197+
198+ [ { x_test , _ } ] = Enum . take ( train , 1 )
199+
200+ model =
201+ Axon . input ( "input" )
202+ |> Axon . dense ( 16 )
203+ |> Axon . dropout ( rate: 0.1 )
204+ |> Axon . dense ( 2 , activation: :softmax )
205+
206+ optimizers = [
207+ Axon.Optimizers . rmsprop ( 5.0e-3 , centered: true ) ,
208+ Axon.Optimizers . rmsprop ( 5.0e-3 , centered: false ) ,
209+ :adagrad ,
210+ :yogi
211+ ]
212+
213+ ExUnit.CaptureIO . capture_io ( fn ->
214+ for optim <- optimizers do
215+ results =
216+ model
217+ |> Axon.Loop . trainer ( :categorical_cross_entropy , optim )
218+ # TODO: Fix default output transform
219+ |> Map . update ( :output_transform , nil , fn _ -> & & 1 end )
220+ |> Axon.Loop . metric ( :accuracy )
221+ |> Axon.Loop . validate ( model , train )
222+ |> Axon.Loop . run ( train , % { } , epochs: 10 )
223+
224+ assert % { step_state: % { model_state: model_state } , metrics: % { 9 => last_epoch_metrics } } =
225+ results
226+
227+ eval_results =
228+ model
229+ |> Axon.Loop . evaluator ( )
230+ |> Axon.Loop . metric ( :accuracy )
231+ |> Axon.Loop . run ( train , model_state )
232+
233+ assert % { 0 => % { "accuracy" => final_model_val_accuracy } } = eval_results
234+
235+ assert_greater_equal ( last_epoch_metrics [ "validation_accuracy" ] , 0.7 )
236+ assert_all_close ( final_model_val_accuracy , last_epoch_metrics [ "validation_accuracy" ] )
237+ assert Nx . shape ( Axon . predict ( model , model_state , x_test ) ) == { 10 , 2 }
238+ end
239+ end )
240+ end
187241end
0 commit comments