Skip to content

Commit 307beef

Browse files
committed
use JLArrays, add gradient test
1 parent 9e8a0b4 commit 307beef

File tree

2 files changed

+13
-15
lines changed

2 files changed

+13
-15
lines changed

Project.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@ Zygote = "0.6.40"
1717
julia = "1.6"
1818

1919
[extras]
20-
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
2120
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
22-
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
21+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
2322
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2423
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2524
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2625

2726
[targets]
28-
test = ["Adapt", "CUDA", "GPUArrays", "StaticArrays", "Test", "Zygote"]
27+
test = ["CUDA", "JLArrays", "StaticArrays", "Test", "Zygote"]

test/gpuarrays.jl

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,19 @@
11
using Optimisers
2-
using ChainRulesCore #, Functors, StaticArrays, Zygote
3-
using LinearAlgebra, Statistics, Test
2+
using ChainRulesCore, Zygote
3+
using Test
44

55
import CUDA
66
if CUDA.functional()
77
using CUDA # exports CuArray, etc
8-
@info "starting CUDA tests"
8+
CUDA.allowscalar(false)
99
else
10-
@info "CUDA not functional, testing via GPUArrays"
11-
using GPUArrays
12-
GPUArrays.allowscalar(false)
10+
@info "CUDA not functional, testing with JLArrays instead"
11+
using JLArrays
12+
JLArrays.allowscalar(false)
1313

14-
# GPUArrays provides a fake GPU array, for testing
15-
jl_file = normpath(joinpath(pathof(GPUArrays), "..", "..", "test", "jlarray.jl"))
16-
using Random, Adapt # loaded within jl_file
17-
include(jl_file)
18-
using .JLArrays
1914
cu = jl
2015
CuArray{T,N} = JLArray{T,N}
2116
end
22-
2317
@test cu(rand(3)) .+ 1 isa CuArray
2418

2519
@testset "very basics" begin
@@ -89,6 +83,11 @@ end
8983
v, re = destructure(m)
9084
@test v isa CuArray
9185
@test re(2v).x isa CuArray
86+
87+
dm = gradient(m -> sum(abs2, destructure(m)[1]), m)[1]
88+
@test dm.z isa CuArray
89+
dv = gradient(v -> sum(abs2, re(v).z), cu([10, 20, 30, 40, 50.0]))[1]
90+
@test dv isa CuArray
9291
end
9392

9493
@testset "destructure mixed" begin

0 commit comments

Comments
 (0)