Skip to content

Commit aa63a5a

Browse files
committed
Make constructors backwards compatible and throw error on vcat.
1 parent ef1ecb3 commit aa63a5a

File tree

2 files changed

+26
-26
lines changed

2 files changed

+26
-26
lines changed

src/onehot.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@ struct OneHotArray{T<:Integer, L, N, var"N+1", I<:Union{T, AbstractArray{T, N}}}
55
indices::I
66
end
77
OneHotArray{T, L, N, I}(indices) where {T, L, N, I} = OneHotArray{T, L, N, N+1, I}(indices)
8-
OneHotArray(L::Integer, indices::T) where {T<:Integer} = OneHotArray{T, L, 0, T}(indices)
9-
OneHotArray(L::Integer, indices::AbstractArray{T, N}) where {T, N} = OneHotArray{T, L, N, typeof(indices)}(indices)
8+
OneHotArray(indices::T, L::Integer) where {T<:Integer} = OneHotArray{T, L, 0, T}(indices)
9+
OneHotArray(indices::AbstractArray{T, N}, L::Integer) where {T, N} = OneHotArray{T, L, N, typeof(indices)}(indices)
1010

1111
_indices(x::OneHotArray) = x.indices
1212

1313
const OneHotVector{T, L} = OneHotArray{T, L, 0, 1, T}
1414
const OneHotMatrix{T, L, I} = OneHotArray{T, L, 1, 2, I}
1515

16-
OneHotVector(L, idx) = OneHotArray(L, idx)
17-
OneHotMatrix(L, indices) = OneHotArray(L, indices)
16+
OneHotVector(idx, L) = OneHotArray(idx, L)
17+
OneHotMatrix(indices, L) = OneHotArray(indices, L)
1818

1919
Base.size(x::OneHotArray{<:Any, L}) where L = (Int(L), size(x.indices)...)
2020

@@ -24,7 +24,7 @@ Base.getindex(x::OneHotVector, i::Integer) = _onehotindex(x.indices, i)
2424
Base.getindex(x::OneHotVector{T, L}, ::Colon) where {T, L} = x
2525

2626
Base.getindex(x::OneHotArray, i::Integer, I...) = _onehotindex.(x.indices[I...], i)
27-
Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L = OneHotArray(L, x.indices[I...])
27+
Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L = OneHotArray(x.indices[I...], L)
2828
Base.getindex(x::OneHotArray{<:Any, <:Any, <:Any, N}, ::Vararg{Colon, N}) where N = x
2929
Base.getindex(x::OneHotArray, I::CartesianIndex{N}) where N = x[I[1], Tuple(I)[2:N]...]
3030

@@ -33,23 +33,23 @@ _onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}) where N = C
3333

3434
function Base.cat(xs::OneHotArray{<:Any, L}...; dims::Int) where L
3535
if isone(dims)
36-
return cat(map(x -> convert(_onehot_bool_type(x), x), xs)...; dims = 1)
36+
return throw(ArgumentError("Cannot concat OneHotArray along first dimension. Use collect to convert to Bool array first."))
3737
else
38-
return OneHotArray(L, cat(_indices.(xs)...; dims = dims - 1))
38+
return OneHotArray(cat(_indices.(xs)...; dims = dims - 1), L)
3939
end
4040
end
4141

4242
Base.hcat(xs::OneHotArray...) = cat(xs...; dims = 2)
4343
Base.vcat(xs::OneHotArray...) = cat(xs...; dims = 1)
4444

4545
Base.reshape(x::OneHotArray{<:Any, L}, dims::Dims) where L =
46-
(first(dims) == L) ? OneHotArray(L, reshape(x.indices, dims[2:end]...)) :
46+
(first(dims) == L) ? OneHotArray(reshape(x.indices, dims[2:end]...), L) :
4747
throw(ArgumentError("Cannot reshape OneHotArray if first(dims) != size(x, 1)"))
4848
Base._reshape(x::OneHotArray, dims::Tuple{Vararg{Int}}) = reshape(x, dims)
4949

50-
batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotArray(L, _indices.(xs))
50+
batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotArray(_indices.(xs), L)
5151

52-
Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(L, adapt(T, x.indices))
52+
Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, x.indices), L)
5353

5454
Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}}) where N = CUDA.CuArrayStyle{N}()
5555

test/onehot.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ end
2727

2828
@testset "abstractmatrix onehotvector multiplication" begin
2929
A = [1 3 5; 2 4 6; 3 6 9]
30-
b1 = Flux.OneHotVector{eltype(A), 3}(1)
31-
b2 = Flux.OneHotVector{eltype(A), 5}(3)
30+
b1 = Flux.OneHotVector(1, 3)
31+
b2 = Flux.OneHotVector(3, 5)
3232

3333
@test A*b1 == A[:,1]
3434
@test_throws DimensionMismatch A*b2
@@ -37,9 +37,9 @@ end
3737
@testset "OneHotArray" begin
3838
using Flux: OneHotArray, OneHotVector, OneHotMatrix
3939

40-
ov = OneHotVector(10, rand(1:10))
41-
om = OneHotMatrix(10, rand(1:10, 5))
42-
oa = OneHotArray(10, rand(1:10, 5, 5))
40+
ov = OneHotVector(rand(1:10), 10)
41+
om = OneHotMatrix(rand(1:10, 5), 10)
42+
oa = OneHotArray(rand(1:10, 5, 5), 10)
4343

4444
# sizes
4545
@testset "Base.size" begin
@@ -55,16 +55,16 @@ end
5555

5656
# matrix indexing
5757
@test om[3, 3] == (om.indices[3] == 3)
58-
@test om[:, 3] == OneHotVector(10, om.indices[3])
58+
@test om[:, 3] == OneHotVector(om.indices[3], 10)
5959
@test om[3, :] == (om.indices .== 3)
6060
@test om[:, :] == om
6161

6262
# array indexing
6363
@test oa[3, 3, 3] == (oa.indices[3, 3] == 3)
64-
@test oa[:, 3, 3] == OneHotVector(10, oa.indices[3, 3])
64+
@test oa[:, 3, 3] == OneHotVector(oa.indices[3, 3], 10)
6565
@test oa[3, :, 3] == (oa.indices[:, 3] .== 3)
6666
@test oa[3, :, :] == (oa.indices .== 3)
67-
@test oa[:, 3, :] == OneHotMatrix(10, oa.indices[3, :])
67+
@test oa[:, 3, :] == OneHotMatrix(oa.indices[3, :], 10)
6868
@test oa[:, :, :] == oa
6969

7070
# cartesian indexing
@@ -73,18 +73,18 @@ end
7373

7474
@testset "Concatenating" begin
7575
# vector cat
76-
@test hcat(ov, ov) == OneHotMatrix(10, vcat(ov.indices, ov.indices))
77-
@test vcat(ov, ov) == vcat(convert(Array{Bool}, ov), convert(Array{Bool}, ov))
78-
@test cat(ov, ov; dims = 3) == OneHotArray(10, cat(ov.indices, ov.indices; dims = 2))
76+
@test hcat(ov, ov) == OneHotMatrix(vcat(ov.indices, ov.indices), 10)
77+
@test_throws ArgumentError vcat(ov, ov)
78+
@test cat(ov, ov; dims = 3) == OneHotArray(cat(ov.indices, ov.indices; dims = 2), 10)
7979

8080
# matrix cat
81-
@test hcat(om, om) == OneHotMatrix(10, vcat(om.indices, om.indices))
82-
@test vcat(om, om) == vcat(convert(Array{Bool}, om), convert(Array{Bool}, om))
83-
@test cat(om, om; dims = 3) == OneHotArray(10, cat(om.indices, om.indices; dims = 2))
81+
@test hcat(om, om) == OneHotMatrix(vcat(om.indices, om.indices), 10)
82+
@test_throws ArgumentError vcat(om, om)
83+
@test cat(om, om; dims = 3) == OneHotArray(cat(om.indices, om.indices; dims = 2), 10)
8484

8585
# array cat
86-
@test cat(oa, oa; dims = 3) == OneHotArray(10, cat(oa.indices, oa.indices; dims = 2))
87-
@test cat(oa, oa; dims = 1) == cat(convert(Array{Bool}, oa), convert(Array{Bool}, oa); dims = 1)
86+
@test cat(oa, oa; dims = 3) == OneHotArray(cat(oa.indices, oa.indices; dims = 2), 10)
87+
@test_throws ArgumentError cat(oa, oa; dims = 1)
8888
end
8989

9090
@testset "Base.reshape" begin

0 commit comments

Comments
 (0)