Skip to content

Commit 5b025d4

Browse files
committed
Fix localmem kernel
1 parent 7d3053a commit 5b025d4

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

src/nhs_grid.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -414,14 +414,16 @@ end
414414
ndrange = max_particles_per_cell * length(nonempty_cells)
415415

416416
n_gpus = length(CUDA.devices())
417-
ndrange_local = [div(ndrange, n_gpus) for _ in 1:n_gpus]
418-
ndrange_local[end] += ndrange % n_gpus
417+
cells_split = Iterators.partition(nonempty_cells, ceil(Int, length(nonempty_cells) / n_gpus))
418+
@assert length(cells_split) == n_gpus
419419

420420
kernel = foreach_neighbor_localmem(backend, (max_particles_per_cell,))
421-
@sync for i in 1:n_gpus
421+
@sync for (i, nonempty_cells_) in enumerate(cells_split)
422422
Threads.@spawn begin
423423
CUDA.device!(i - 1)
424-
kernel(f, system_coords, neighbor_coords, neighborhood_search, nonempty_cells, Val(max_particles_per_cell), search_radius; ndrange = ndrange_local[i])
424+
kernel(f, system_coords, neighbor_coords, neighborhood_search, nonempty_cells_,
425+
Val(max_particles_per_cell), search_radius;
426+
ndrange = length(nonempty_cells_) * max_particles_per_cell)
425427
KernelAbstractions.synchronize(backend)
426428
end
427429
end

0 commit comments

Comments
 (0)