Skip to content

Commit 1354395

Browse files
authored
Update ArrayInterface compatibility to v7 (#61)
1 parent b1e5e23 commit 1354395

8 files changed

+79
-72
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "HybridArrays"
22
uuid = "1baab800-613f-4b0a-84e4-9cd3431bfbb9"
33
authors = ["Mateusz Baran <mateuszbaran89@gmail.com>"]
4-
version = "0.4.14"
4+
version = "0.4.15"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -20,7 +20,8 @@ EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
2020
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2121
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2222
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
23+
StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718"
2324
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2425

2526
[targets]
26-
test = ["Test", "Random", "ArrayInterface", "EllipsisNotation", "ForwardDiff", "Static"]
27+
test = ["Test", "Random", "ArrayInterface", "EllipsisNotation", "ForwardDiff", "Static", "StaticArrayInterface"]

src/HybridArrays.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ function __init__()
184184
@require ArrayInterface="4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" begin
185185
include("array_interface_compat.jl")
186186
end
187+
@require StaticArrayInterface="0d7ed370-da01-4f52-bd93-41d350b8b718" begin
188+
include("static_array_interface_compat.jl")
189+
end
187190
end
188191

189192
end # module

src/abstractarray.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@ Base.dataids(sa::HybridArray) = Base.dataids(parent(sa))
55

66
@inline Base.elsize(sa::HybridArray) = Base.elsize(parent(sa))
77

8-
@inline size(sa::HybridArray{S,T,N,N}) where {S,T,N} = size(parent(sa))
8+
@inline size(sa::HybridArray{S,T,N,N}) where {S<:Tuple,T,N} = size(parent(sa))
99

1010
@inline length(sa::HybridArray) = length(parent(sa))
1111

12-
@inline strides(sa::HybridArray{S,T,N,N}) where {S,T,N} = strides(parent(sa))
12+
@inline strides(sa::HybridArray{S,T,N,N}) where {S<:Tuple,T,N} = strides(parent(sa))
1313

1414
@inline pointer(sa::HybridArray) = pointer(parent(sa))
1515

16-
@generated function _sized_abstract_array_axes(::Type{S}, ax::Tuple) where S<:Tuple
16+
@generated function _sized_abstract_array_axes(::Type{S}, ax::Tuple) where {S<:Tuple}
1717
exprs = Any[]
1818
map(enumerate(S.parameters)) do (i, si)
1919
if isa(si, Dynamic)
@@ -25,18 +25,18 @@ Base.dataids(sa::HybridArray) = Base.dataids(parent(sa))
2525
return Expr(:tuple, exprs...)
2626
end
2727

28-
function axes(sa::HybridArray{S}) where S
28+
function axes(sa::HybridArray{S}) where {S<:Tuple}
2929
ax = axes(parent(sa))
3030
return _sized_abstract_array_axes(S, ax)
3131
end
3232

3333

34-
function promote_rule(::Type{<:HybridArray{S,T,N,M,TDataA}}, ::Type{<:HybridArray{S,U,N,M,TDataB}}) where {S,T,U,N,M,TDataA,TDataB}
34+
function promote_rule(::Type{<:HybridArray{S,T,N,M,TDataA}}, ::Type{<:HybridArray{S,U,N,M,TDataB}}) where {S<:Tuple,T,U,N,M,TDataA,TDataB}
3535
TU = promote_type(T,U)
3636
HybridArray{S,TU,N,M,promote_type(TDataA, TDataB)::Type{<:AbstractArray{TU}}}
3737
end
3838

39-
@inline copy(a::HybridArray{S, T, N, M}) where {S, T, N, M} = begin
39+
@inline copy(a::HybridArray{S, T, N, M}) where {S<:Tuple, T, N, M} = begin
4040
parentcopy = copy(parent(a))
4141
HybridArray{S, T, N, M, typeof(parentcopy)}(parentcopy)
4242
end

src/array_interface_compat.jl

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -14,49 +14,3 @@ end
1414
function ArrayInterface.restructure(x::HybridArray{S}, y) where {S}
1515
return HybridArray{S}(reshape(convert(Array, y), size(x)...))
1616
end
17-
18-
function ArrayInterface.strides(x::HybridArray)
19-
return ArrayInterface.strides(parent(x))
20-
end
21-
22-
@generated function ArrayInterface.strides(x::HybridArray{S,T,N,N,Array{T,N}}) where {S,T,N}
23-
collected_strides = []
24-
i = 1
25-
for (argnum, Sarg) in enumerate(S.parameters)
26-
if i > 0
27-
push!(collected_strides, ArrayInterface.StaticInt(i))
28-
else
29-
push!(collected_strides, :(datastrides[$argnum]))
30-
end
31-
if Sarg isa Integer
32-
i *= Sarg
33-
else
34-
i = -1
35-
end
36-
end
37-
return quote
38-
datastrides = strides(parent(x))
39-
return tuple($(collected_strides...))
40-
end
41-
end
42-
43-
@generated function ArrayInterface.size(x::HybridArray{S}) where {S}
44-
collected_sizes = []
45-
for (argnum, Sarg) in enumerate(S.parameters)
46-
if Sarg isa Integer
47-
push!(collected_sizes, ArrayInterface.StaticInt(Sarg))
48-
else
49-
push!(collected_sizes, :(datasize[$argnum]))
50-
end
51-
end
52-
return quote
53-
datasize = ArrayInterface.size(parent(x))
54-
return tuple($(collected_sizes...))
55-
end
56-
end
57-
58-
ArrayInterface.contiguous_axis(::Type{HybridArray{S,T,N,N,TData}}) where {S,T,N,TData} = ArrayInterface.contiguous_axis(TData)
59-
ArrayInterface.contiguous_batch_size(::Type{HybridArray{S,T,N,N,TData}}) where {S,T,N,TData} = ArrayInterface.contiguous_batch_size(TData)
60-
ArrayInterface.stride_rank(::Type{HybridArray{S,T,N,N,TData}}) where {S,T,N,TData} = ArrayInterface.stride_rank(TData)
61-
62-
ArrayInterface.dense_dims(x::HybridArray) = ArrayInterface.dense_dims(x.data)

src/static_array_interface_compat.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
function StaticArrayInterface.strides(x::HybridArray)
2+
return StaticArrayInterface.strides(parent(x))
3+
end
4+
5+
@generated function StaticArrayInterface.strides(x::HybridArray{S,T,N,N,Array{T,N}}) where {S<:Tuple,T,N}
6+
collected_strides = []
7+
i = 1
8+
for (argnum, Sarg) in enumerate(S.parameters)
9+
if i > 0
10+
push!(collected_strides, StaticArrayInterface.StaticInt(i))
11+
else
12+
push!(collected_strides, :(datastrides[$argnum]))
13+
end
14+
if Sarg isa Integer
15+
i *= Sarg
16+
else
17+
i = -1
18+
end
19+
end
20+
return quote
21+
datastrides = strides(parent(x))
22+
return tuple($(collected_strides...))
23+
end
24+
end
25+
26+
@generated function StaticArrayInterface.size(x::HybridArray{S}) where {S}
27+
collected_sizes = []
28+
for (argnum, Sarg) in enumerate(S.parameters)
29+
if Sarg isa Integer
30+
push!(collected_sizes, StaticArrayInterface.StaticInt(Sarg))
31+
else
32+
push!(collected_sizes, :(datasize[$argnum]))
33+
end
34+
end
35+
return quote
36+
datasize = StaticArrayInterface.size(parent(x))
37+
return tuple($(collected_sizes...))
38+
end
39+
end
40+
41+
StaticArrayInterface.contiguous_axis(::Type{HybridArray{S,T,N,N,TData}}) where {S,T,N,TData} = StaticArrayInterface.contiguous_axis(TData)
42+
StaticArrayInterface.contiguous_batch_size(::Type{HybridArray{S,T,N,N,TData}}) where {S,T,N,TData} = StaticArrayInterface.contiguous_batch_size(TData)
43+
StaticArrayInterface.stride_rank(::Type{HybridArray{S,T,N,N,TData}}) where {S,T,N,TData} = StaticArrayInterface.stride_rank(TData)
44+
45+
StaticArrayInterface.dense_dims(x::HybridArray) = StaticArrayInterface.dense_dims(x.data)

test/array_interface_compat.jl

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11

22
using HybridArrays, ArrayInterface, Test, StaticArrays
3-
using ArrayInterface: StaticInt
4-
using Static
53

64
@testset "ArrayInterface compatibility" begin
75
M = HybridMatrix{2, StaticArrays.Dynamic()}([1 2; 4 5])
@@ -11,17 +9,4 @@ using Static
119
@test ArrayInterface.parent_type(M) === Matrix{Int}
1210
@test ArrayInterface.restructure(M, [2, 4, 6, 8]) == HybridMatrix{2, StaticArrays.Dynamic()}([2 6; 4 8])
1311
@test isa(ArrayInterface.restructure(M, [2, 4, 6, 8]), HybridMatrix{2, StaticArrays.Dynamic()})
14-
15-
M2 = HybridArray{Tuple{2, 3, StaticArrays.Dynamic(), StaticArrays.Dynamic()}}(randn(2, 3, 5, 7))
16-
@test (@inferred ArrayInterface.strides(M2)) === (StaticInt(1), StaticInt(2), StaticInt(6), 30)
17-
@test (@inferred ArrayInterface.strides(MV)) === (2, 30)
18-
@test (@inferred ArrayInterface.size(M2)) === (StaticInt(2), StaticInt(3), 5, 7)
19-
20-
@test ArrayInterface.contiguous_axis(typeof(M2)) === ArrayInterface.contiguous_axis(typeof(parent(M2)))
21-
@test ArrayInterface.contiguous_batch_size(typeof(M2)) === ArrayInterface.contiguous_batch_size(typeof(parent(M2)))
22-
@test ArrayInterface.stride_rank(typeof(M2)) === ArrayInterface.stride_rank(typeof(parent(M2)))
23-
@test ArrayInterface.contiguous_axis(typeof(M')) === ArrayInterface.contiguous_axis(typeof(parent(M)'))
24-
@test ArrayInterface.contiguous_batch_size(typeof(M')) === ArrayInterface.contiguous_batch_size(typeof(parent(M)'))
25-
@test ArrayInterface.stride_rank(typeof(M')) === ArrayInterface.stride_rank(typeof(parent(M)'))
26-
@test ArrayInterface.dense_dims(M) === (static(true), static(true))
2712
end

test/runtests.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ include("broadcast.jl")
162162
include("linalg.jl")
163163
include("ssubarray.jl")
164164
include("nonstandard_indices.jl")
165-
if VERSION >= v"1.2"
166-
include("array_interface_compat.jl")
167-
end
165+
166+
include("array_interface_compat.jl")
167+
include("static_array_interface_compat.jl")
168168
include("forwarddiff.jl")

test/static_array_interface_compat.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
2+
using HybridArrays, StaticArrayInterface, Test, StaticArrays
3+
using StaticArrayInterface: StaticInt
4+
5+
@testset "StaticArrayInterface compatibility" begin
6+
M = HybridMatrix{2, StaticArrays.Dynamic()}([1 2; 4 5])
7+
MV = HybridMatrix{3, StaticArrays.Dynamic()}(view(randn(10, 10), 1:2:5, 1:3:12))
8+
9+
M2 = HybridArray{Tuple{2, 3, StaticArrays.Dynamic(), StaticArrays.Dynamic()}}(randn(2, 3, 5, 7))
10+
@test (@inferred StaticArrayInterface.strides(M2)) === (StaticInt(1), StaticInt(2), StaticInt(6), 30)
11+
@test (@inferred StaticArrayInterface.strides(MV)) === (2, 30)
12+
13+
@test StaticArrayInterface.contiguous_axis(typeof(M2)) === StaticArrayInterface.contiguous_axis(typeof(parent(M2)))
14+
@test StaticArrayInterface.contiguous_batch_size(typeof(M2)) === StaticArrayInterface.contiguous_batch_size(typeof(parent(M2)))
15+
@test StaticArrayInterface.stride_rank(typeof(M2)) === StaticArrayInterface.stride_rank(typeof(parent(M2)))
16+
@test StaticArrayInterface.contiguous_axis(typeof(M')) === StaticArrayInterface.contiguous_axis(typeof(parent(M)'))
17+
@test StaticArrayInterface.contiguous_batch_size(typeof(M')) === StaticArrayInterface.contiguous_batch_size(typeof(parent(M)'))
18+
@test StaticArrayInterface.stride_rank(typeof(M')) === StaticArrayInterface.stride_rank(typeof(parent(M)'))
19+
end

0 commit comments

Comments
 (0)