Skip to content

Commit f69a180

Browse files
committed
use JLArrays, add gradient test
1 parent 9e8a0b4 commit f69a180

File tree

2 files changed

+13
-16
lines changed

2 files changed

+13
-16
lines changed

Project.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@ Zygote = "0.6.40"
1717
julia = "1.6"
1818

1919
[extras]
20-
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
2120
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
22-
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
21+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
2322
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2423
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2524
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2625

2726
[targets]
28-
test = ["Adapt", "CUDA", "GPUArrays", "StaticArrays", "Test", "Zygote"]
27+
test = ["CUDA", "JLArrays", "StaticArrays", "Test", "Zygote"]

test/gpuarrays.jl

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,18 @@
11
using Optimisers
2-
using ChainRulesCore #, Functors, StaticArrays, Zygote
3-
using LinearAlgebra, Statistics, Test
2+
using ChainRulesCore, Zygote
3+
using Test
44

55
import CUDA
66
if CUDA.functional()
77
using CUDA # exports CuArray, etc
8-
@info "starting CUDA tests"
98
else
10-
@info "CUDA not functional, testing via GPUArrays"
11-
using GPUArrays
12-
GPUArrays.allowscalar(false)
9+
@info "CUDA not functional, testing with JLArrays instead"
10+
using JLArrays
11+
JLArrays.allowscalar(false)
1312

14-
# GPUArrays provides a fake GPU array, for testing
15-
jl_file = normpath(joinpath(pathof(GPUArrays), "..", "..", "test", "jlarray.jl"))
16-
using Random, Adapt # loaded within jl_file
17-
include(jl_file)
18-
using .JLArrays
19-
cu = jl
13+
cu = jl32
2014
CuArray{T,N} = JLArray{T,N}
2115
end
22-
2316
@test cu(rand(3)) .+ 1 isa CuArray
2417

2518
@testset "very basics" begin
@@ -89,6 +82,11 @@ end
8982
v, re = destructure(m)
9083
@test v isa CuArray
9184
@test re(2v).x isa CuArray
85+
86+
dm = gradient(m -> sum(abs2, destructure(m)[1]), m)[1]
87+
@test dm.z isa CuArray
88+
dv = gradient(v -> sum(abs2, re(v).z), cu([10, 20, 30, 40, 50.0]))[1]
89+
@test dv isa CuArray
9290
end
9391

9492
@testset "destructure mixed" begin

0 commit comments

Comments
 (0)