Skip to content

Commit d8a5eb1

Browse files
committed
test with GPUArrays
1 parent dd571ca commit d8a5eb1

File tree

3 files changed

+112
-3
lines changed

3 files changed

+112
-3
lines changed

Project.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Optimisers"
22
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
33
authors = ["Mike J Innes <mike.j.innes@gmail.com>"]
4-
version = "0.2.2"
4+
version = "0.2.2"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -16,9 +16,12 @@ Functors = "0.2.8"
1616
julia = "1.6"
1717

1818
[extras]
19-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
19+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
20+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
21+
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
2022
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
23+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2124
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2225

2326
[targets]
24-
test = ["Test", "StaticArrays", "Zygote"]
27+
test = ["Adapt", "CUDA", "GPUArrays", "StaticArrays", "Test", "Zygote"]

test/gpuarrays.jl

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
using Optimisers
2+
using ChainRulesCore #, Functors, StaticArrays, Zygote
3+
using LinearAlgebra, Statistics, Test
4+
5+
import CUDA
6+
if CUDA.functional()
7+
using CUDA # exports CuArray, etc
8+
@info "starting CUDA tests"
9+
else
10+
@info "CUDA not functional, testing via GPUArrays"
11+
using GPUArrays
12+
GPUArrays.allowscalar(false)
13+
14+
# GPUArrays provides a fake GPU array, for testing
15+
jl_file = normpath(joinpath(pathof(GPUArrays), "..", "..", "test", "jlarray.jl"))
16+
using Random, Adapt # loaded within jl_file
17+
include(jl_file)
18+
using .JLArrays
19+
cu = jl
20+
CuArray{T,N} = JLArray{T,N}
21+
end
22+
23+
@test cu(rand(3)) .+ 1 isa CuArray
24+
25+
@testset "very basics" begin
26+
m = (cu([1.0, 2.0]),)
27+
mid = objectid(m[1])
28+
g = (cu([25, 33]),)
29+
o = Descent(0.1f0)
30+
s = Optimisers.setup(o, m)
31+
32+
s2, m2 = Optimisers.update(s, m, g)
33+
@test Array(m[1]) == 1:2 # not mutated
34+
@test m2[1] isa CuArray
35+
@test Array(m2[1]) [1,2] .- 0.1 .* [25, 33] atol=1e-6
36+
37+
s3, m3 = Optimisers.update!(s, m, g)
38+
@test objectid(m3[1]) == mid
39+
@test Array(m3[1]) [1,2] .- 0.1 .* [25, 33] atol=1e-6
40+
41+
g4 = Tangent{typeof(m)}(g...)
42+
s4, m4 = Optimisers.update!(s, (cu([1.0, 2.0]),), g4)
43+
@test Array(m4[1]) [1,2] .- 0.1 .* [25, 33] atol=1e-6
44+
end
45+
46+
@testset "basic mixed" begin
47+
# Works trivially as every element of the tree is either here or there
48+
m = (device = cu([1.0, 2.0]), host = [3.0, 4.0], neither = (5, 6, sin))
49+
s = Optimisers.setup(ADAM(0.1), m)
50+
@test s.device.state[1] isa CuArray
51+
@test s.host.state[1] isa Array
52+
53+
g = (device = cu([1, 0.1]), host = [1, 10], neither = nothing)
54+
s2, m2 = Optimisers.update(s, m, g)
55+
56+
@test m2.device isa CuArray
57+
@test Array(m2.device) [0.9, 1.9] atol=1e-6
58+
59+
@test m2.host isa Array
60+
@test m2.host [2.9, 3.9]
61+
end
62+
63+
RULES = [
64+
# Just a selection:
65+
Descent(), ADAM(), RMSProp(), NADAM(),
66+
# A few chained combinations:
67+
OptimiserChain(WeightDecay(), ADAM(0.001)),
68+
OptimiserChain(ClipNorm(), ADAM(0.001)),
69+
OptimiserChain(ClipGrad(0.5), Momentum()),
70+
]
71+
72+
name(o) = typeof(o).name.name # just for printing testset headings
73+
name(o::OptimiserChain) = join(name.(o.opts), "")
74+
75+
@testset "rules: simple sum" begin
76+
@testset "$(name(o))" for o in RULES
77+
m = cu(shuffle!(reshape(1:64, 8, 8) .+ 0.0))
78+
s = Optimisers.setup(o, m)
79+
for _ in 1:10
80+
g = Zygote.gradient(x -> sum(abs2, x + x'), m)[1]
81+
s, m = Optimisers.update!(s, m, g)
82+
end
83+
@test sum(m) < sum(1:64)
84+
end
85+
end
86+
87+
@testset "destructure GPU" begin
88+
m = (x = cu(Float32[1,2,3]), y = (0, 99), z = cu(Float32[4,5]))
89+
v, re = destructure(m)
90+
@test v isa CuArray
91+
@test re(2v).x isa CuArray
92+
end
93+
94+
@testset "destructure mixed" begin
95+
# Not sure what should happen here!
96+
m_c1 = (x = cu(Float32[1,2,3]), y = Float32[4,5])
97+
v, re = destructure(m_c1)
98+
@test re(2v).x isa CuArray
99+
@test_broken re(2v).y isa Array
100+
101+
m_c2 = (x = Float32[1,2,3], y = cu(Float32[4,5]))
102+
@test_skip destructure(m_c2) # ERROR: Scalar indexing
103+
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,4 +172,7 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
172172
@testset verbose=true "Optimisation Rules" begin
173173
include("rules.jl")
174174
end
175+
@testset verbose=true "GPU" begin
176+
include("gpuarrays.jl")
177+
end
175178
end

0 commit comments

Comments
 (0)