Skip to content

Commit 27b9b53

Browse files
authored
Fix mismatch for certain optimizers with dropout (#458)
1 parent 76d5a45 commit 27b9b53

File tree

4 files changed

+63
-7
lines changed

4 files changed

+63
-7
lines changed

lib/axon/shared.ex

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,13 @@ defmodule Axon.Shared do
135135
@doc """
136136
Creates a fulls-like tuple of inputs.
137137
"""
138-
deftransform fulls_like(params, value) do
138+
deftransform fulls_like(params, value, opts \\ []) do
139+
opts = Keyword.validate!(opts, [:type])
139140
fun = Axon.Initializers.full(value)
140141

141142
deep_new(params, fn x ->
142-
fun.(Nx.shape(x), Nx.type(x))
143+
type = opts[:type] || Nx.type(x)
144+
fun.(Nx.shape(x), type)
143145
end)
144146
end
145147

lib/axon/updates.ex

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ defmodule Axon.Updates do
219219
end
220220

221221
defnp init_scale_by_rss(params, value) do
222-
sum_of_squares = fulls_like(params, value)
222+
sum_of_squares = fulls_like(params, value, type: :f32)
223223
%{sum_of_squares: sum_of_squares}
224224
end
225225

@@ -278,7 +278,7 @@ defmodule Axon.Updates do
278278
end
279279

280280
defnp init_scale_by_rms(params, scale) do
281-
nu = fulls_like(params, scale)
281+
nu = fulls_like(params, scale, type: :f32)
282282
%{nu: nu}
283283
end
284284

@@ -395,7 +395,7 @@ defmodule Axon.Updates do
395395

396396
defnp init_scale_by_stddev(params, value) do
397397
mu = zeros_like(params, type: :f32)
398-
nu = fulls_like(params, value)
398+
nu = fulls_like(params, value, type: :f32)
399399
%{mu: mu, nu: nu}
400400
end
401401

@@ -860,7 +860,7 @@ defmodule Axon.Updates do
860860
end
861861

862862
defnp init_scale_by_yogi(params, value) do
863-
value = fulls_like(params, value)
863+
value = fulls_like(params, value, type: :f32)
864864
mu = value
865865
nu = value
866866
count = Nx.tensor(0)

mix.exs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ defmodule Axon.MixProject do
22
use Mix.Project
33

44
@source_url "https://github.com/elixir-nx/axon"
5-
@version "0.4.0"
5+
@version "0.4.1"
66

77
def project do
88
[

test/axon/integration_test.exs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,4 +184,58 @@ defmodule Axon.IntegrationTest do
184184
assert_equal(step_state1, step_state2)
185185
end)
186186
end
187+
188+
test "dropout with certain optimizers regression test" do
189+
{train, _test} = get_test_data(100, 0, 10, {10}, 2, 1337)
190+
191+
train =
192+
train
193+
|> Stream.map(fn {xs, ys} ->
194+
{xs, one_hot(ys, num_classes: 2)}
195+
end)
196+
|> Enum.to_list()
197+
198+
[{x_test, _}] = Enum.take(train, 1)
199+
200+
model =
201+
Axon.input("input")
202+
|> Axon.dense(16)
203+
|> Axon.dropout(rate: 0.1)
204+
|> Axon.dense(2, activation: :softmax)
205+
206+
optimizers = [
207+
Axon.Optimizers.rmsprop(5.0e-3, centered: true),
208+
Axon.Optimizers.rmsprop(5.0e-3, centered: false),
209+
:adagrad,
210+
:yogi
211+
]
212+
213+
ExUnit.CaptureIO.capture_io(fn ->
214+
for optim <- optimizers do
215+
results =
216+
model
217+
|> Axon.Loop.trainer(:categorical_cross_entropy, optim)
218+
# TODO: Fix default output transform
219+
|> Map.update(:output_transform, nil, fn _ -> & &1 end)
220+
|> Axon.Loop.metric(:accuracy)
221+
|> Axon.Loop.validate(model, train)
222+
|> Axon.Loop.run(train, %{}, epochs: 10)
223+
224+
assert %{step_state: %{model_state: model_state}, metrics: %{9 => last_epoch_metrics}} =
225+
results
226+
227+
eval_results =
228+
model
229+
|> Axon.Loop.evaluator()
230+
|> Axon.Loop.metric(:accuracy)
231+
|> Axon.Loop.run(train, model_state)
232+
233+
assert %{0 => %{"accuracy" => final_model_val_accuracy}} = eval_results
234+
235+
assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.7)
236+
assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"])
237+
assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2}
238+
end
239+
end)
240+
end
187241
end

0 commit comments

Comments
 (0)