diff --git a/src/onehot.jl b/src/onehot.jl index d2d5e9d..dc58bdd 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -48,7 +48,7 @@ end _findval(val, labels::Tuple{}, i::Integer) = nothing """ - onehotbatch(xs, labels, [default]) + onehotbatch(xs, labels, [default]; dims::Val{D}=Val{1}) Returns a [`OneHotMatrix`](@ref) where `k`th column of the matrix is [`onehot(xs[k], labels)`](@ref onehot). This is a sparse matrix, which stores just a `Vector{UInt32}` containing the indices of the @@ -64,6 +64,8 @@ i.e. `result[:, k...] == onehot(xs[k...], labels)`. Note that `xs` can be any iterable, such as a string. And that using a tuple for `labels` will often speed up construction, certainly for less than 32 classes. +If dims keyword is given, the onehot vectors lie on the [dims] dimension rather than the first one. + # Examples ```jldoctest julia> oh = onehotbatch("abracadabra", 'a':'e', 'e') @@ -79,44 +81,73 @@ julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficientl 1 4 13 1 7 1 10 1 4 13 1 2 5 14 2 8 2 11 2 5 14 2 3 6 15 3 9 3 12 3 6 15 3 + +# One hot vectors on the second axis +julia> onehotbatch([0, 0, 7], 0:9; dims=Val(2)) +3×10 PermutedDimsArray(OneHotMatrix(::Vector{UInt32}), (2, 1)) with eltype Bool: + 1 0 0 0 0 0 0 0 0 0 + 1 0 0 0 0 0 0 0 0 0 + 0 0 0 0 0 0 0 1 0 0 ``` """ -onehotbatch(data, labels, default...) = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...) -function _onehotbatch(data, labels) - indices = UInt32[something(_findval(i, labels), 0) for i in data] - if 0 in indices - for x in data - isnothing(_findval(x, labels)) && error("Value $x not found in labels") - end - end - return OneHotArray(indices, length(labels)) -end +# developer note: +# onehotbatch is intended as the api and includes bounds checks +# _onehotbatch is intended as the implementation which includes membership checks +# _onehotbatch_fast same as above but without membership checks which would be slow on GPU -function _onehotbatch(data, labels, default) - default_index = _findval(default, labels) - isnothing(default_index) && error("Default value $default is not in labels") - indices = UInt32[something(_findval(i, labels), default_index) for i in data] - return OneHotArray(indices, length(labels)) +function onehotbatch(data::String, labels, default...; dims::Val{D} = Val(1)) where D + _onehotbatch(dims, data, length(labels) < 32 ? Tuple(labels) : labels, default...) end -function onehotbatch(data::AbstractArray{<:Integer}, labels::AbstractUnitRange{<:Integer}) +function onehotbatch(data::AbstractArray{<:Integer}, labels::AbstractUnitRange{<:Integer}, default...; dims::Val{D} = Val(1)) where D lo, hi = extrema(data) lo < first(labels) && error("Value $lo not found in labels") hi > last(labels) && error("Value $hi not found in labels") offset = 1 - first(labels) indices = UInt32.(data .+ offset) - return OneHotArray(indices, length(labels)) + _onehotbatch(dims, indices, length(labels) < 32 ? Tuple(labels) : labels) end + # That bounds check with extrema synchronises on GPU, much slower than rest of the function, # hence add a special method, with a less helpful error message: -function onehotbatch(data::AbstractGPUArray{<:Integer}, labels::AbstractUnitRange{<:Integer}) +function onehotbatch(data::AbstractGPUArray{<:Integer}, labels::AbstractUnitRange{<:Integer}, default...; dims::Val{D} = Val(1)) where D offset = 1 - first(labels) indices = map(data) do datum i = UInt32(datum + offset) checkbounds(labels, i) i end + _onehotbatch_fast(dims, indices, length(labels) < 32 ? Tuple(labels) : labels) +end +# _onehotbatch_fast does not have the bounds checks in _onehotbatch which would slow down GPU, but allows permute +_onehotbatch_fast(dims::Val{D}, indices, labels) where D = _permute(dims, _onehotbatch_fast(Val(1), indices, labels)) +_onehotbatch_fast(::Val{1}, indices, labels) = OneHotArray(indices, length(labels)) + +_onehotbatch(dims::Val, data, labels, default...) = _permute(dims, _onehotbatch(Val(1), data, labels, default...)) + +_permute(::Val{2}, array::OneHotArray{<:Any, 1, 2}) = transpose(array) +function _permute(::Val{d}, array::OneHotArray{<:Any, N,M}) where {d, N, M} + perm = Tuple(ntuple(d -> d==D ? 1 : (d==1 ? D : d), M)) + # need to use obtuse PermutedDimsArray constructor in order to stabilise permuation types + iperm = invperm(perm) + PermutedDimsArray{eltype(out),M,(perm...,),(iperm...,),typeof(out)}(out) +end + +function _onehotbatch(::Val{1}, data, labels) + indices = UInt32[something(_findval(i, labels), 0) for i in data] + if 0 in indices + for x in data + isnothing(_findval(x, labels)) && error("Value $x not found in labels") + end + end + return OneHotArray(indices, length(labels)) +end + +function _onehotbatch(::Val{1}, data, labels, default) + default_index = _findval(default, labels) + isnothing(default_index) && error("Default value $default is not in labels") + indices = UInt32[something(_findval(i, labels), default_index) for i in data] return OneHotArray(indices, length(labels)) end diff --git a/test/onehot.jl b/test/onehot.jl index fffac19..27c142c 100644 --- a/test/onehot.jl +++ b/test/onehot.jl @@ -69,3 +69,17 @@ end @test y[:,1] isa OneHotVector @test y[:,:] isa OneHotMatrix end + +@testset "onehotbatch dims" begin + # basic tests + @test onehotbatch([20, 10], 10:10:30; dims=Val(2)) == Bool[0 1 0; 1 0 0] + @test onehotbatch([10, 20], [30, 40, 50], 30; dims=Val(2)) == Bool[1 0 0; 1 0 0] + # higher dimensions + @test size(onehotbatch(reshape(collect(1:12), 3, 4), 1:12; dims=Val(2))) == (3, 12, 4) # test shape + @test sum(onehotbatch(reshape(collect(1:12), 3, 4), 1:12; dims=Val(2)), dims=2)[:] == ones(12) # test onehot on the second dim + # works with strings + @test onehotbatch("ba", 'a':'c'; dims=Val(2)) == Bool[0 1 0; 1 0 0] + + @test @inferred(onehotbatch([20, 10], 10:10:30; dims=Val(2))) == Bool[0 1 0; 1 0 0] + @test @inferred(onehotbatch([40, 10], (10,20,30), 20; dims=Val(2))) == Bool[0 1 0; 1 0 0] +end \ No newline at end of file