From c3a02b519b3e2b2ac77f7c8c2ad1c1786540ad8f Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 27 Apr 2023 10:51:14 -0400 Subject: [PATCH 1/3] add _norm --- Project.toml | 2 +- src/rules.jl | 44 +++++++++++++++++++++++++++++++++++++++++++- test/runtests.jl | 14 ++++++++++++++ 3 files changed, 58 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index cbe10920..d3f7ea1c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Optimisers" uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" authors = ["Mike J Innes "] -version = "0.2.18" +version = "0.2.19" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rules.jl b/src/rules.jl index 82237eba..5e5fd269 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -611,7 +611,7 @@ ClipNorm(ω = 10f0, p = 2; throw::Bool = true) = ClipNorm{float(typeof(ω))}(ω, init(o::ClipNorm, x::AbstractArray) = nothing function apply!(o::ClipNorm, state, x, dx) - nrm = norm(dx, o.p) + nrm = _norm(dx, o.p) if o.throw && !isfinite(nrm) throw(DomainError("gradient has $(o.p)-norm $nrm, for array $(summary(x))")) end @@ -620,6 +620,48 @@ function apply!(o::ClipNorm, state, x, dx) return state, @lazy dx * λ end +_norm(dx::AbstractArray, p::Real) = norm(dx, p) # LinearAlgebra, CUDA +function _norm(dx::Broadcast.Broadcasted, p::Real) + if p == 2 + # This lacks the undeflow/overflow tests of LinearAlgebra's version + sqrt(sum(abs2, dx)) + elseif p == 1 + float(sum(abs, dx)) + elseif p == Inf + float(maximum(abs, dx)) + elseif p == 0 + cnt = count(!iszero, dx) + T = Base.@default_eltype dx + T <: Number ? convert(float(T), cnt) : cnt + elseif p == -Inf + float(minimum(abs, dx)) + else + # This isn't optimally fast but does ensure p::Float64 doesn't promote + tmp = abs.(dx) + q = convert(float(eltype(tmp)), p) + sum(tmp .^ q) ^ (1/q) + end +end + +#= + +julia> using Metal + +julia> using Base.Broadcast: broadcasted, instantiate + +julia> bc = instantiate(broadcasted(+, MtlArray(rand(Float32, 3)), 1)); + +julia> norm(bc) +┌ Warning: Performing scalar indexing + +└ @ Metal ~/.julia/packages/Metal/TtPHW/src/compiler/compilation.jl:77 +ERROR: NSError: Undefined symbols: + llvm.maximum.f32, referenced from: _Z24partial_mapreduce_device8identity3max7Float323ValILi1024EES2_I22CartesianIndices__3___ES2_I22CartesianIndices__1___ES2_ILi1EES2_ILi1EES2_ILitrueEE14MtlDeviceArrayIS1_Li2ELi1EE11BroadcastedI13MtlArrayStyleILi1EE5TupleI5OneToI5Int64EE4normS6_IS4_IS5_ILi1EES6_IS7_IS8_EE1_S6_IS3_IS1_Li1ELi1EES8_EEEE + +julia> Metal.allowscalar(false) + +=# + """ OptimiserChain(opts...) diff --git a/test/runtests.jl b/test/runtests.jl index 7ee578bc..8f69d787 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,7 @@ using Optimisers using ChainRulesCore, Functors, StaticArrays, Zygote, Yota using LinearAlgebra, Statistics, Test, Random using Optimisers: @.., @lazy +using Base.Broadcast: broadcasted, instantiate, Broadcasted Random.seed!(1) @@ -506,6 +507,19 @@ y2z(x) = x y = Optimisers.subtract!(x, nothing) @test y === x end + + @testset "_norm(dx, p) works" begin + bc = instantiate(broadcasted(+, randn(Float32, 10), randn(Float32, 10)')); + arr = collect(bc) + bc2 = instantiate(broadcasted(+, [1, 0, -3, 4], 0)) + arr2 = collect(bc2) + for p in (-Inf, -3, -1, 0, 0.5, 1, 1.5, 2, 3f0, Inf32) + @test Optimisers._norm(bc, p) ≈ norm(arr, p) + @test Optimisers._norm(bc, p) isa Float32 + @test Optimisers._norm(bc2, p) ≈ norm(arr2, p) + @test Optimisers._norm(bc2, p) isa Float64 + end + end end @testset verbose=true "Destructure" begin include("destructure.jl") From 32024531ce72b1e7648a1d9557a646b03390ca05 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 29 Apr 2023 14:12:42 -0400 Subject: [PATCH 2/3] fix test which wasn't doing the right thing before --- test/runtests.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 8f69d787..4e02f4d0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -90,7 +90,8 @@ y2z(x) = x _, m2 = Optimisers.update(s2, m, (α = ([0.1], nothing), γ = [1,10,100],)) @test only(m.α[1] .- m2.α[1]) ≈ 0.1 @test norm(m.γ .- m2.γ) ≈ 10 - @test_throws DomainError Optimisers.update(s2, m, (α = [0.1], γ = [1,10,NaN],)) + # This error is thrown by apply! due to NaN input. + @test_throws DomainError Optimisers.update(s2, m, (α = ([0.1], nothing), γ = [1,10,NaN],)) s3 = Optimisers.setup(ClipNorm(5, 1; throw=false), m) _, m3 = Optimisers.update(s3, m, (α = ([0.1], nothing), γ = [1,10,100],)) From d73a0ee5cef0f7f2bf247ce5d1db90bb0f736182 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 29 Apr 2023 15:24:41 -0400 Subject: [PATCH 3/3] rm Metal comment --- src/rules.jl | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index 5e5fd269..e994b740 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -643,25 +643,6 @@ function _norm(dx::Broadcast.Broadcasted, p::Real) end end -#= - -julia> using Metal - -julia> using Base.Broadcast: broadcasted, instantiate - -julia> bc = instantiate(broadcasted(+, MtlArray(rand(Float32, 3)), 1)); - -julia> norm(bc) -┌ Warning: Performing scalar indexing - -└ @ Metal ~/.julia/packages/Metal/TtPHW/src/compiler/compilation.jl:77 -ERROR: NSError: Undefined symbols: - llvm.maximum.f32, referenced from: _Z24partial_mapreduce_device8identity3max7Float323ValILi1024EES2_I22CartesianIndices__3___ES2_I22CartesianIndices__1___ES2_ILi1EES2_ILi1EES2_ILitrueEE14MtlDeviceArrayIS1_Li2ELi1EE11BroadcastedI13MtlArrayStyleILi1EE5TupleI5OneToI5Int64EE4normS6_IS4_IS5_ILi1EES6_IS7_IS8_EE1_S6_IS3_IS1_Li1ELi1EES8_EEEE - -julia> Metal.allowscalar(false) - -=# - """ OptimiserChain(opts...)