@@ -174,20 +174,16 @@ defmodule Axon.Initializers do
174174
175175 """
176176 def uniform ( opts \\ [ ] ) do
177+ opts = Keyword . validate! ( opts , scale: 1.0e-2 )
178+ scale = Keyword . fetch! ( opts , :scale )
179+
177180 fn shape , type , key ->
178- scale = opts [ :scale ] || 1.0e-2
179- uniform_impl ( key , shape: shape , type: type , scale: scale )
181+ uniform_impl ( key , scale , shape: shape , type: type )
180182 end
181183 end
182184
183- defnp uniform_impl ( key , opts \\ [ ] ) do
184- opts = keyword! ( opts , [ :shape , type: { :f , 32 } , scale: 1.0e-2 ] )
185- shape = Nx . shape ( opts [ :shape ] )
186-
187- Nx.Random . uniform_split ( key , Nx . negate ( opts [ :scale ] ) , opts [ :scale ] ,
188- type: opts [ :type ] ,
189- shape: shape
190- )
185+ defnp uniform_impl ( key , scale , opts ) do
186+ Nx.Random . uniform_split ( key , Nx . negate ( scale ) , scale , opts )
191187 end
192188
193189 @ doc """
@@ -216,18 +212,15 @@ defmodule Axon.Initializers do
216212
217213 """
218214 def normal ( opts \\ [ ] ) do
215+ opts = Keyword . validate! ( opts , scale: 1.0e-2 , mean: 0.0 )
216+ scale = Keyword . fetch! ( opts , :scale )
217+ mean = Keyword . fetch! ( opts , :mean )
218+
219219 fn shape , type , key ->
220- scale = opts [ :scale ] || 1.0e-2
221- mean = opts [ :mean ] || 0.0
222- normal_impl ( key , shape: shape , type: type , scale: scale , mean: mean )
220+ Nx.Random . normal_split ( key , mean , scale , type: type , shape: shape )
223221 end
224222 end
225223
226- defnp normal_impl ( key , opts \\ [ ] ) do
227- opts = keyword! ( opts , [ :shape , type: { :f , 32 } , scale: 1.0e-2 , mean: 0.0 ] )
228- Nx.Random . normal_split ( key , opts [ :mean ] , opts [ :scale ] , shape: opts [ :shape ] , type: opts [ :type ] )
229- end
230-
231224 @ doc """
232225 Initializes parameters with the Lecun uniform initializer.
233226
@@ -261,25 +254,21 @@ defmodule Axon.Initializers do
261254
262255 """
263256 def lecun_uniform ( opts \\ [ ] ) do
257+ opts = Keyword . validate! ( opts , scale: 1.0 )
258+ scale = Keyword . fetch! ( opts , :scale )
259+
264260 fn shape , type , key ->
265- scale = opts [ :scale ] || 1.0
266- lecun_uniform_impl ( key , shape: shape , type: type , scale: scale )
261+ variance_scaling_impl (
262+ key ,
263+ scale ,
264+ shape: shape ,
265+ type: type ,
266+ mode: :fan_in ,
267+ distribution: :uniform
268+ )
267269 end
268270 end
269271
270- defnp lecun_uniform_impl ( key , opts \\ [ ] ) do
271- opts = keyword! ( opts , [ :shape , type: { :f , 32 } , scale: 1.0 ] )
272-
273- variance_scaling_impl (
274- key ,
275- shape: opts [ :shape ] ,
276- type: opts [ :type ] ,
277- scale: opts [ :scale ] ,
278- mode: :fan_in ,
279- distribution: :uniform
280- )
281- end
282-
283272 @ doc """
284273 Initializes parameters with the Lecun normal initializer.
285274
@@ -313,25 +302,21 @@ defmodule Axon.Initializers do
313302
314303 """
315304 def lecun_normal ( opts \\ [ ] ) do
305+ opts = Keyword . validate! ( opts , scale: 1.0 )
306+ scale = Keyword . fetch! ( opts , :scale )
307+
316308 fn shape , type , key ->
317- scale = opts [ :scale ] || 1.0
318- lecun_normal_impl ( key , shape: shape , type: type , scale: scale )
309+ variance_scaling_impl (
310+ key ,
311+ scale ,
312+ shape: shape ,
313+ type: type ,
314+ mode: :fan_in ,
315+ distribution: :truncated_normal
316+ )
319317 end
320318 end
321319
322- defnp lecun_normal_impl ( key , opts \\ [ ] ) do
323- opts = keyword! ( opts , [ :shape , type: { :f , 32 } , scale: 1.0 ] )
324-
325- variance_scaling_impl (
326- key ,
327- shape: opts [ :shape ] ,
328- type: opts [ :type ] ,
329- scale: opts [ :scale ] ,
330- mode: :fan_in ,
331- distribution: :truncated_normal
332- )
333- end
334-
335320 @ doc """
336321 Initializes parameters with the Glorot uniform initializer.
337322
@@ -368,25 +353,21 @@ defmodule Axon.Initializers do
368353
369354 """
370355 def glorot_uniform ( opts \\ [ ] ) do
356+ opts = Keyword . validate! ( opts , scale: 1.0 )
357+ scale = Keyword . fetch! ( opts , :scale )
358+
371359 fn shape , type , key ->
372- scale = opts [ :scale ] || 1.0
373- glorot_uniform_impl ( key , shape: shape , type: type , scale: scale )
360+ variance_scaling_impl (
361+ key ,
362+ scale ,
363+ shape: shape ,
364+ type: type ,
365+ mode: :fan_avg ,
366+ distribution: :uniform
367+ )
374368 end
375369 end
376370
377- defnp glorot_uniform_impl ( key , opts \\ [ ] ) do
378- opts = keyword! ( opts , [ :shape , type: { :f , 32 } , scale: 1.0 ] )
379-
380- variance_scaling_impl (
381- key ,
382- shape: opts [ :shape ] ,
383- type: opts [ :type ] ,
384- scale: opts [ :scale ] ,
385- mode: :fan_avg ,
386- distribution: :uniform
387- )
388- end
389-
390371 @ doc """
391372 Initializes parameters with the Glorot normal initializer.
392373
@@ -423,25 +404,21 @@ defmodule Axon.Initializers do
423404
424405 """
425406 def glorot_normal ( opts \\ [ ] ) do
407+ opts = Keyword . validate! ( opts , scale: 1.0 )
408+ scale = Keyword . fetch! ( opts , :scale )
409+
426410 fn shape , type , key ->
427- scale = opts [ :scale ] || 1.0
428- glorot_normal_impl ( key , shape: shape , type: type , scale: scale )
411+ variance_scaling_impl (
412+ key ,
413+ scale ,
414+ shape: shape ,
415+ type: type ,
416+ mode: :fan_avg ,
417+ distribution: :truncated_normal
418+ )
429419 end
430420 end
431421
432- defnp glorot_normal_impl ( key , opts \\ [ ] ) do
433- opts = keyword! ( opts , [ :shape , type: { :f , 32 } , scale: 1.0 ] )
434-
435- variance_scaling_impl (
436- key ,
437- shape: opts [ :shape ] ,
438- type: opts [ :type ] ,
439- scale: opts [ :scale ] ,
440- mode: :fan_avg ,
441- distribution: :truncated_normal
442- )
443- end
444-
445422 @ doc """
446423 Initializes parameters with the He uniform initializer.
447424
@@ -475,25 +452,21 @@ defmodule Axon.Initializers do
475452
476453 """
477454 def he_uniform ( opts \\ [ ] ) do
455+ opts = Keyword . validate! ( opts , scale: 2.0 )
456+ scale = Keyword . fetch! ( opts , :scale )
457+
478458 fn shape , type , key ->
479- scale = opts [ :scale ] || 2.0
480- he_uniform_impl ( key , shape: shape , type: type , scale: scale )
459+ variance_scaling_impl (
460+ key ,
461+ scale ,
462+ shape: shape ,
463+ type: type ,
464+ mode: :fan_in ,
465+ distribution: :uniform
466+ )
481467 end
482468 end
483469
484- defnp he_uniform_impl ( key , opts \\ [ ] ) do
485- opts = keyword! ( opts , [ :shape , type: { :f , 32 } , scale: 2.0 ] )
486-
487- variance_scaling_impl (
488- key ,
489- shape: opts [ :shape ] ,
490- type: opts [ :type ] ,
491- scale: opts [ :scale ] ,
492- mode: :fan_in ,
493- distribution: :uniform
494- )
495- end
496-
497470 @ doc """
498471 Initializes parameters with the He normal initializer.
499472
@@ -527,25 +500,21 @@ defmodule Axon.Initializers do
527500
528501 """
529502 def he_normal ( opts \\ [ ] ) do
503+ opts = Keyword . validate! ( opts , scale: 2.0 )
504+ scale = Keyword . fetch! ( opts , :scale )
505+
530506 fn shape , type , key ->
531- scale = opts [ :scale ] || 2.0
532- he_normal_impl ( key , shape: shape , type: type , scale: scale )
507+ variance_scaling_impl (
508+ key ,
509+ scale ,
510+ shape: shape ,
511+ type: type ,
512+ mode: :fan_in ,
513+ distribution: :truncated_normal
514+ )
533515 end
534516 end
535517
536- defnp he_normal_impl ( key , opts \\ [ ] ) do
537- opts = keyword! ( opts , [ :shape , type: { :f , 32 } , scale: 2.0 ] )
538-
539- variance_scaling_impl (
540- key ,
541- shape: opts [ :shape ] ,
542- type: opts [ :type ] ,
543- scale: opts [ :scale ] ,
544- mode: :fan_in ,
545- distribution: :truncated_normal
546- )
547- end
548-
549518 @ doc """
550519 Initializes parameters with variance scaling according to
551520 the given distribution and mode.
@@ -586,30 +555,29 @@ defmodule Axon.Initializers do
586555
587556 """
588557 def variance_scaling ( opts \\ [ ] ) do
589- fn shape , type , key ->
590- scale = opts [ :scale ] || 1.0
591- mode = opts [ :mode ] || :fan_in
592- distribution = opts [ :distribution ] || :normal
558+ opts = Keyword . validate! ( opts , scale: 1.0 , mode: :fan_in , distribution: :normal )
559+ scale = Keyword . fetch! ( opts , :scale )
560+ mode = Keyword . fetch! ( opts , :mode )
561+ distribution = Keyword . fetch! ( opts , :distribution )
593562
563+ fn shape , type , key ->
594564 variance_scaling_impl (
595565 key ,
566+ scale ,
596567 shape: shape ,
597568 type: type ,
598- scale: scale ,
599569 mode: mode ,
600570 distribution: distribution
601571 )
602572 end
603573 end
604574
605- defnp variance_scaling_impl ( key , opts \\ [ ] ) do
606- opts =
607- keyword! ( opts , [ :shape , type: { :f , 32 } , scale: 1.0 , mode: :fan_in , distribution: :normal ] )
575+ defnp variance_scaling_impl ( key , scale , opts \\ [ ] ) do
576+ opts = keyword! ( opts , [ :shape , type: { :f , 32 } , mode: :fan_in , distribution: :normal ] )
608577
609578 fans = compute_fans ( opts [ :shape ] )
610579 denominator = compute_denominator ( fans , opts [ :mode ] )
611-
612- variance = Nx . divide ( Nx . tensor ( opts [ :scale ] , type: opts [ :type ] ) , Nx . max ( denominator , 1.0 ) )
580+ variance = Nx . as_type ( scale , opts [ :type ] ) / Nx . max ( denominator , 1.0 )
613581
614582 apply_distribution ( key , opts [ :distribution ] , variance , shape: opts [ :shape ] , type: opts [ :type ] )
615583 end
0 commit comments