Skip to content

Commit b43949b

Browse files
committed
Add copyto, map for CatVector.
1 parent 350896c commit b43949b

File tree

2 files changed

+38
-3
lines changed

2 files changed

+38
-3
lines changed

src/custom_collections/CatVector.jl

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
struct CatVector{T, N, I<:Tuple{Vararg{AbstractVector{T}, N}}} <: AbstractVector{T}
2-
vecs::I
1+
struct CatVector{T, N, V<:AbstractVector{T}} <: AbstractVector{T}
2+
vecs::NTuple{N, V}
33
end
44

55
@inline Base.size(vec::CatVector) = (mapreduce(length, +, vec.vecs; init=0),)
@@ -39,7 +39,7 @@ Base.@propagate_inbounds function Base.setindex!(vec::CatVector, val, i::Int)
3939
error()
4040
end
4141

42-
Base.@propagate_inbounds function Base.copyto!(dest::AbstractVector, src::CatVector)
42+
Base.@propagate_inbounds function Base.copyto!(dest::AbstractVector{T}, src::CatVector{T}) where {T}
4343
@boundscheck length(dest) == length(src) || throw(DimensionMismatch())
4444
dest_indices = eachindex(dest)
4545
k = 1
@@ -55,3 +55,30 @@ end
5555

5656
Base.similar(vec::CatVector) = CatVector(map(similar, vec.vecs))
5757
Base.similar(vec::CatVector, ::Type{T}) where {T} = CatVector(map(x -> similar(x, T), vec.vecs))
58+
59+
function check_cat_vectors_line_up(x::CatVector, ys::CatVector...)
60+
for j in eachindex(ys)
61+
y = ys[j]
62+
length(x.vecs) == length(y.vecs) || throw(ArgumentError("Subvectors must line up"))
63+
for i in eachindex(x.vecs)
64+
length(x.vecs[i]) == length(y.vecs[i]) || throw(ArgumentError("Subvectors must line up"))
65+
end
66+
end
67+
end
68+
69+
@inline function Base.copyto!(dest::CatVector, src::CatVector)
70+
@boundscheck check_cat_vectors_line_up(dest, src)
71+
@inbounds for i in eachindex(dest.vecs)
72+
copyto!(dest.vecs[i], src.vecs[i])
73+
end
74+
return dest
75+
end
76+
77+
@inline function Base.map!(f::F, dest::CatVector, args::CatVector...) where F
78+
@boundscheck check_cat_vectors_line_up(dest, args...)
79+
@inbounds for i in eachindex(dest.vecs)
80+
map!(f, dest.vecs[i], map(arg -> arg.vecs[i], args)...)
81+
end
82+
return dest
83+
end
84+

test/test_custom_collections.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,5 +200,13 @@ Base.axes(m::NonOneBasedMatrix) = ((1:m.m) .- 2, (1:m.n) .+ 1)
200200
@test length(y3.vecs[i]) == length(y.vecs[i])
201201
@test y3.vecs[i] !== y.vecs[i]
202202
end
203+
204+
y4 = similar(y)
205+
copyto!(y4, y)
206+
@test y4 == y
207+
208+
y5 = similar(y)
209+
map!(+, y5, y, y)
210+
@test Vector(y5) == Vector(y) + Vector(y)
203211
end
204212
end

0 commit comments

Comments
 (0)