From 09ca632af17d91a834c502b145ce97d5dc75fa1f Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 31 Mar 2025 10:06:10 +0200 Subject: [PATCH 1/2] julia 1.10 --- Project.toml | 2 +- src/dim_helpers/ConvDims.jl | 2 +- src/gemm.jl | 2 +- src/utils.jl | 18 ------------------ test.jl | 11 +++++++++++ test/conv.jl | 2 +- test/dropout.jl | 5 +---- test/pooling.jl | 2 +- test/runtests.jl | 2 +- test/testsuite/gather.jl | 2 +- test/testsuite/scatter.jl | 2 +- 11 files changed, 20 insertions(+), 30 deletions(-) create mode 100644 test.jl diff --git a/Project.toml b/Project.toml index c9647205c..02e14607b 100644 --- a/Project.toml +++ b/Project.toml @@ -48,4 +48,4 @@ ScopedValues = "1.3.0" SpecialFunctions = "2" Statistics = "1" cuDNN = "1" -julia = "1.9" +julia = "1.10" diff --git a/src/dim_helpers/ConvDims.jl b/src/dim_helpers/ConvDims.jl index e8bcc08f4..9e02010d3 100644 --- a/src/dim_helpers/ConvDims.jl +++ b/src/dim_helpers/ConvDims.jl @@ -73,7 +73,7 @@ function im2col_dims(c::ConvDims) # Size of single dotproduct within convolution prod(kernel_size(c))*channels_in(c), # One workspace per thread - VERSION > v"1.9.0-0" ? Threads.nthreads(:default) : Threads.nthreads(), + Threads.nthreads(:default), ) end diff --git a/src/gemm.jl b/src/gemm.jl index 9a3c6cd57..e05174d17 100644 --- a/src/gemm.jl +++ b/src/gemm.jl @@ -95,7 +95,7 @@ for (gemm, elt) in gemm_datatype_mappings strC = Base.stride(C, 3) n_threads = min( - VERSION > v"1.9.0-0" ? Threads.nthreads(:default) : Threads.nthreads(), + Threads.nthreads(:default), 1 + max(length(A), length(B)) ÷ 8000) # In some tests, size (20,20,20) is worth splitting between two threads, # as is size (32,32,8). diff --git a/src/utils.jl b/src/utils.jl index baf95c8da..6d82a81ec 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -144,21 +144,3 @@ function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_broadcast!), f: rrule_via_ad(cfg, broadcast, f, x, ys...) end -# Could get this from Compat.jl instead -# https://github.com/JuliaLang/julia/pull/39794 -if VERSION < v"1.7.0-DEV.793" - struct Returns{V} <: Function - value::V - Returns{V}(value) where {V} = new{V}(value) - Returns(value) = new{Core.Typeof(value)}(value) - end - - (obj::Returns)(args...; kw...) = obj.value - function Base.show(io::IO, obj::Returns) - show(io, typeof(obj)) - print(io, "(") - show(io, obj.value) - print(io, ")") - end -end - diff --git a/test.jl b/test.jl new file mode 100644 index 000000000..a20105f24 --- /dev/null +++ b/test.jl @@ -0,0 +1,11 @@ +import Metal, NNlib, Flux + +dev = Flux.get_device() + +src, idx = Int32[1 2 3 4; 5 6 7 8], Int32[2,1,1,5] +srcd, idxd = dev(x), dev(idx) +y = NNlib.scatter(+, src, idx) +yd = dev(zero(y)) +NNlib.scatter!(+, yd, srcd, idxd) + + diff --git a/test/conv.jl b/test/conv.jl index cf3232778..8e52c846a 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -908,7 +908,7 @@ end gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w) end -@static if Test_Enzyme +if NNLIB_TEST_ENZYME @testset "EnzymeRules: conv! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) x = rand(rng, repeat([5], spatial_rank)..., 3, 2) diff --git a/test/dropout.jl b/test/dropout.jl index 0da70111e..65aac8b62 100644 --- a/test/dropout.jl +++ b/test/dropout.jl @@ -16,9 +16,6 @@ using Zygote, StableRNGs, ChainRulesCore, Enzyme @test size(@inferred dropout(rng, x1, 0.1; dims=2)) == (3, 4) x2 = Diagonal(randn(Float32, 10)) # Just to check it runs on weird matrices. - if VERSION > v"1.8-" # on 1.6 this makes a sparse array. - @test dropout(x2, 0.3) isa Matrix{Float32} # does not infer, but that's OK? - end # Values @test dropout(x1, 0) == x1 @@ -76,7 +73,7 @@ using Zygote, StableRNGs, ChainRulesCore, Enzyme @test_throws ArgumentError dropout!(y1, x1, 3) end -@static if Test_Enzyme +if NNLIB_TEST_ENZYME @testset "EnzymeRules: dropout " begin rng = Random.default_rng() diff --git a/test/pooling.jl b/test/pooling.jl index f9d57ade7..1b11a1aea 100644 --- a/test/pooling.jl +++ b/test/pooling.jl @@ -948,7 +948,7 @@ end gradtest(x -> sum(meanpool(x, k)), x) end -@static if Test_Enzyme +if NNLIB_TEST_ENZYME @testset "EnzymeRules: pooling! $pool spatial_rank=$spatial_rank " for spatial_rank in (1, 2), (pool, pool!) in ((maxpool, maxpool!), (meanpool, meanpool!)) diff --git a/test/runtests.jl b/test/runtests.jl index b8080b6ba..6805672e7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,10 +18,10 @@ import ReverseDiff as RD # used in `pooling.jl` import Pkg using SpecialFunctions -const Test_Enzyme = VERSION <= v"1.10-" DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib, UnicodePlots); recursive=true) +const NNLIB_TEST_ENZYME = true # ENV["NNLIB_TEST_CUDA"] = "true" # uncomment to run CUDA tests # ENV["NNLIB_TEST_AMDGPU"] = "true" # uncomment to run AMDGPU tests # ENV["NNLIB_TEST_CPU"] = "false" # uncomment to skip CPU tests diff --git a/test/testsuite/gather.jl b/test/testsuite/gather.jl index 92e3bfb7d..189533385 100644 --- a/test/testsuite/gather.jl +++ b/test/testsuite/gather.jl @@ -154,7 +154,7 @@ function gather_testsuite(Backend) gradtest_fn((s, i) -> gather(s, i), src, idx) end - @static if Test_Enzyme + if NNLIB_TEST_ENZYME @testset "EnzymeRules: gather! gradient for scalar index" begin src = device(Float64[3, 4, 5, 6, 7]) diff --git a/test/testsuite/scatter.jl b/test/testsuite/scatter.jl index aa0b1c41e..ddbf8eb67 100644 --- a/test/testsuite/scatter.jl +++ b/test/testsuite/scatter.jl @@ -208,7 +208,7 @@ function scatter_testsuite(Backend) end - @static if Test_Enzyme + if NNLIB_TEST_ENZYME @testset "EnzymeRules" begin idx = device([2, 2, 3, 4, 4]) From bcac8d6ab2e14f1a07c15b024b9db6245177519b Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 8 Apr 2025 09:22:00 +0200 Subject: [PATCH 2/2] cleanup --- test.jl | 11 ----------- 1 file changed, 11 deletions(-) delete mode 100644 test.jl diff --git a/test.jl b/test.jl deleted file mode 100644 index a20105f24..000000000 --- a/test.jl +++ /dev/null @@ -1,11 +0,0 @@ -import Metal, NNlib, Flux - -dev = Flux.get_device() - -src, idx = Int32[1 2 3 4; 5 6 7 8], Int32[2,1,1,5] -srcd, idxd = dev(x), dev(idx) -y = NNlib.scatter(+, src, idx) -yd = dev(zero(y)) -NNlib.scatter!(+, yd, srcd, idxd) - -