Skip to content

Commit 40663ca

Browse files
authored
Improve ArrayInterface compatibility (#28)
* improve ArrayInterface compatibility * a bit safer implementation * one more restriction * restrict a test to newer versions of Julia * strides improvements
1 parent 130fd30 commit 40663ca

7 files changed

+90
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "HybridArrays"
22
uuid = "1baab800-613f-4b0a-84e4-9cd3431bfbb9"
33
authors = ["Mateusz Baran <mateuszbaran89@gmail.com>"]
4-
version = "0.4.0"
4+
version = "0.4.1"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/HybridArrays.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import Base: convert,
1010
getindex,
1111
dataids,
1212
promote_rule,
13+
pointer,
14+
strides,
1315
setindex!,
1416
size,
1517
length,

src/abstractarray.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11

22
Base.dataids(sa::HybridArray) = Base.dataids(sa.data)
33

4-
@inline size(sa::HybridArray{S}) where S = size(sa.data)
4+
@inline size(sa::HybridArray{S,T,N,N}) where {S,T,N} = size(sa.data)
55

66
@inline length(sa::HybridArray) = length(sa.data)
77

8+
@inline strides(sa::HybridArray{S,T,N,N}) where {S,T,N} = strides(sa.data)
9+
10+
@inline pointer(sa::HybridArray) = pointer(sa.data)
11+
812
@generated function _sized_abstract_array_axes(::Type{S}, ax::Tuple) where S<:Tuple
913
exprs = Any[]
1014
map(enumerate(S.parameters)) do (i, si)

src/array_interface_compat.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,47 @@ end
1414
function ArrayInterface.restructure(x::HybridArray{S}, y) where {S}
1515
return HybridArray{S}(reshape(convert(Array, y), size(x)...))
1616
end
17+
18+
function ArrayInterface.strides(x::HybridArray)
19+
return ArrayInterface.strides(x.data)
20+
end
21+
22+
@generated function ArrayInterface.strides(x::HybridArray{S,T,N,N,Array{T,N}}) where {S,T,N}
23+
collected_strides = []
24+
i = 1
25+
for (argnum, Sarg) in enumerate(S.parameters)
26+
if i > 0
27+
push!(collected_strides, ArrayInterface.StaticInt(i))
28+
else
29+
push!(collected_strides, :(datastrides[$argnum]))
30+
end
31+
if Sarg isa Integer
32+
i *= Sarg
33+
else
34+
i = -1
35+
end
36+
end
37+
return quote
38+
datastrides = strides(x.data)
39+
return tuple($(collected_strides...))
40+
end
41+
end
42+
43+
@generated function ArrayInterface.size(x::HybridArray{S}) where {S}
44+
collected_sizes = []
45+
for (argnum, Sarg) in enumerate(S.parameters)
46+
if Sarg isa Integer
47+
push!(collected_sizes, ArrayInterface.StaticInt(Sarg))
48+
else
49+
push!(collected_sizes, :(datasize[$argnum]))
50+
end
51+
end
52+
return quote
53+
datasize = ArrayInterface.size(x.data)
54+
return tuple($(collected_sizes...))
55+
end
56+
end
57+
58+
ArrayInterface.contiguous_axis(::Type{HybridArray{S,T,N,N,TData}}) where {S,T,N,TData} = ArrayInterface.contiguous_axis(TData)
59+
ArrayInterface.contiguous_batch_size(::Type{HybridArray{S,T,N,N,TData}}) where {S,T,N,TData} = ArrayInterface.contiguous_batch_size(TData)
60+
ArrayInterface.stride_rank(::Type{HybridArray{S,T,N,N,TData}}) where {S,T,N,TData} = ArrayInterface.stride_rank(TData)

src/convert.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,5 @@ end
6565
@inline function convert(::Type{HybridArray{S}}, a::TData) where {S,T,M,TData<:AbstractArray{T,M}}
6666
convert(HybridArray{S,T}, a)
6767
end
68+
69+
@inline Base.unsafe_convert(::Type{Ptr{T}}, A::HybridArray{S,T}) where {S,T} = pointer(A.data)

test/abstractarray.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,5 +78,27 @@ using StaticArrays, HybridArrays, Test, LinearAlgebra
7878
@test_throws TypeError HybridArrays.new_out_size_nongen(Size{Tuple{1,2}}, 'a')
7979
end
8080

81+
@testset "strides" begin
82+
M = HybridMatrix{2, StaticArrays.Dynamic(), Int}([1 2; 3 4])
83+
84+
@test strides(M) == strides(M.data)
85+
end
86+
87+
@testset "pointer" begin
88+
M = HybridMatrix{2, StaticArrays.Dynamic(), Int}([1 2; 3 4])
89+
MT = HybridMatrix{2, StaticArrays.Dynamic(), Int}([1 2; 3 4]')
90+
91+
@test pointer(M) == pointer(M.data)
92+
if VERSION >= v"1.5"
93+
# pointer on Adjoint is not available on earilier versions of Julia
94+
@test pointer(MT) == pointer(MT.data)
95+
end
96+
end
97+
98+
@testset "unsafe_convert" begin
99+
M = HybridMatrix{2, StaticArrays.Dynamic(), Int}([1 2; 3 4])
100+
@test Base.unsafe_convert(Ptr{Int}, M) === pointer(M.data)
101+
end
102+
81103
@test HybridArrays._totally_linear() === true
82104
end

test/array_interface_compat.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,25 @@
11

22
using HybridArrays, ArrayInterface, Test, StaticArrays
3+
using ArrayInterface: StaticInt
34

45
@testset "ArrayInterface compatibility" begin
56
M = HybridMatrix{2, StaticArrays.Dynamic()}([1 2; 4 5])
7+
MV = HybridMatrix{3, StaticArrays.Dynamic()}(view(randn(10, 10), 1:2:5, 1:3:12))
68
@test ArrayInterface.ismutable(M)
79
@test ArrayInterface.can_setindex(M)
810
@test ArrayInterface.parent_type(M) === Matrix{Int}
911
@test ArrayInterface.restructure(M, [2, 4, 6, 8]) == HybridMatrix{2, StaticArrays.Dynamic()}([2 6; 4 8])
1012
@test isa(ArrayInterface.restructure(M, [2, 4, 6, 8]), HybridMatrix{2, StaticArrays.Dynamic()})
13+
14+
M2 = HybridArray{Tuple{2, 3, StaticArrays.Dynamic(), StaticArrays.Dynamic()}}(randn(2, 3, 5, 7))
15+
@test (@inferred ArrayInterface.strides(M2)) === (StaticInt(1), StaticInt(2), StaticInt(6), 30)
16+
@test (@inferred ArrayInterface.strides(MV)) === (2, 30)
17+
@test (@inferred ArrayInterface.size(M2)) === (StaticInt(2), StaticInt(3), 5, 7)
18+
19+
@test ArrayInterface.contiguous_axis(typeof(M2)) === ArrayInterface.contiguous_axis(typeof(M2.data))
20+
@test ArrayInterface.contiguous_batch_size(typeof(M2)) === ArrayInterface.contiguous_batch_size(typeof(M2.data))
21+
@test ArrayInterface.stride_rank(typeof(M2)) === ArrayInterface.stride_rank(typeof(M2.data))
22+
@test ArrayInterface.contiguous_axis(typeof(M')) === ArrayInterface.contiguous_axis(typeof(M.data'))
23+
@test ArrayInterface.contiguous_batch_size(typeof(M')) === ArrayInterface.contiguous_batch_size(typeof(M.data'))
24+
@test ArrayInterface.stride_rank(typeof(M')) === ArrayInterface.stride_rank(typeof(M.data'))
1125
end

0 commit comments

Comments
 (0)