Skip to content

Commit 98d83c1

Browse files
committed
Add prototype for multi-gpu
1 parent 80cf700 commit 98d83c1

File tree

5 files changed

+38
-6
lines changed

5 files changed

+38
-6
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "0.4.6-dev"
66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
9+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
910
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1011
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1112
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -16,6 +17,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1617
[compat]
1718
Adapt = "3, 4"
1819
Atomix = "0.1, 1"
20+
CUDA = "5.4.3"
1921
GPUArraysCore = "0.1, 0.2"
2022
KernelAbstractions = "0.9"
2123
LinearAlgebra = "1"

src/PointNeighbors.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using KernelAbstractions: KernelAbstractions, @kernel, @index, @localmem, @synch
1010
using LinearAlgebra: dot
1111
using Polyester: Polyester
1212
@reexport using StaticArrays: SVector
13+
using CUDA
1314

1415
include("util.jl")
1516
include("vector_of_vectors.jl")

src/gpu.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010
Adapt.@adapt_structure FullGridCellList
1111
Adapt.@adapt_structure DynamicVectorOfVectors
1212

13+
# TODO quick and dirty method to make all `CuArray`s unified memory
14+
function Adapt.adapt_structure(to::typeof(CuArray), array::Array)
15+
return CuArray{eltype(array), ndims(array), CUDA.UnifiedMemory}(array)
16+
end
17+
1318
# `adapt(CuArray, ::SVector)::SVector`, but `adapt(Array, ::SVector)::Vector`.
1419
# We don't want to change the type of the `SVector` here.
1520
function Adapt.adapt_structure(to::typeof(Array), svector::SVector)

src/nhs_grid.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,10 +412,22 @@ end
412412
# max_particles_per_cell = maximum(lengths)
413413
nonempty_cells = Adapt.adapt(backend, filter(index -> lengths[linear_indices[index]] > 0, cartesian_indices))
414414
ndrange = max_particles_per_cell * length(nonempty_cells)
415+
416+
n_gpus = length(CUDA.devices())
417+
ndrange_local = [div(ndrange, n_gpus) for _ in 1:n_gpus]
418+
ndrange_local[end] += ndrange % n_gpus
419+
415420
kernel = foreach_neighbor_localmem(backend, (max_particles_per_cell,))
416-
kernel(f, system_coords, neighbor_coords, neighborhood_search, nonempty_cells, Val(max_particles_per_cell), search_radius; ndrange)
421+
@sync for i in 1:n_gpus
422+
Threads.@spawn begin
423+
CUDA.device!(i - 1)
424+
kernel(f, system_coords, neighbor_coords, neighborhood_search, nonempty_cells, Val(max_particles_per_cell), search_radius; ndrange = ndrange_local[i])
425+
KernelAbstractions.synchronize(backend)
426+
end
427+
end
428+
# kernel(f, system_coords, neighbor_coords, neighborhood_search, nonempty_cells, Val(max_particles_per_cell), search_radius; ndrange)
417429

418-
KernelAbstractions.synchronize(backend)
430+
# KernelAbstractions.synchronize(backend)
419431

420432
return nothing
421433
end

src/util.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,23 @@ end
144144

145145
# Call the generic kernel that is defined below, which only calls a function with
146146
# the global GPU index.
147-
generic_kernel(backend)(ndrange = ndrange) do i
148-
@inbounds @inline f(iterator[indices[i]])
147+
n_gpus = length(CUDA.devices())
148+
ndrange_local = [div(ndrange, n_gpus) for _ in 1:n_gpus]
149+
ndrange_local[end] += ndrange % n_gpus
150+
151+
@sync for i in 1:n_gpus
152+
Threads.@spawn begin
153+
CUDA.device!(i - 1)
154+
generic_kernel(backend)(ndrange = ndrange_local[i]) do j
155+
@inbounds @inline f(iterator[indices[j]])
156+
end
157+
KernelAbstractions.synchronize(backend)
158+
end
149159
end
150-
151-
KernelAbstractions.synchronize(backend)
160+
# generic_kernel(backend)(ndrange = ndrange) do i
161+
# @inbounds @inline f(iterator[indices[i]])
162+
# end
163+
# KernelAbstractions.synchronize(backend)
152164
end
153165

154166
@kernel function generic_kernel(f)

0 commit comments

Comments
 (0)