@@ -5,16 +5,16 @@ struct OneHotArray{T<:Integer, L, N, var"N+1", I<:Union{T, AbstractArray{T, N}}}
5
5
indices:: I
6
6
end
7
7
OneHotArray {T, L, N, I} (indices) where {T, L, N, I} = OneHotArray {T, L, N, N+1, I} (indices)
8
- OneHotArray (L :: Integer , indices :: T ) where {T<: Integer } = OneHotArray {T, L, 0, T} (indices)
9
- OneHotArray (L :: Integer , indices:: AbstractArray{T, N} ) where {T, N} = OneHotArray {T, L, N, typeof(indices)} (indices)
8
+ OneHotArray (indices :: T , L :: Integer ) where {T<: Integer } = OneHotArray {T, L, 0, T} (indices)
9
+ OneHotArray (indices:: AbstractArray{T, N} , L :: Integer ) where {T, N} = OneHotArray {T, L, N, typeof(indices)} (indices)
10
10
11
11
_indices (x:: OneHotArray ) = x. indices
12
12
13
13
const OneHotVector{T, L} = OneHotArray{T, L, 0 , 1 , T}
14
14
const OneHotMatrix{T, L, I} = OneHotArray{T, L, 1 , 2 , I}
15
15
16
- OneHotVector (L, idx ) = OneHotArray (L, idx )
17
- OneHotMatrix (L, indices ) = OneHotArray (L, indices )
16
+ OneHotVector (idx, L ) = OneHotArray (idx, L )
17
+ OneHotMatrix (indices, L ) = OneHotArray (indices, L )
18
18
19
19
Base. size (x:: OneHotArray{<:Any, L} ) where L = (Int (L), size (x. indices)... )
20
20
@@ -24,7 +24,7 @@ Base.getindex(x::OneHotVector, i::Integer) = _onehotindex(x.indices, i)
24
24
Base. getindex (x:: OneHotVector{T, L} , :: Colon ) where {T, L} = x
25
25
26
26
Base. getindex (x:: OneHotArray , i:: Integer , I... ) = _onehotindex .(x. indices[I... ], i)
27
- Base. getindex (x:: OneHotArray{<:Any, L} , :: Colon , I... ) where L = OneHotArray (L, x. indices[I... ])
27
+ Base. getindex (x:: OneHotArray{<:Any, L} , :: Colon , I... ) where L = OneHotArray (x. indices[I... ], L )
28
28
Base. getindex (x:: OneHotArray{<:Any, <:Any, <:Any, N} , :: Vararg{Colon, N} ) where N = x
29
29
Base. getindex (x:: OneHotArray , I:: CartesianIndex{N} ) where N = x[I[1 ], Tuple (I)[2 : N]. .. ]
30
30
@@ -33,23 +33,23 @@ _onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}) where N = C
33
33
34
34
function Base. cat (xs:: OneHotArray{<:Any, L} ...; dims:: Int ) where L
35
35
if isone (dims)
36
- return cat ( map (x -> convert ( _onehot_bool_type (x), x), xs) ... ; dims = 1 )
36
+ return throw ( ArgumentError ( " Cannot concat OneHotArray along first dimension. Use collect to convert to Bool array first. " ) )
37
37
else
38
- return OneHotArray (L, cat (_indices .(xs)... ; dims = dims - 1 ))
38
+ return OneHotArray (cat (_indices .(xs)... ; dims = dims - 1 ), L )
39
39
end
40
40
end
41
41
42
42
Base. hcat (xs:: OneHotArray... ) = cat (xs... ; dims = 2 )
43
43
Base. vcat (xs:: OneHotArray... ) = cat (xs... ; dims = 1 )
44
44
45
45
Base. reshape (x:: OneHotArray{<:Any, L} , dims:: Dims ) where L =
46
- (first (dims) == L) ? OneHotArray (L, reshape (x. indices, dims[2 : end ]. .. )) :
46
+ (first (dims) == L) ? OneHotArray (reshape (x. indices, dims[2 : end ]. .. ), L ) :
47
47
throw (ArgumentError (" Cannot reshape OneHotArray if first(dims) != size(x, 1)" ))
48
48
Base. _reshape (x:: OneHotArray , dims:: Tuple{Vararg{Int}} ) = reshape (x, dims)
49
49
50
- batch (xs:: AbstractArray{<:OneHotVector{<:Any, L}} ) where L = OneHotArray (L, _indices .(xs))
50
+ batch (xs:: AbstractArray{<:OneHotVector{<:Any, L}} ) where L = OneHotArray (_indices .(xs), L )
51
51
52
- Adapt. adapt_structure (T, x:: OneHotArray{<:Any, L} ) where L = OneHotArray (L, adapt (T, x. indices))
52
+ Adapt. adapt_structure (T, x:: OneHotArray{<:Any, L} ) where L = OneHotArray (adapt (T, x. indices), L )
53
53
54
54
Base. BroadcastStyle (:: Type{<:OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}} ) where N = CUDA. CuArrayStyle {N} ()
55
55
0 commit comments