Skip to content

Commit 5ad06ac

Browse files
committed
Faster broadcast.
1 parent 53886a6 commit 5ad06ac

File tree

2 files changed

+40
-29
lines changed

2 files changed

+40
-29
lines changed

src/custom_collections/CatVector.jl

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,36 @@ end
66
Base.eltype(vec::CatVector) = eltype(eltype(vec.vecs))
77

88
# Note: getindex and setindex are pretty naive.
9-
Base.@propagate_inbounds function Base.getindex(vec::CatVector, i::Int)
10-
@boundscheck checkbounds(vec, i)
11-
I = 1
9+
Base.@propagate_inbounds function Base.getindex(vec::CatVector, index::Int)
10+
@boundscheck checkbounds(vec, index)
11+
i = 1
12+
j = index
1213
@inbounds while true
13-
subvec = vec.vecs[I]
14+
subvec = vec.vecs[i]
1415
l = length(subvec)
15-
if i <= l
16-
return subvec[eachindex(subvec)[i]]
16+
if j <= l
17+
return subvec[eachindex(subvec)[j]]
1718
else
18-
i -= l
19-
I += 1
19+
j -= l
20+
i += 1
2021
end
2122
end
2223
error()
2324
end
2425

25-
Base.@propagate_inbounds function Base.setindex!(vec::CatVector, val, i::Int)
26-
@boundscheck checkbounds(vec, i)
27-
I = 1
26+
Base.@propagate_inbounds function Base.setindex!(vec::CatVector, val, index::Int)
27+
@boundscheck checkbounds(vec, index)
28+
i = 1
29+
j = index
2830
while true
29-
subvec = vec.vecs[I]
31+
subvec = vec.vecs[i]
3032
l = length(subvec)
31-
if i <= l
32-
subvec[eachindex(subvec)[i]] = val
33+
if j <= l
34+
subvec[eachindex(subvec)[j]] = val
3335
return val
3436
else
35-
i -= l
36-
I += 1
37+
j -= l
38+
i += 1
3739
end
3840
end
3941
error()
@@ -83,22 +85,18 @@ end
8385
return dest
8486
end
8587

86-
Base.@propagate_inbounds catvec_broadcast_getindex(vec::CatVector, i::Int, j::Int, k::Int) = vec.vecs[i][j]
87-
Base.@propagate_inbounds catvec_broadcast_getindex(x, i::Int, j::Int, k::Int) = Broadcast._broadcast_getindex(x, i)
88+
Base.@propagate_inbounds catvec_broadcast_vec(x::CatVector, k::Int) = x.vecs[k]
89+
Base.@propagate_inbounds catvec_broadcast_vec(x::Number, k::Int) = x
8890

8991
@inline function Base.copyto!(dest::CatVector, bc::Broadcast.Broadcasted{Nothing})
9092
flat = Broadcast.flatten(bc)
91-
index = 1
92-
dest_vecs = dest.vecs
93-
@boundscheck check_cat_vectors_line_up(dest, bc.args...)
94-
@inbounds for i in eachindex(dest_vecs)
95-
vec = dest_vecs[i]
96-
for j in eachindex(vec)
97-
k = axes(flat)[1][index]
98-
let f = flat.f, args = flat.args, i = i, j = j, k = k
99-
vec[j] = Broadcast._broadcast_getindex_evalf(f, map(arg -> catvec_broadcast_getindex(arg, i, j, k), args)...)
100-
end
101-
index += 1
93+
@boundscheck check_cat_vectors_line_up(dest, flat.args...)
94+
@inbounds for i in eachindex(dest.vecs)
95+
let i = i, f = flat.f, args = flat.args
96+
dest′ = catvec_broadcast_vec(dest, i)
97+
args′ = map(arg -> catvec_broadcast_vec(arg, i), args)
98+
axes′ = (eachindex(dest′),)
99+
copyto!(dest′, Broadcast.Broadcasted{Nothing}(f, args′, axes′))
102100
end
103101
end
104102
return dest

test/test_custom_collections.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,5 +208,18 @@ Base.axes(m::NonOneBasedMatrix) = ((1:m.m) .- 2, (1:m.n) .+ 1)
208208
y5 = similar(y)
209209
map!(+, y5, y, y)
210210
@test Vector(y5) == Vector(y) + Vector(y)
211+
212+
z = similar(y)
213+
rand!(z)
214+
yvec = Vector(y)
215+
zvec = Vector(z)
216+
217+
z .= muladd.(1e-3, y, z)
218+
zvec .= muladd.(1e-3, yvec, zvec)
219+
@test zvec == z
220+
allocs = let y=y, z=z
221+
@allocated z .= muladd.(1e-3, y, z)
222+
end
223+
@test allocs == 0
211224
end
212225
end

0 commit comments

Comments
 (0)