Skip to content

Make ClipNorm work on GPU Broadcasted #144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Optimisers"
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
authors = ["Mike J Innes <mike.j.innes@gmail.com>"]
version = "0.2.18"
version = "0.2.19"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
44 changes: 43 additions & 1 deletion src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ ClipNorm(ω = 10f0, p = 2; throw::Bool = true) = ClipNorm{float(typeof(ω))}(ω,
init(o::ClipNorm, x::AbstractArray) = nothing

function apply!(o::ClipNorm, state, x, dx)
nrm = norm(dx, o.p)
nrm = _norm(dx, o.p)
if o.throw && !isfinite(nrm)
throw(DomainError("gradient has $(o.p)-norm $nrm, for array $(summary(x))"))
end
Expand All @@ -620,6 +620,48 @@ function apply!(o::ClipNorm, state, x, dx)
return state, @lazy dx * λ
end

_norm(dx::AbstractArray, p::Real) = norm(dx, p) # LinearAlgebra, CUDA
function _norm(dx::Broadcast.Broadcasted, p::Real)
if p == 2
# This lacks the undeflow/overflow tests of LinearAlgebra's version
sqrt(sum(abs2, dx))
elseif p == 1
float(sum(abs, dx))
elseif p == Inf
float(maximum(abs, dx))
elseif p == 0
cnt = count(!iszero, dx)
T = Base.@default_eltype dx
T <: Number ? convert(float(T), cnt) : cnt
elseif p == -Inf
float(minimum(abs, dx))
else
# This isn't optimally fast but does ensure p::Float64 doesn't promote
tmp = abs.(dx)
q = convert(float(eltype(tmp)), p)
sum(tmp .^ q) ^ (1/q)
end
end

#=

julia> using Metal

julia> using Base.Broadcast: broadcasted, instantiate

julia> bc = instantiate(broadcasted(+, MtlArray(rand(Float32, 3)), 1));

julia> norm(bc)
┌ Warning: Performing scalar indexing

└ @ Metal ~/.julia/packages/Metal/TtPHW/src/compiler/compilation.jl:77
ERROR: NSError: Undefined symbols:
llvm.maximum.f32, referenced from: _Z24partial_mapreduce_device8identity3max7Float323ValILi1024EES2_I22CartesianIndices__3___ES2_I22CartesianIndices__1___ES2_ILi1EES2_ILi1EES2_ILitrueEE14MtlDeviceArrayIS1_Li2ELi1EE11BroadcastedI13MtlArrayStyleILi1EE5TupleI5OneToI5Int64EE4normS6_IS4_IS5_ILi1EES6_IS7_IS8_EE1_S6_IS3_IS1_Li1ELi1EES8_EEEE

julia> Metal.allowscalar(false)

=#

"""
OptimiserChain(opts...)

Expand Down
14 changes: 14 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using Optimisers
using ChainRulesCore, Functors, StaticArrays, Zygote, Yota
using LinearAlgebra, Statistics, Test, Random
using Optimisers: @.., @lazy
using Base.Broadcast: broadcasted, instantiate, Broadcasted

Random.seed!(1)

Expand Down Expand Up @@ -506,6 +507,19 @@ y2z(x) = x
y = Optimisers.subtract!(x, nothing)
@test y === x
end

@testset "_norm(dx, p) works" begin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we test the interface instead of an internal method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At present there are no GPU tests, and norm(::Broadcasted{..., Array}) works without error. So other tests which didn't fail before do call _norm.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#71 was going to add some tests, but did not run into this failure. I haven't checked whether JLArray will in fact see it, but Metal's array type (apple M1) does.

bc = instantiate(broadcasted(+, randn(Float32, 10), randn(Float32, 10)'));
arr = collect(bc)
bc2 = instantiate(broadcasted(+, [1, 0, -3, 4], 0))
arr2 = collect(bc2)
for p in (-Inf, -3, -1, 0, 0.5, 1, 1.5, 2, 3f0, Inf32)
@test Optimisers._norm(bc, p) ≈ norm(arr, p)
@test Optimisers._norm(bc, p) isa Float32
@test Optimisers._norm(bc2, p) ≈ norm(arr2, p)
@test Optimisers._norm(bc2, p) isa Float64
end
end
end
@testset verbose=true "Destructure" begin
include("destructure.jl")
Expand Down