@@ -22,7 +22,7 @@ function convnextblock(planes::Integer, drop_path_rate = 0.0, layerscale_init =
2222end
2323
2424"""
25- convnext(depths::Vector {<:Integer}, planes::Vector {<:Integer};
25+ convnext(depths::AbstractVector {<:Integer}, planes::AbstractVector {<:Integer};
2626 drop_path_rate = 0.0, layerscale_init = 1.0f-6, inchannels::Integer = 3,
2727 nclasses::Integer = 1000)
2828
@@ -31,27 +31,27 @@ Creates the layers for a ConvNeXt model.
3131
3232# Arguments
3333
34- - `inchannels`: number of input channels.
3534 - `depths`: list with configuration for depth of each block
3635 - `planes`: list with configuration for number of output channels in each block
3736 - `drop_path_rate`: Stochastic depth rate.
3837 - `layerscale_init`: Initial value for [`LayerScale`](#)
3938 ([reference](https://arxiv.org/abs/2103.17239))
39+ - `inchannels`: number of input channels.
4040 - `nclasses`: number of output classes
4141"""
42- function convnext (depths:: Vector {<:Integer} , planes:: Vector {<:Integer} ;
42+ function convnext (depths:: AbstractVector {<:Integer} , planes:: AbstractVector {<:Integer} ;
4343 drop_path_rate = 0.0 , layerscale_init = 1.0f-6 , inchannels:: Integer = 3 ,
4444 nclasses:: Integer = 1000 )
4545 @assert length (depths) == length (planes)
4646 " `planes` should have exactly one value for each block"
4747 downsample_layers = []
48- stem = Chain ( Conv (( 4 , 4 ), inchannels => planes[ 1 ]; stride = 4 ) ,
49- ChannelLayerNorm ( planes[1 ]))
50- push! (downsample_layers, stem )
48+ push! (downsample_layers ,
49+ Chain ( conv_norm (( 4 , 4 ), inchannels => planes[1 ]; stride = 4 ,
50+ norm_layer = ChannelLayerNorm) ... ) )
5151 for m in 1 : (length (depths) - 1 )
52- downsample_layer = Chain ( ChannelLayerNorm (planes[m]) ,
53- Conv (( 2 , 2 ), planes[m] => planes[m + 1 ]; stride = 2 ))
54- push! (downsample_layers, downsample_layer )
52+ push! (downsample_layers ,
53+ Chain ( conv_norm (( 2 , 2 ), planes[m] => planes[m + 1 ]; stride = 2 ,
54+ norm_layer = ChannelLayerNorm, revnorm = true ) ... ) )
5555 end
5656 stages = []
5757 dp_rates = linear_scheduler (drop_path_rate; depth = sum (depths))
@@ -64,8 +64,7 @@ function convnext(depths::Vector{<:Integer}, planes::Vector{<:Integer};
6464 end
6565 backbone = collect (Iterators. flatten (Iterators. flatten (zip (downsample_layers, stages))))
6666 classifier = Chain (GlobalMeanPool (), MLUtils. flatten,
67- LayerNorm (planes[end ]),
68- Dense (planes[end ], nclasses))
67+ LayerNorm (planes[end ]), Dense (planes[end ], nclasses))
6968 return Chain (Chain (backbone... ), classifier)
7069end
7170
@@ -77,13 +76,14 @@ const CONVNEXT_CONFIGS = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]),
7776 :xlarge => ([3 , 3 , 27 , 3 ], [256 , 512 , 1024 , 2048 ]))
7877
7978"""
80- ConvNeXt(mode ::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
79+ ConvNeXt(config ::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
8180
8281Creates a ConvNeXt model.
8382([reference](https://arxiv.org/abs/2201.03545))
8483
8584# Arguments
8685
86+ - `config`: The size of the model, one of `tiny`, `small`, `base`, `large` or `xlarge`.
8787 - `inchannels`: The number of channels in the input.
8888 - `nclasses`: number of output classes
8989
@@ -94,9 +94,9 @@ struct ConvNeXt
9494end
9595@functor ConvNeXt
9696
97- function ConvNeXt (mode :: Symbol ; inchannels:: Integer = 3 , nclasses:: Integer = 1000 )
98- _checkconfig (mode , keys (CONVNEXT_CONFIGS))
99- layers = convnext (CONVNEXT_CONFIGS[mode ]. .. ; inchannels, nclasses)
97+ function ConvNeXt (config :: Symbol ; inchannels:: Integer = 3 , nclasses:: Integer = 1000 )
98+ _checkconfig (config , keys (CONVNEXT_CONFIGS))
99+ layers = convnext (CONVNEXT_CONFIGS[config ]. .. ; inchannels, nclasses)
100100 return ConvNeXt (layers)
101101end
102102
0 commit comments