|
56 | 56 | Base.similar(vec::CatVector) = CatVector(map(similar, vec.vecs))
|
57 | 57 | Base.similar(vec::CatVector, ::Type{T}) where {T} = CatVector(map(x -> similar(x, T), vec.vecs))
|
58 | 58 |
|
59 |
| -function check_cat_vectors_line_up(x::CatVector, ys::CatVector...) |
60 |
| - for j in eachindex(ys) |
61 |
| - y = ys[j] |
62 |
| - length(x.vecs) == length(y.vecs) || throw(ArgumentError("Subvectors must line up")) |
63 |
| - for i in eachindex(x.vecs) |
64 |
| - length(x.vecs[i]) == length(y.vecs[i]) || throw(ArgumentError("Subvectors must line up")) |
65 |
| - end |
| 59 | +@inline function check_cat_vectors_line_up(x::CatVector, y::CatVector) |
| 60 | + length(x.vecs) == length(y.vecs) || throw(ArgumentError("Subvectors must line up")) |
| 61 | + for i in eachindex(x.vecs) |
| 62 | + length(x.vecs[i]) == length(y.vecs[i]) || throw(ArgumentError("Subvectors must line up")) |
66 | 63 | end
|
| 64 | + nothing |
67 | 65 | end
|
68 | 66 |
|
| 67 | +@inline check_cat_vectors_line_up(x::CatVector, y) = nothing |
| 68 | +@inline check_cat_vectors_line_up(x::CatVector, y, tail...) = (check_cat_vectors_line_up(x, y); check_cat_vectors_line_up(x, tail...)) |
| 69 | + |
69 | 70 | @inline function Base.copyto!(dest::CatVector, src::CatVector)
|
70 | 71 | @boundscheck check_cat_vectors_line_up(dest, src)
|
71 | 72 | @inbounds for i in eachindex(dest.vecs)
|
|
82 | 83 | return dest
|
83 | 84 | end
|
84 | 85 |
|
| 86 | +Base.@propagate_inbounds catvec_broadcast_getindex(vec::CatVector, i::Int, j::Int, k::Int) = vec.vecs[i][j] |
| 87 | +Base.@propagate_inbounds catvec_broadcast_getindex(x, i::Int, j::Int, k::Int) = Broadcast._broadcast_getindex(x, i) |
| 88 | + |
| 89 | +@inline function Base.copyto!(dest::CatVector, bc::Broadcast.Broadcasted{Nothing}) |
| 90 | + flat = Broadcast.flatten(bc) |
| 91 | + index = 1 |
| 92 | + dest_vecs = dest.vecs |
| 93 | + @boundscheck check_cat_vectors_line_up(dest, bc.args...) |
| 94 | + @inbounds for i in eachindex(dest_vecs) |
| 95 | + vec = dest_vecs[i] |
| 96 | + for j in eachindex(vec) |
| 97 | + k = axes(flat)[1][index] |
| 98 | + let f = flat.f, args = flat.args, i = i, j = j, k = k |
| 99 | + vec[j] = Broadcast._broadcast_getindex_evalf(f, map(arg -> catvec_broadcast_getindex(arg, i, j, k), args)...) |
| 100 | + end |
| 101 | + index += 1 |
| 102 | + end |
| 103 | + end |
| 104 | + return dest |
| 105 | +end |
0 commit comments