From 1686f1250835702a764cebf80ffccc8773b3c8bd Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 24 Jul 2022 22:27:42 -0400 Subject: [PATCH 1/2] add AbstractArray construction methods --- src/onehot.jl | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/src/onehot.jl b/src/onehot.jl index 6231f3c..a2d75dd 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -83,20 +83,36 @@ julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficientl """ onehotbatch(data, labels, default...) = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...) -function _onehotbatch(data, labels) - indices = UInt32[something(_findval(i, labels), 0) for i in data] - if 0 in indices +function _onehotbatch(data, labels) # this accepts any iterator + indices = UInt32[something(_findval(x, labels), 0) for x in data] + if any(iszero, indices) for x in data - isnothing(_findval(x, labels)) && error("Value $x not found in labels") + isnothing(_findval(x, labels)) && throw(ArgumentError("Value x = $x not found in labels = $labels")) end end return OneHotArray(indices, length(labels)) end +function _onehotbatch(data::AbstractArray, labels) # this works for GPUArrays too + indices = similar(data, UInt32) + map!(x -> something(_findval(x, labels), 0), indices, data) + if any(iszero, indices) + badx = @allowscalar data[findfirst(iszero, indices)] + throw(ArgumentError("Value x = $badx not found in labels = $labels")) + end + return OneHotArray(indices, length(labels)) +end function _onehotbatch(data, labels, default) default_index = _findval(default, labels) - isnothing(default_index) && error("Default value $default is not in labels") - indices = UInt32[something(_findval(i, labels), default_index) for i in data] + isnothing(default_index) && throw(ArgumentError("Default value $default is not in labels = $labels")) + indices = UInt32[something(_findval(x, labels), default_index) for x in data] + return OneHotArray(indices, length(labels)) +end +function _onehotbatch(data::AbstractArray, labels, default) + default_index = _findval(default, labels) + isnothing(default_index) && throw(ArgumentError("Default value $default is not in labels = $labels")) + indices = similar(data, UInt32) + map!(x -> something(_findval(x, labels), default_index), indices, data) return OneHotArray(indices, length(labels)) end From f05824432ef4c807ee45645b4dc54842df1d810f Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 24 Jul 2022 22:27:50 -0400 Subject: [PATCH 2/2] tests --- test/gpu.jl | 49 +++++++++++++++++++++++++++---------------------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/test/gpu.jl b/test/gpu.jl index 13c208c..1cb508a 100644 --- a/test/gpu.jl +++ b/test/gpu.jl @@ -1,29 +1,23 @@ -# Tests from Flux, probably not the optimal testset organisation! - -@testset "CUDA" begin - x = randn(5, 5) - cx = cu(x) - @test cx isa CuArray - - @test_skip onecold(cu([1.0, 2.0, 3.0])) == 3 # passes with CuArray with Julia 1.6, but fails with JLArray - - x = onehotbatch([1, 2, 3], 1:3) - cx = cu(x) - @test cx isa OneHotMatrix && cx.indices isa CuArray - @test (cx .+ 1) isa CuArray - +@testset "onehotbatch gpu" begin + # move to GPU after construction + x = onehotbatch([1, 2, 3, 2], 1:3) + @test cu(x) isa OneHotMatrix + @test cu(x).indices isa CuArray + + # broadcast style works: + @test (cu(x) .+ 1) isa CuArray xs = rand(5, 5) - ys = onehotbatch(1:5,1:5) + ys = onehotbatch(rand(1:5, 5), 1:5) @test collect(cu(xs) .+ cu(ys)) ≈ collect(xs .+ ys) -end -@testset "onehot gpu" begin - y = onehotbatch(ones(3), 1:2) |> cu; - @test (repr("text/plain", y); true) - - gA = rand(3, 2) |> cu; - @test_broken gradient(A -> sum(A * y), gA)[1] isa CuArray # fails with JLArray, bug in Zygote? + # move to GPU before construction + z1 = onehotbatch(cu([3f0, 1f0, 2f0, 2f0]), (1.0, 2f0, 3)) + @test z1.indices isa CuArray + z2 = onehotbatch(cu([3f0, 1f0, 2f0, 2f0]), [1, 2], 2) # with default + @test z2.indices isa CuArray + @test_throws ArgumentError onehotbatch(cu([1, 2, 3]), [1, 2]) # friendly error, not scalar indexing + @test_throws ArgumentError onehotbatch(cu([1, 2, 3]), [1, 2], 5) end @testset "onecold gpu" begin @@ -32,6 +26,17 @@ end @test onecold(y) isa CuArray @test y[3,:] isa CuArray @test onecold(y, l) == ['a', 'a', 'a'] + + @test_skip onecold(cu([1.0, 2.0, 3.0])) == 3 # passes with CuArray with Julia 1.6, but fails with JLArray +end + +@testset "matrix multiplication gpu" begin + y = onehotbatch([1, 2, 1], [1, 2]) |> cu; + A = rand(3, 2) |> cu; + + @test_broken collect(A * y) ≈ collect(A) * collect(y) + + @test_broken gradient(A -> sum(abs, A * y), A)[1] isa CuArray # gather!(dst::JLArray, ...) fails end @testset "onehot forward map to broadcast" begin