Skip to content

Commit 7d3053a

Browse files
committed
Fix ndrange splitting
1 parent 98d83c1 commit 7d3053a

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/util.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,14 +145,14 @@ end
145145
# Call the generic kernel that is defined below, which only calls a function with
146146
# the global GPU index.
147147
n_gpus = length(CUDA.devices())
148-
ndrange_local = [div(ndrange, n_gpus) for _ in 1:n_gpus]
149-
ndrange_local[end] += ndrange % n_gpus
148+
indices_split = Iterators.partition(indices, ceil(Int, length(indices) / n_gpus))
149+
@assert length(indices_split) == n_gpus
150150

151-
@sync for i in 1:n_gpus
151+
@sync for (i, indices_) in enumerate(indices_split)
152152
Threads.@spawn begin
153153
CUDA.device!(i - 1)
154-
generic_kernel(backend)(ndrange = ndrange_local[i]) do j
155-
@inbounds @inline f(iterator[indices[j]])
154+
generic_kernel(backend)(ndrange = length(indices_)) do j
155+
@inbounds @inline f(iterator[indices_[j]])
156156
end
157157
KernelAbstractions.synchronize(backend)
158158
end

0 commit comments

Comments
 (0)