Skip to content

Commit 3a81f8f

Browse files
authored
Fix deprecation warnings with new Nx (#454)
1 parent 36f160e commit 3a81f8f

File tree

6 files changed

+98
-140
lines changed

6 files changed

+98
-140
lines changed

lib/axon/activations.ex

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,8 @@ defmodule Axon.Activations do
525525
defn relu(x) do
526526
custom_grad(
527527
Nx.max(x, 0),
528-
fn _ans, g -> [{x, Nx.select(Nx.greater(x, 0), g, Nx.broadcast(0, g))}] end
528+
[x],
529+
fn g -> [{x, Nx.select(Nx.greater(x, 0), g, Nx.broadcast(0, g))}] end
529530
)
530531
end
531532

lib/axon/compiler.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -889,7 +889,7 @@ defmodule Axon.Compiler do
889889

890890
if event? and mode? do
891891
if on_event == :backward do
892-
Nx.Defn.Kernel.custom_grad(expr, fn _ans, g ->
892+
Nx.Defn.Kernel.custom_grad(expr, [expr], fn g ->
893893
hooked_g = Nx.Defn.Kernel.hook(g, hook_fn)
894894
[{expr, hooked_g}]
895895
end)

lib/axon/initializers.ex

Lines changed: 86 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -174,20 +174,16 @@ defmodule Axon.Initializers do
174174
175175
"""
176176
def uniform(opts \\ []) do
177+
opts = Keyword.validate!(opts, scale: 1.0e-2)
178+
scale = Keyword.fetch!(opts, :scale)
179+
177180
fn shape, type, key ->
178-
scale = opts[:scale] || 1.0e-2
179-
uniform_impl(key, shape: shape, type: type, scale: scale)
181+
uniform_impl(key, scale, shape: shape, type: type)
180182
end
181183
end
182184

183-
defnp uniform_impl(key, opts \\ []) do
184-
opts = keyword!(opts, [:shape, type: {:f, 32}, scale: 1.0e-2])
185-
shape = Nx.shape(opts[:shape])
186-
187-
Nx.Random.uniform_split(key, Nx.negate(opts[:scale]), opts[:scale],
188-
type: opts[:type],
189-
shape: shape
190-
)
185+
defnp uniform_impl(key, scale, opts) do
186+
Nx.Random.uniform_split(key, Nx.negate(scale), scale, opts)
191187
end
192188

193189
@doc """
@@ -216,18 +212,15 @@ defmodule Axon.Initializers do
216212
217213
"""
218214
def normal(opts \\ []) do
215+
opts = Keyword.validate!(opts, scale: 1.0e-2, mean: 0.0)
216+
scale = Keyword.fetch!(opts, :scale)
217+
mean = Keyword.fetch!(opts, :mean)
218+
219219
fn shape, type, key ->
220-
scale = opts[:scale] || 1.0e-2
221-
mean = opts[:mean] || 0.0
222-
normal_impl(key, shape: shape, type: type, scale: scale, mean: mean)
220+
Nx.Random.normal_split(key, mean, scale, type: type, shape: shape)
223221
end
224222
end
225223

226-
defnp normal_impl(key, opts \\ []) do
227-
opts = keyword!(opts, [:shape, type: {:f, 32}, scale: 1.0e-2, mean: 0.0])
228-
Nx.Random.normal_split(key, opts[:mean], opts[:scale], shape: opts[:shape], type: opts[:type])
229-
end
230-
231224
@doc """
232225
Initializes parameters with the Lecun uniform initializer.
233226
@@ -261,25 +254,21 @@ defmodule Axon.Initializers do
261254
262255
"""
263256
def lecun_uniform(opts \\ []) do
257+
opts = Keyword.validate!(opts, scale: 1.0)
258+
scale = Keyword.fetch!(opts, :scale)
259+
264260
fn shape, type, key ->
265-
scale = opts[:scale] || 1.0
266-
lecun_uniform_impl(key, shape: shape, type: type, scale: scale)
261+
variance_scaling_impl(
262+
key,
263+
scale,
264+
shape: shape,
265+
type: type,
266+
mode: :fan_in,
267+
distribution: :uniform
268+
)
267269
end
268270
end
269271

270-
defnp lecun_uniform_impl(key, opts \\ []) do
271-
opts = keyword!(opts, [:shape, type: {:f, 32}, scale: 1.0])
272-
273-
variance_scaling_impl(
274-
key,
275-
shape: opts[:shape],
276-
type: opts[:type],
277-
scale: opts[:scale],
278-
mode: :fan_in,
279-
distribution: :uniform
280-
)
281-
end
282-
283272
@doc """
284273
Initializes parameters with the Lecun normal initializer.
285274
@@ -313,25 +302,21 @@ defmodule Axon.Initializers do
313302
314303
"""
315304
def lecun_normal(opts \\ []) do
305+
opts = Keyword.validate!(opts, scale: 1.0)
306+
scale = Keyword.fetch!(opts, :scale)
307+
316308
fn shape, type, key ->
317-
scale = opts[:scale] || 1.0
318-
lecun_normal_impl(key, shape: shape, type: type, scale: scale)
309+
variance_scaling_impl(
310+
key,
311+
scale,
312+
shape: shape,
313+
type: type,
314+
mode: :fan_in,
315+
distribution: :truncated_normal
316+
)
319317
end
320318
end
321319

322-
defnp lecun_normal_impl(key, opts \\ []) do
323-
opts = keyword!(opts, [:shape, type: {:f, 32}, scale: 1.0])
324-
325-
variance_scaling_impl(
326-
key,
327-
shape: opts[:shape],
328-
type: opts[:type],
329-
scale: opts[:scale],
330-
mode: :fan_in,
331-
distribution: :truncated_normal
332-
)
333-
end
334-
335320
@doc """
336321
Initializes parameters with the Glorot uniform initializer.
337322
@@ -368,25 +353,21 @@ defmodule Axon.Initializers do
368353
369354
"""
370355
def glorot_uniform(opts \\ []) do
356+
opts = Keyword.validate!(opts, scale: 1.0)
357+
scale = Keyword.fetch!(opts, :scale)
358+
371359
fn shape, type, key ->
372-
scale = opts[:scale] || 1.0
373-
glorot_uniform_impl(key, shape: shape, type: type, scale: scale)
360+
variance_scaling_impl(
361+
key,
362+
scale,
363+
shape: shape,
364+
type: type,
365+
mode: :fan_avg,
366+
distribution: :uniform
367+
)
374368
end
375369
end
376370

377-
defnp glorot_uniform_impl(key, opts \\ []) do
378-
opts = keyword!(opts, [:shape, type: {:f, 32}, scale: 1.0])
379-
380-
variance_scaling_impl(
381-
key,
382-
shape: opts[:shape],
383-
type: opts[:type],
384-
scale: opts[:scale],
385-
mode: :fan_avg,
386-
distribution: :uniform
387-
)
388-
end
389-
390371
@doc """
391372
Initializes parameters with the Glorot normal initializer.
392373
@@ -423,25 +404,21 @@ defmodule Axon.Initializers do
423404
424405
"""
425406
def glorot_normal(opts \\ []) do
407+
opts = Keyword.validate!(opts, scale: 1.0)
408+
scale = Keyword.fetch!(opts, :scale)
409+
426410
fn shape, type, key ->
427-
scale = opts[:scale] || 1.0
428-
glorot_normal_impl(key, shape: shape, type: type, scale: scale)
411+
variance_scaling_impl(
412+
key,
413+
scale,
414+
shape: shape,
415+
type: type,
416+
mode: :fan_avg,
417+
distribution: :truncated_normal
418+
)
429419
end
430420
end
431421

432-
defnp glorot_normal_impl(key, opts \\ []) do
433-
opts = keyword!(opts, [:shape, type: {:f, 32}, scale: 1.0])
434-
435-
variance_scaling_impl(
436-
key,
437-
shape: opts[:shape],
438-
type: opts[:type],
439-
scale: opts[:scale],
440-
mode: :fan_avg,
441-
distribution: :truncated_normal
442-
)
443-
end
444-
445422
@doc """
446423
Initializes parameters with the He uniform initializer.
447424
@@ -475,25 +452,21 @@ defmodule Axon.Initializers do
475452
476453
"""
477454
def he_uniform(opts \\ []) do
455+
opts = Keyword.validate!(opts, scale: 2.0)
456+
scale = Keyword.fetch!(opts, :scale)
457+
478458
fn shape, type, key ->
479-
scale = opts[:scale] || 2.0
480-
he_uniform_impl(key, shape: shape, type: type, scale: scale)
459+
variance_scaling_impl(
460+
key,
461+
scale,
462+
shape: shape,
463+
type: type,
464+
mode: :fan_in,
465+
distribution: :uniform
466+
)
481467
end
482468
end
483469

484-
defnp he_uniform_impl(key, opts \\ []) do
485-
opts = keyword!(opts, [:shape, type: {:f, 32}, scale: 2.0])
486-
487-
variance_scaling_impl(
488-
key,
489-
shape: opts[:shape],
490-
type: opts[:type],
491-
scale: opts[:scale],
492-
mode: :fan_in,
493-
distribution: :uniform
494-
)
495-
end
496-
497470
@doc """
498471
Initializes parameters with the He normal initializer.
499472
@@ -527,25 +500,21 @@ defmodule Axon.Initializers do
527500
528501
"""
529502
def he_normal(opts \\ []) do
503+
opts = Keyword.validate!(opts, scale: 2.0)
504+
scale = Keyword.fetch!(opts, :scale)
505+
530506
fn shape, type, key ->
531-
scale = opts[:scale] || 2.0
532-
he_normal_impl(key, shape: shape, type: type, scale: scale)
507+
variance_scaling_impl(
508+
key,
509+
scale,
510+
shape: shape,
511+
type: type,
512+
mode: :fan_in,
513+
distribution: :truncated_normal
514+
)
533515
end
534516
end
535517

536-
defnp he_normal_impl(key, opts \\ []) do
537-
opts = keyword!(opts, [:shape, type: {:f, 32}, scale: 2.0])
538-
539-
variance_scaling_impl(
540-
key,
541-
shape: opts[:shape],
542-
type: opts[:type],
543-
scale: opts[:scale],
544-
mode: :fan_in,
545-
distribution: :truncated_normal
546-
)
547-
end
548-
549518
@doc """
550519
Initializes parameters with variance scaling according to
551520
the given distribution and mode.
@@ -586,30 +555,29 @@ defmodule Axon.Initializers do
586555
587556
"""
588557
def variance_scaling(opts \\ []) do
589-
fn shape, type, key ->
590-
scale = opts[:scale] || 1.0
591-
mode = opts[:mode] || :fan_in
592-
distribution = opts[:distribution] || :normal
558+
opts = Keyword.validate!(opts, scale: 1.0, mode: :fan_in, distribution: :normal)
559+
scale = Keyword.fetch!(opts, :scale)
560+
mode = Keyword.fetch!(opts, :mode)
561+
distribution = Keyword.fetch!(opts, :distribution)
593562

563+
fn shape, type, key ->
594564
variance_scaling_impl(
595565
key,
566+
scale,
596567
shape: shape,
597568
type: type,
598-
scale: scale,
599569
mode: mode,
600570
distribution: distribution
601571
)
602572
end
603573
end
604574

605-
defnp variance_scaling_impl(key, opts \\ []) do
606-
opts =
607-
keyword!(opts, [:shape, type: {:f, 32}, scale: 1.0, mode: :fan_in, distribution: :normal])
575+
defnp variance_scaling_impl(key, scale, opts \\ []) do
576+
opts = keyword!(opts, [:shape, type: {:f, 32}, mode: :fan_in, distribution: :normal])
608577

609578
fans = compute_fans(opts[:shape])
610579
denominator = compute_denominator(fans, opts[:mode])
611-
612-
variance = Nx.divide(Nx.tensor(opts[:scale], type: opts[:type]), Nx.max(denominator, 1.0))
580+
variance = Nx.as_type(scale, opts[:type]) / Nx.max(denominator, 1.0)
613581

614582
apply_distribution(key, opts[:distribution], variance, shape: opts[:shape], type: opts[:type])
615583
end

lib/axon/layers.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1609,7 +1609,7 @@ defmodule Axon.Layers do
16091609
@doc type: :dropout
16101610
defn dropout(input, key, opts \\ []) do
16111611
opts = keyword!(opts, [:rate, noise_shape: Nx.shape(input), mode: :inference])
1612-
keep_prob = Nx.tensor(1, type: Nx.type(input)) - Nx.tensor(opts[:rate], type: Nx.type(input))
1612+
keep_prob = Nx.tensor(1, type: Nx.type(input)) - Nx.as_type(opts[:rate], Nx.type(input))
16131613

16141614
{rand, new_key} =
16151615
Nx.Random.uniform(key, 0, 1, shape: opts[:noise_shape], type: Nx.type(input))

lib/axon/schedules.ex

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ defmodule Axon.Schedules do
2222
"""
2323

2424
import Nx.Defn
25-
import Axon.Shared
2625

2726
@doc """
2827
Linear decay schedule.
@@ -86,7 +85,7 @@ defmodule Axon.Schedules do
8685

8786
init_value = opts[:init_value]
8887
rate = opts[:decay_rate]
89-
staircase? = to_predicate(opts[:staircase])
88+
staircase? = opts[:staircase]
9089
k = opts[:transition_steps]
9190
start = opts[:transition_begin]
9291

@@ -166,7 +165,7 @@ defmodule Axon.Schedules do
166165

167166
defnp apply_constant(_step, opts \\ []) do
168167
opts = keyword!(opts, init_value: 0.01)
169-
Nx.tensor(opts[:init_value])
168+
opts[:init_value]
170169
end
171170

172171
@doc ~S"""

0 commit comments

Comments
 (0)