Skip to content

Commit ebd37d6

Browse files
bors[bot]darsnack
andauthored
Merge #1448
1448: Arbitrary dimension one-hot arrays r=DhairyaLGandhi a=darsnack This supersedes #1447. It should address the same issues: - fix #1445, #1229 - probably fix also #864, #556, #189 This PR introduces a new one-hot N-dimensional array type, `OneHotArray`. Like #1447, this approach avoids the pointer allocations associated with `OneHotMatrix` being an array of `OneHotVector`s. It also lifts the "height" into the type parameter to avoid unnecessary allocation. Unlike #1447, this approach does not introduce a new primitive type. Instead, a "one-hot vector" is represented with a single subtype of `Integer` that is configurable by the user. By default, the exposed API will use `UInt32`. Fundamentally, the primitive type is necessary because wrapping a `UInt32` as a `OneHotVector` will suffer memory penalties when you create an `Array{<:OneHotVector}`. But if we begin by designing for N-dimensions, then `OneHotVector` is just the specialized 1D case (similar to how `Vector{T} = Array{T, 1}`). ## Performance I compared against the same tests mentioned in #1447. Please suggest more if you want to. 1. #189 ```jl #master julia> x = Flux.onehotbatch(rand(1:100, 50), 1:100); julia> W = rand(128, 100); julia> @Btime $W * $x; 5.095 μs (13 allocations: 50.86 KiB) julia> cW, cx = cu(W), cu(x); julia> @Btime $cW * $cx; 24.948 μs (86 allocations: 3.11 KiB) #1447 julia> x = Flux.onehotbatch(rand(1:100, 50), 1:100); julia> W = rand(128, 100); julia> @Btime $W * $x; 5.312 μs (3 allocations: 50.36 KiB) julia> cW, cx = cu(W), cu(x); julia> @Btime $cW * $cx; 8.466 μs (61 allocations: 1.69 KiB) # this PR julia> x = Flux.onehotbatch(rand(1:100, 50), 1:100); julia> W = rand(128, 100); julia> @Btime $W * $x; 4.708 μs (3 allocations: 50.56 KiB) julia> cW, cx = cu(W), cu(x); julia> @Btime $cW * $cx; 8.576 μs (63 allocations: 1.73 KiB) ``` 2. #556 ```jl #master julia> valY = randn(1000, 128); julia> @Btime Flux.onecold($valY); 365.712 μs (1131 allocations: 38.16 KiB) julia> @Btime Flux.onecold($(gpu(valY))); ┌ Warning: Performing scalar operations on GPU arrays: This is very slow, consider disallowing these operations with `allowscalar(false)` └ @ GPUArrays ~/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:43 1.330 s (781248 allocations: 31.59 MiB) #1447 julia> valY = randn(1000, 128); julia> @Btime Flux.onecold($valY); 524.767 μs (8 allocations: 4.00 KiB) julia> @Btime Flux.onecold($(gpu(valY))); 27.563 μs (169 allocations: 5.56 KiB) # this PR julia> valY = randn(1000, 128); julia> @Btime Flux.onecold($valY); 493.017 μs (8 allocations: 4.53 KiB) julia> @Btime Flux.onecold($(gpu(valY))); 26.702 μs (171 allocations: 5.61 KiB) ``` ## Summary This should basically be #1447 but simpler to maintain w/ fewer changes. Tests are passing, though I think we should add more tests for one-hot data (currently our test set seems pretty sparse). Performance matches #1447 where I have tested, but please suggest more performance tests. In theory, any performance difference between #1447 and this PR should be recoverable. ### PR Checklist - [ ] Tests are added - [ ] Entry in NEWS.md - [ ] Documentation, if applicable - [ ] Final review from @DhairyaLGandhi (for API changes). cc @CarloLucibello @chengchingwen Co-authored-by: Kyle Daruwalla <daruwalla@wisc.edu> Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
2 parents ef7aba3 + aa63a5a commit ebd37d6

File tree

4 files changed

+144
-53
lines changed

4 files changed

+144
-53
lines changed

docs/src/data/onehot.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ It's common to encode categorical variables (like `true`, `false` or `cat`, `dog
66
julia> using Flux: onehot, onecold
77
88
julia> onehot(:b, [:a, :b, :c])
9-
3-element Flux.OneHotVector:
9+
3-element Flux.OneHotArray{UInt32,3,0,1,UInt32}:
1010
0
1111
1
1212
0
1313
1414
julia> onehot(:c, [:a, :b, :c])
15-
3-element Flux.OneHotVector:
15+
3-element Flux.OneHotArray{UInt32,3,0,1,UInt32}:
1616
0
1717
0
1818
1
@@ -44,7 +44,7 @@ Flux.onecold
4444
julia> using Flux: onehotbatch
4545
4646
julia> onehotbatch([:b, :a, :b], [:a, :b, :c])
47-
3×3 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
47+
3×3 Flux.OneHotArray{UInt32,3,1,2,Array{UInt32,1}}:
4848
0 1 0
4949
1 0 1
5050
0 0 0

src/onehot.jl

Lines changed: 62 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,61 @@
1-
import Base: *
1+
import Adapt
2+
import .CUDA
23

3-
struct OneHotVector <: AbstractVector{Bool}
4-
ix::UInt32
5-
of::UInt32
4+
struct OneHotArray{T<:Integer, L, N, var"N+1", I<:Union{T, AbstractArray{T, N}}} <: AbstractArray{Bool, var"N+1"}
5+
indices::I
66
end
7+
OneHotArray{T, L, N, I}(indices) where {T, L, N, I} = OneHotArray{T, L, N, N+1, I}(indices)
8+
OneHotArray(indices::T, L::Integer) where {T<:Integer} = OneHotArray{T, L, 0, T}(indices)
9+
OneHotArray(indices::AbstractArray{T, N}, L::Integer) where {T, N} = OneHotArray{T, L, N, typeof(indices)}(indices)
710

8-
Base.size(xs::OneHotVector) = (Int64(xs.of),)
11+
_indices(x::OneHotArray) = x.indices
912

10-
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
13+
const OneHotVector{T, L} = OneHotArray{T, L, 0, 1, T}
14+
const OneHotMatrix{T, L, I} = OneHotArray{T, L, 1, 2, I}
1115

12-
Base.getindex(xs::OneHotVector, ::Colon) = OneHotVector(xs.ix, xs.of)
16+
OneHotVector(idx, L) = OneHotArray(idx, L)
17+
OneHotMatrix(indices, L) = OneHotArray(indices, L)
1318

14-
function Base.:*(A::AbstractMatrix, b::OneHotVector)
15-
if size(A, 2) != b.of
16-
throw(DimensionMismatch("Matrix column must correspond with OneHotVector size"))
17-
end
18-
return A[:, b.ix]
19-
end
19+
Base.size(x::OneHotArray{<:Any, L}) where L = (Int(L), size(x.indices)...)
2020

21-
struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
22-
height::Int
23-
data::A
24-
end
21+
_onehotindex(x, i) = (x == i)
2522

26-
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
23+
Base.getindex(x::OneHotVector, i::Integer) = _onehotindex(x.indices, i)
24+
Base.getindex(x::OneHotVector{T, L}, ::Colon) where {T, L} = x
2725

28-
Base.getindex(xs::OneHotMatrix, i::Union{Integer, AbstractVector}, j::Integer) = xs.data[j][i]
29-
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
30-
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
31-
Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = OneHotMatrix(xs.height, copy(xs.data))
26+
Base.getindex(x::OneHotArray, i::Integer, I...) = _onehotindex.(x.indices[I...], i)
27+
Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L = OneHotArray(x.indices[I...], L)
28+
Base.getindex(x::OneHotArray{<:Any, <:Any, <:Any, N}, ::Vararg{Colon, N}) where N = x
29+
Base.getindex(x::OneHotArray, I::CartesianIndex{N}) where N = x[I[1], Tuple(I)[2:N]...]
3230

33-
Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data)
31+
_onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N}
32+
_onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N}
33+
34+
function Base.cat(xs::OneHotArray{<:Any, L}...; dims::Int) where L
35+
if isone(dims)
36+
return throw(ArgumentError("Cannot concat OneHotArray along first dimension. Use collect to convert to Bool array first."))
37+
else
38+
return OneHotArray(cat(_indices.(xs)...; dims = dims - 1), L)
39+
end
40+
end
3441

35-
# remove workaround when https://github.com/JuliaGPU/CuArrays.jl/issues/676 is fixed
36-
A::AbstractMatrix * B::OneHotMatrix = A[:, cpu(map(x->x.ix, B.data))]
42+
Base.hcat(xs::OneHotArray...) = cat(xs...; dims = 2)
43+
Base.vcat(xs::OneHotArray...) = cat(xs...; dims = 1)
3744

38-
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
45+
Base.reshape(x::OneHotArray{<:Any, L}, dims::Dims) where L =
46+
(first(dims) == L) ? OneHotArray(reshape(x.indices, dims[2:end]...), L) :
47+
throw(ArgumentError("Cannot reshape OneHotArray if first(dims) != size(x, 1)"))
48+
Base._reshape(x::OneHotArray, dims::Tuple{Vararg{Int}}) = reshape(x, dims)
3949

40-
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
50+
batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotArray(_indices.(xs), L)
4151

42-
import Adapt: adapt, adapt_structure
52+
Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, x.indices), L)
4353

44-
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
54+
Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}}) where N = CUDA.CuArrayStyle{N}()
4555

46-
import .CUDA: CuArray, CuArrayStyle, cudaconvert
47-
import Base.Broadcast: BroadcastStyle, ArrayStyle
48-
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = CuArrayStyle{2}()
49-
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
56+
Base.argmax(x::OneHotArray; dims = Colon()) =
57+
(dims == 1) ? reshape(CartesianIndex.(x.indices, CartesianIndices(x.indices)), 1, size(x.indices)...) :
58+
argmax(convert(_onehot_bool_type(x), x); dims = dims)
5059

5160
"""
5261
onehot(l, labels[, unk])
@@ -60,13 +69,13 @@ If `l` is not found in labels and `unk` is present, the function returns
6069
# Examples
6170
```jldoctest
6271
julia> Flux.onehot(:b, [:a, :b, :c])
63-
3-element Flux.OneHotVector:
72+
3-element Flux.OneHotArray{UInt32,3,0,1,UInt32}:
6473
0
6574
1
6675
0
6776
6877
julia> Flux.onehot(:c, [:a, :b, :c])
69-
3-element Flux.OneHotVector:
78+
3-element Flux.OneHotArray{UInt32,3,0,1,UInt32}:
7079
0
7180
0
7281
1
@@ -75,13 +84,13 @@ julia> Flux.onehot(:c, [:a, :b, :c])
7584
function onehot(l, labels)
7685
i = something(findfirst(isequal(l), labels), 0)
7786
i > 0 || error("Value $l is not in labels")
78-
OneHotVector(i, length(labels))
87+
OneHotVector{UInt32, length(labels)}(i)
7988
end
8089

8190
function onehot(l, labels, unk)
8291
i = something(findfirst(isequal(l), labels), 0)
8392
i > 0 || return onehot(unk, labels)
84-
OneHotVector(i, length(labels))
93+
OneHotVector{UInt32, length(labels)}(i)
8594
end
8695

8796
"""
@@ -95,16 +104,13 @@ return [`onehot(unk, labels)`](@ref) ; otherwise the function will raise an erro
95104
# Examples
96105
```jldoctest
97106
julia> Flux.onehotbatch([:b, :a, :b], [:a, :b, :c])
98-
3×3 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
107+
3×3 Flux.OneHotArray{UInt32,3,1,2,Array{UInt32,1}}:
99108
0 1 0
100109
1 0 1
101110
0 0 0
102111
```
103112
"""
104-
onehotbatch(ls, labels, unk...) =
105-
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
106-
107-
Base.argmax(xs::OneHotVector) = xs.ix
113+
onehotbatch(ls, labels, unk...) = batch([onehot(l, labels, unk...) for l in ls])
108114

109115
"""
110116
onecold(y[, labels = 1:length(y)])
@@ -120,11 +126,20 @@ julia> Flux.onecold([0.3, 0.2, 0.5], [:a, :b, :c])
120126
:c
121127
```
122128
"""
123-
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
129+
onecold(y::AbstractVector, labels = 1:length(y)) = labels[argmax(y)]
130+
function onecold(y::AbstractArray, labels = 1:size(y, 1))
131+
indices = _fast_argmax(y)
132+
xs = isbits(labels) ? indices : collect(indices) # non-bit type cannot be handled by CUDA
124133

125-
onecold(y::AbstractMatrix, labels...) =
126-
dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1)
134+
return map(xi -> labels[xi[1]], xs)
135+
end
127136

128-
onecold(y::OneHotMatrix, labels...) = map(x -> Flux.onecold(x, labels...), y.data)
137+
_fast_argmax(x::AbstractArray) = dropdims(argmax(x; dims = 1); dims = 1)
138+
_fast_argmax(x::OneHotArray) = x.indices
129139

130-
@nograd onecold, onehot, onehotbatch
140+
@nograd OneHotArray, onecold, onehot, onehotbatch
141+
142+
function Base.:(*)(A::AbstractMatrix, B::OneHotArray{<:Any, L}) where L
143+
size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
144+
return A[:, onecold(B)]
145+
end

test/cuda/cuda.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ using LinearAlgebra: I, cholesky, Cholesky
1313

1414
x = Flux.onehotbatch([1, 2, 3], 1:3)
1515
cx = gpu(x)
16-
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray
16+
@test cx isa Flux.OneHotMatrix && cx.indices isa CuArray
1717
@test (cx .+ 1) isa CuArray
1818

1919
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
@@ -40,8 +40,10 @@ end
4040

4141
@testset "onecold gpu" begin
4242
y = Flux.onehotbatch(ones(3), 1:10) |> gpu;
43+
l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']
4344
@test Flux.onecold(y) isa CuArray
4445
@test y[3,:] isa CuArray
46+
@test Flux.onecold(y, l) == ['a', 'a', 'a']
4547
end
4648

4749
@testset "restructure gpu" begin

test/onehot.jl

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,83 @@ end
2727

2828
@testset "abstractmatrix onehotvector multiplication" begin
2929
A = [1 3 5; 2 4 6; 3 6 9]
30-
b1 = Flux.OneHotVector(1,3)
31-
b2 = Flux.OneHotVector(3,5)
30+
b1 = Flux.OneHotVector(1, 3)
31+
b2 = Flux.OneHotVector(3, 5)
3232

3333
@test A*b1 == A[:,1]
3434
@test_throws DimensionMismatch A*b2
35+
end
36+
37+
@testset "OneHotArray" begin
38+
using Flux: OneHotArray, OneHotVector, OneHotMatrix
39+
40+
ov = OneHotVector(rand(1:10), 10)
41+
om = OneHotMatrix(rand(1:10, 5), 10)
42+
oa = OneHotArray(rand(1:10, 5, 5), 10)
43+
44+
# sizes
45+
@testset "Base.size" begin
46+
@test size(ov) == (10,)
47+
@test size(om) == (10, 5)
48+
@test size(oa) == (10, 5, 5)
49+
end
50+
51+
@testset "Indexing" begin
52+
# vector indexing
53+
@test ov[3] == (ov.indices == 3)
54+
@test ov[:] == ov
55+
56+
# matrix indexing
57+
@test om[3, 3] == (om.indices[3] == 3)
58+
@test om[:, 3] == OneHotVector(om.indices[3], 10)
59+
@test om[3, :] == (om.indices .== 3)
60+
@test om[:, :] == om
61+
62+
# array indexing
63+
@test oa[3, 3, 3] == (oa.indices[3, 3] == 3)
64+
@test oa[:, 3, 3] == OneHotVector(oa.indices[3, 3], 10)
65+
@test oa[3, :, 3] == (oa.indices[:, 3] .== 3)
66+
@test oa[3, :, :] == (oa.indices .== 3)
67+
@test oa[:, 3, :] == OneHotMatrix(oa.indices[3, :], 10)
68+
@test oa[:, :, :] == oa
69+
70+
# cartesian indexing
71+
@test oa[CartesianIndex(3, 3, 3)] == oa[3, 3, 3]
72+
end
73+
74+
@testset "Concatenating" begin
75+
# vector cat
76+
@test hcat(ov, ov) == OneHotMatrix(vcat(ov.indices, ov.indices), 10)
77+
@test_throws ArgumentError vcat(ov, ov)
78+
@test cat(ov, ov; dims = 3) == OneHotArray(cat(ov.indices, ov.indices; dims = 2), 10)
79+
80+
# matrix cat
81+
@test hcat(om, om) == OneHotMatrix(vcat(om.indices, om.indices), 10)
82+
@test_throws ArgumentError vcat(om, om)
83+
@test cat(om, om; dims = 3) == OneHotArray(cat(om.indices, om.indices; dims = 2), 10)
84+
85+
# array cat
86+
@test cat(oa, oa; dims = 3) == OneHotArray(cat(oa.indices, oa.indices; dims = 2), 10)
87+
@test_throws ArgumentError cat(oa, oa; dims = 1)
88+
end
89+
90+
@testset "Base.reshape" begin
91+
# reshape test
92+
@test reshape(oa, 10, 25) isa OneHotArray
93+
@test reshape(oa, 10, :) isa OneHotArray
94+
@test reshape(oa, :, 25) isa OneHotArray
95+
@test_throws ArgumentError reshape(oa, 50, :)
96+
@test_throws ArgumentError reshape(oa, 5, 10, 5)
97+
@test reshape(oa, (10, 25)) isa OneHotArray
98+
end
99+
100+
@testset "Base.argmax" begin
101+
# argmax test
102+
@test argmax(ov) == argmax(convert(Array{Bool}, ov))
103+
@test argmax(om) == argmax(convert(Array{Bool}, om))
104+
@test argmax(om; dims = 1) == argmax(convert(Array{Bool}, om); dims = 1)
105+
@test argmax(om; dims = 2) == argmax(convert(Array{Bool}, om); dims = 2)
106+
@test argmax(oa; dims = 1) == argmax(convert(Array{Bool}, oa); dims = 1)
107+
@test argmax(oa; dims = 3) == argmax(convert(Array{Bool}, oa); dims = 3)
108+
end
35109
end

0 commit comments

Comments
 (0)