Skip to content

Commit a7d3567

Browse files
committed
Broadcast improvements.
1 parent ac27a70 commit a7d3567

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

src/custom_collections/CatVector.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,18 +95,22 @@ end
9595
return dest
9696
end
9797

98-
Base.@propagate_inbounds catvec_broadcast_vec(x::CatVector, k::Int) = x.vecs[k]
99-
Base.@propagate_inbounds catvec_broadcast_vec(x::Number, k::Int) = x
98+
Base.@propagate_inbounds catvec_broadcast_vec(arg::CatVector, range::UnitRange, k::Int) = arg.vecs[k]
99+
Base.@propagate_inbounds catvec_broadcast_vec(arg::AbstractVector, range::UnitRange, k::Int) = view(arg, range)
100+
Base.@propagate_inbounds catvec_broadcast_vec(arg::Number, range::UnitRange, k::Int) = arg
100101

101102
@inline function Base.copyto!(dest::CatVector, bc::Broadcast.Broadcasted{Nothing})
102103
flat = Broadcast.flatten(bc)
103104
@boundscheck check_cat_vectors_line_up(dest, flat.args...)
105+
offset = 1
104106
@inbounds for i in eachindex(dest.vecs)
105107
let i = i, f = flat.f, args = flat.args
106-
dest′ = catvec_broadcast_vec(dest, i)
107-
args′ = map(arg -> catvec_broadcast_vec(arg, i), args)
108+
dest′ = dest.vecs[i]
109+
range = offset : offset + length(dest′) - 1
110+
args′ = map(arg -> catvec_broadcast_vec(arg, range, i), args)
108111
axes′ = (eachindex(dest′),)
109112
copyto!(dest′, Broadcast.Broadcasted{Nothing}(f, args′, axes′))
113+
offset = last(range) + 1
110114
end
111115
end
112116
return dest

test/test_custom_collections.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -175,13 +175,13 @@ Base.axes(m::NonOneBasedMatrix) = ((1:m.m) .- 2, (1:m.n) .+ 1)
175175
for i in eachindex(x)
176176
@test x[i] == y[i]
177177
end
178-
179-
x .= 0
180-
for i in eachindex(y)
181-
x[i] = y[i]
182-
end
183178
@test x == y
184179

180+
y .= 0
181+
rand!(x)
182+
y .= x .+ y .+ 1
183+
@test x .+ 1 == y
184+
185185
allocs = let x=x, vecs=vecs
186186
@allocated copyto!(x, RigidBodyDynamics.CatVector(vecs))
187187
end

0 commit comments

Comments
 (0)