Skip to content

Commit 674a8fd

Browse files
committed
start on CategoricalArrays extension
1 parent 4e2d74a commit 674a8fd

File tree

4 files changed

+44
-2
lines changed

4 files changed

+44
-2
lines changed

Project.toml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,22 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1212

13+
[weakdeps]
14+
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
15+
16+
[extensions]
17+
OneHotArraysCategoricalArraysExt = "CategoricalArrays"
18+
1319
[compat]
1420
Adapt = "3.0, 4"
1521
CUDA = "4, 5"
22+
CategoricalArrays = "0.10.8"
1623
ChainRulesCore = "1.13"
1724
Compat = "4.2"
1825
GPUArraysCore = "0.1, 0.2"
1926
NNlib = "0.8, 0.9"
2027
Zygote = "0.6.35"
21-
julia = "1.6"
28+
julia = "1.10"
2229

2330
[extras]
2431
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
@@ -28,4 +35,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2835
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2936

3037
[targets]
31-
test = ["Test", "CUDA", "JLArrays", "Random", "Zygote"]
38+
test = ["Test", "CategoricalArrays", "CUDA", "JLArrays", "Random", "Zygote"]
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
module OneHotArraysCategoricalArraysExt
2+
3+
println("loading?")
4+
5+
using OneHotArrays, CategoricalArrays
6+
7+
OneHotArrays.OneHotArray(cv::CategoricalValue) = OneHotVector(cv.ref, length(cv.pool.levels))
8+
9+
OneHotArrays.OneHotArray(ca::CategoricalArray) = OneHotArray(ca.refs, length(ca.pool))
10+
11+
end # module

test/ext_categorical.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using Test, OneHotArrays, CategoricalArrays
2+
3+
@testset "CategoricalArrays -> OneHotArrays" begin
4+
cval = CategoricalArrays.CategoricalValue('b', CategoricalArray('a':'z'))
5+
6+
@test OneHotArray(cval) isa OneHotVector
7+
@test OneHotArray(cval) == (('a':'z') .== 'b')
8+
9+
@test_broken OneHotVector(cval) isa OneHotVector # surely if OneHotArray works, subtypes should too
10+
@test_broken convert(OneHotArray, cval) isa OneHotVector
11+
@test_broken onehot(cval) isa OneHotVector # possibly we should define this? Instead?
12+
13+
cvec = categorical(string.([:a, :b, :b, :c, :d, :e]))
14+
15+
@test OneHotArray(cvec) isa OneHotMatrix
16+
@test size(OneHotArray(cvec)) == (5, 6)
17+
@test onecold(OneHotArray(cvec)) == [1, 2, 2, 3, 4, 5]
18+
19+
@test_broken onehotbatch(cvec) isa OneHotMatrix # possibly we should define this? Instead?
20+
end

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ end
1414
include("linalg.jl")
1515
end
1616

17+
@testset "Extensions" begin
18+
include("ext_categorical.jl")
19+
end
20+
1721
using Zygote
1822
import CUDA
1923
if CUDA.functional()

0 commit comments

Comments
 (0)