|
1 | 1 | using Optimisers
|
2 |
| -using ChainRulesCore #, Functors, StaticArrays, Zygote |
3 |
| -using LinearAlgebra, Statistics, Test |
| 2 | +using ChainRulesCore, Zygote |
| 3 | +using Test |
4 | 4 |
|
5 | 5 | import CUDA
|
6 | 6 | if CUDA.functional()
|
7 | 7 | using CUDA # exports CuArray, etc
|
8 |
| - @info "starting CUDA tests" |
| 8 | + CUDA.allowscalar(false) |
9 | 9 | else
|
10 |
| - @info "CUDA not functional, testing via GPUArrays" |
11 |
| - using GPUArrays |
12 |
| - GPUArrays.allowscalar(false) |
| 10 | + @info "CUDA not functional, testing with JLArrays instead" |
| 11 | + using JLArrays |
| 12 | + JLArrays.allowscalar(false) |
13 | 13 |
|
14 |
| - # GPUArrays provides a fake GPU array, for testing |
15 |
| - jl_file = normpath(joinpath(pathof(GPUArrays), "..", "..", "test", "jlarray.jl")) |
16 |
| - using Random, Adapt # loaded within jl_file |
17 |
| - include(jl_file) |
18 |
| - using .JLArrays |
19 | 14 | cu = jl
|
20 | 15 | CuArray{T,N} = JLArray{T,N}
|
21 | 16 | end
|
22 |
| - |
23 | 17 | @test cu(rand(3)) .+ 1 isa CuArray
|
24 | 18 |
|
25 | 19 | @testset "very basics" begin
|
|
89 | 83 | v, re = destructure(m)
|
90 | 84 | @test v isa CuArray
|
91 | 85 | @test re(2v).x isa CuArray
|
| 86 | + |
| 87 | + dm = gradient(m -> sum(abs2, destructure(m)[1]), m)[1] |
| 88 | + @test dm.z isa CuArray |
| 89 | + dv = gradient(v -> sum(abs2, re(v).z), cu([10, 20, 30, 40, 50.0]))[1] |
| 90 | + @test dv isa CuArray |
92 | 91 | end
|
93 | 92 |
|
94 | 93 | @testset "destructure mixed" begin
|
|
0 commit comments