|
1 | 1 | using Optimisers
|
2 |
| -using ChainRulesCore #, Functors, StaticArrays, Zygote |
3 |
| -using LinearAlgebra, Statistics, Test |
| 2 | +using ChainRulesCore, Zygote |
| 3 | +using Test |
4 | 4 |
|
5 | 5 | import CUDA
|
6 | 6 | if CUDA.functional()
|
7 | 7 | using CUDA # exports CuArray, etc
|
8 |
| - @info "starting CUDA tests" |
9 | 8 | else
|
10 |
| - @info "CUDA not functional, testing via GPUArrays" |
11 |
| - using GPUArrays |
12 |
| - GPUArrays.allowscalar(false) |
| 9 | + @info "CUDA not functional, testing with JLArrays instead" |
| 10 | + using JLArrays |
| 11 | + JLArrays.allowscalar(false) |
13 | 12 |
|
14 |
| - # GPUArrays provides a fake GPU array, for testing |
15 |
| - jl_file = normpath(joinpath(pathof(GPUArrays), "..", "..", "test", "jlarray.jl")) |
16 |
| - using Random, Adapt # loaded within jl_file |
17 |
| - include(jl_file) |
18 |
| - using .JLArrays |
19 |
| - cu = jl |
| 13 | + cu = jl32 |
20 | 14 | CuArray{T,N} = JLArray{T,N}
|
21 | 15 | end
|
22 |
| - |
23 | 16 | @test cu(rand(3)) .+ 1 isa CuArray
|
24 | 17 |
|
25 | 18 | @testset "very basics" begin
|
|
89 | 82 | v, re = destructure(m)
|
90 | 83 | @test v isa CuArray
|
91 | 84 | @test re(2v).x isa CuArray
|
| 85 | + |
| 86 | + dm = gradient(m -> sum(abs2, destructure(m)[1]), m)[1] |
| 87 | + @test dm.z isa CuArray |
| 88 | + dv = gradient(v -> sum(abs2, re(v).z), cu([10, 20, 30, 40, 50.0]))[1] |
| 89 | + @test dv isa CuArray |
92 | 90 | end
|
93 | 91 |
|
94 | 92 | @testset "destructure mixed" begin
|
|
0 commit comments