You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
1448: Arbitrary dimension one-hot arrays r=DhairyaLGandhi a=darsnack
This supersedes #1447. It should address the same issues:
- fix#1445, #1229
- probably fix also #864, #556, #189
This PR introduces a new one-hot N-dimensional array type, `OneHotArray`. Like #1447, this approach avoids the pointer allocations associated with `OneHotMatrix` being an array of `OneHotVector`s. It also lifts the "height" into the type parameter to avoid unnecessary allocation. Unlike #1447, this approach does not introduce a new primitive type. Instead, a "one-hot vector" is represented with a single subtype of `Integer` that is configurable by the user. By default, the exposed API will use `UInt32`.
Fundamentally, the primitive type is necessary because wrapping a `UInt32` as a `OneHotVector` will suffer memory penalties when you create an `Array{<:OneHotVector}`. But if we begin by designing for N-dimensions, then `OneHotVector` is just the specialized 1D case (similar to how `Vector{T} = Array{T, 1}`).
## Performance
I compared against the same tests mentioned in #1447. Please suggest more if you want to.
1. #189
```jl
#master
julia> x = Flux.onehotbatch(rand(1:100, 50), 1:100);
julia> W = rand(128, 100);
julia> @Btime $W * $x;
5.095 μs (13 allocations: 50.86 KiB)
julia> cW, cx = cu(W), cu(x);
julia> @Btime $cW * $cx;
24.948 μs (86 allocations: 3.11 KiB)
#1447
julia> x = Flux.onehotbatch(rand(1:100, 50), 1:100);
julia> W = rand(128, 100);
julia> @Btime $W * $x;
5.312 μs (3 allocations: 50.36 KiB)
julia> cW, cx = cu(W), cu(x);
julia> @Btime $cW * $cx;
8.466 μs (61 allocations: 1.69 KiB)
# this PR
julia> x = Flux.onehotbatch(rand(1:100, 50), 1:100);
julia> W = rand(128, 100);
julia> @Btime $W * $x;
4.708 μs (3 allocations: 50.56 KiB)
julia> cW, cx = cu(W), cu(x);
julia> @Btime $cW * $cx;
8.576 μs (63 allocations: 1.73 KiB)
```
2. #556
```jl
#master
julia> valY = randn(1000, 128);
julia> @Btime Flux.onecold($valY);
365.712 μs (1131 allocations: 38.16 KiB)
julia> @Btime Flux.onecold($(gpu(valY)));
┌ Warning: Performing scalar operations on GPU arrays: This is very slow, consider disallowing these operations with `allowscalar(false)`
└ @ GPUArrays ~/.julia/packages/GPUArrays/jhRU7/src/host/indexing.jl:43
1.330 s (781248 allocations: 31.59 MiB)
#1447
julia> valY = randn(1000, 128);
julia> @Btime Flux.onecold($valY);
524.767 μs (8 allocations: 4.00 KiB)
julia> @Btime Flux.onecold($(gpu(valY)));
27.563 μs (169 allocations: 5.56 KiB)
# this PR
julia> valY = randn(1000, 128);
julia> @Btime Flux.onecold($valY);
493.017 μs (8 allocations: 4.53 KiB)
julia> @Btime Flux.onecold($(gpu(valY)));
26.702 μs (171 allocations: 5.61 KiB)
```
## Summary
This should basically be #1447 but simpler to maintain w/ fewer changes. Tests are passing, though I think we should add more tests for one-hot data (currently our test set seems pretty sparse). Performance matches #1447 where I have tested, but please suggest more performance tests. In theory, any performance difference between #1447 and this PR should be recoverable.
### PR Checklist
- [ ] Tests are added
- [ ] Entry in NEWS.md
- [ ] Documentation, if applicable
- [ ] Final review from @DhairyaLGandhi (for API changes).
cc @CarloLucibello@chengchingwen
Co-authored-by: Kyle Daruwalla <daruwalla@wisc.edu>
Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
0 commit comments