Skip to content

Commit 53886a6

Browse files
committed
Add broadcast for CatVector. Works, but it's pretty slow.
1 parent b43949b commit 53886a6

File tree

1 file changed

+28
-7
lines changed

1 file changed

+28
-7
lines changed

src/custom_collections/CatVector.jl

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,17 @@ end
5656
Base.similar(vec::CatVector) = CatVector(map(similar, vec.vecs))
5757
Base.similar(vec::CatVector, ::Type{T}) where {T} = CatVector(map(x -> similar(x, T), vec.vecs))
5858

59-
function check_cat_vectors_line_up(x::CatVector, ys::CatVector...)
60-
for j in eachindex(ys)
61-
y = ys[j]
62-
length(x.vecs) == length(y.vecs) || throw(ArgumentError("Subvectors must line up"))
63-
for i in eachindex(x.vecs)
64-
length(x.vecs[i]) == length(y.vecs[i]) || throw(ArgumentError("Subvectors must line up"))
65-
end
59+
@inline function check_cat_vectors_line_up(x::CatVector, y::CatVector)
60+
length(x.vecs) == length(y.vecs) || throw(ArgumentError("Subvectors must line up"))
61+
for i in eachindex(x.vecs)
62+
length(x.vecs[i]) == length(y.vecs[i]) || throw(ArgumentError("Subvectors must line up"))
6663
end
64+
nothing
6765
end
6866

67+
@inline check_cat_vectors_line_up(x::CatVector, y) = nothing
68+
@inline check_cat_vectors_line_up(x::CatVector, y, tail...) = (check_cat_vectors_line_up(x, y); check_cat_vectors_line_up(x, tail...))
69+
6970
@inline function Base.copyto!(dest::CatVector, src::CatVector)
7071
@boundscheck check_cat_vectors_line_up(dest, src)
7172
@inbounds for i in eachindex(dest.vecs)
@@ -82,3 +83,23 @@ end
8283
return dest
8384
end
8485

86+
Base.@propagate_inbounds catvec_broadcast_getindex(vec::CatVector, i::Int, j::Int, k::Int) = vec.vecs[i][j]
87+
Base.@propagate_inbounds catvec_broadcast_getindex(x, i::Int, j::Int, k::Int) = Broadcast._broadcast_getindex(x, i)
88+
89+
@inline function Base.copyto!(dest::CatVector, bc::Broadcast.Broadcasted{Nothing})
90+
flat = Broadcast.flatten(bc)
91+
index = 1
92+
dest_vecs = dest.vecs
93+
@boundscheck check_cat_vectors_line_up(dest, bc.args...)
94+
@inbounds for i in eachindex(dest_vecs)
95+
vec = dest_vecs[i]
96+
for j in eachindex(vec)
97+
k = axes(flat)[1][index]
98+
let f = flat.f, args = flat.args, i = i, j = j, k = k
99+
vec[j] = Broadcast._broadcast_getindex_evalf(f, map(arg -> catvec_broadcast_getindex(arg, i, j, k), args)...)
100+
end
101+
index += 1
102+
end
103+
end
104+
return dest
105+
end

0 commit comments

Comments
 (0)