Skip to content

Commit 02b22f0

Browse files
committed
Improve readability
1 parent 843116c commit 02b22f0

File tree

1 file changed

+132
-29
lines changed

1 file changed

+132
-29
lines changed

src/nhs_grid.jl

Lines changed: 132 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -420,79 +420,184 @@ end
420420
return nothing
421421
end
422422

423+
@inline function copy_to_localmem!(local_points, local_neighbor_coords,
424+
neighbor_cell, neighbor_system_coords,
425+
neighborhood_search, particleidx)
426+
points_view = points_in_cell(neighbor_cell, neighborhood_search)
427+
n_particles_in_neighbor_cell = length(points_view)
428+
429+
# First use all threads to load the neighbors into local memory in parallel
430+
if particleidx <= n_particles_in_neighbor_cell
431+
@inbounds p = local_points[particleidx] = points_view[particleidx]
432+
for d in 1:ndims(neighborhood_search)
433+
@inbounds local_neighbor_coords[d, particleidx] = neighbor_system_coords[d, p]
434+
end
435+
end
436+
return n_particles_in_neighbor_cell
437+
end
438+
439+
# @parallel(block) for cell in cells
440+
# for neighbor_cell in neighboring_cells
441+
# @parallel(thread) for neighbor in neighbor_cell
442+
# copy_coordinates_to_localmem(neighbor)
443+
#
444+
# # Make sure all threads finished the copying
445+
# @synchronize
446+
#
447+
# @parallel(thread) for particle in cell
448+
# for neighbor in neighbor_cell
449+
# # This uses the neighbor coordinates from the local memory
450+
# compute(point, neighbor)
451+
#
452+
# # Make sure all threads finished computing before we continue with copying
453+
# @synchronize
423454
@kernel cpu=false function foreach_neighbor_localmem(f::F, system_coords, neighbor_system_coords,
424455
neighborhood_search, cells, ::Val{MAX}, search_radius) where {F, MAX}
425456
cell_ = @index(Group)
426457
cell = @inbounds Tuple(cells[cell_])
427458
particleidx = @index(Local)
428459
@assert 1 <= particleidx <= MAX
429460

461+
# Coordinate buffer in local memory
430462
local_points = @localmem Int32 MAX
431463
local_neighbor_coords = @localmem eltype(system_coords) (ndims(neighborhood_search), MAX)
432464

433-
next_local_points = @localmem Int32 MAX
434-
next_local_neighbor_coords = @localmem eltype(system_coords) (ndims(neighborhood_search), MAX)
465+
points = points_in_cell(cell, neighborhood_search)
466+
n_particles_in_current_cell = length(points)
435467

436-
pv = points_in_cell(cell, neighborhood_search)
437-
n_particles_in_current_cell = length(pv)
468+
# Extract point coordinates if a point lies on this thread
438469
if particleidx <= n_particles_in_current_cell
439-
point = @inbounds pv[particleidx]
470+
point = @inbounds points[particleidx]
440471
point_coords = @inbounds extract_svector(system_coords, Val(ndims(neighborhood_search)),
441472
point)
442473
else
443474
point = zero(Int32)
444475
point_coords = zero(SVector{ndims(neighborhood_search), eltype(system_coords)})
445476
end
446477

447-
@inline function stage!(local_points, local_neighbor_coords, neighbor_cell)
448-
points_view = points_in_cell(neighbor_cell, neighborhood_search)
449-
n_particles_in_neighbor_cell_ = length(points_view)
478+
for neighbor_cell_ in neighboring_cells(cell, neighborhood_search)
479+
neighbor_cell = Tuple(neighbor_cell_)
480+
481+
n_particles_in_neighbor_cell = copy_to_localmem!(local_points, local_neighbor_coords,
482+
neighbor_cell, neighbor_system_coords,
483+
neighborhood_search, particleidx)
484+
485+
# Make sure all threads finished the copying
486+
@synchronize
450487

451-
# First use all threads to load the neighbors into local memory in parallel
452-
if particleidx <= n_particles_in_neighbor_cell_
453-
@inbounds p = local_points[particleidx] = points_view[particleidx]
454-
for d in 1:ndims(neighborhood_search)
455-
@inbounds local_neighbor_coords[d, particleidx] = neighbor_system_coords[d, p]
488+
# Now each thread works on one point again
489+
if particleidx <= n_particles_in_current_cell
490+
for local_neighbor in 1:n_particles_in_neighbor_cell
491+
@inbounds neighbor = local_points[local_neighbor]
492+
@inbounds neighbor_coords = extract_svector(local_neighbor_coords,
493+
Val(ndims(neighborhood_search)),
494+
local_neighbor)
495+
496+
pos_diff = point_coords - neighbor_coords
497+
distance2 = dot(pos_diff, pos_diff)
498+
499+
# TODO periodic
500+
501+
if distance2 <= search_radius^2
502+
distance = sqrt(distance2) # TODO: eventuell fastmath
503+
504+
# Inline to avoid loss of performance
505+
# compared to not using `foreach_point_neighbor`.
506+
@inline f(point, neighbor, pos_diff, distance)
507+
end
456508
end
457509
end
458-
return n_particles_in_neighbor_cell_
510+
511+
# Make sure all threads finished computing before we continue with copying
512+
@synchronize()
513+
end
514+
end
515+
516+
# @parallel(block) for cell in cells
517+
# @parallel(thread) for neighbor in first_neighbor_cell
518+
# copy_coordinates_to_localmem!(local_coords, neighbor)
519+
#
520+
# for neighbor_cell in neighboring_cells
521+
# @parallel(thread) for neighbor in neighbor_cell + 1
522+
# copy_coordinates_to_localmem!(next_local_coords, neighbor)
523+
#
524+
# # No synchronize needed. The following loop works on `local_coords`.
525+
#
526+
# @parallel(thread) for particle in cell
527+
# for neighbor in neighbor_cell
528+
# # This uses the neighbor coordinates from the local memory
529+
# compute(point, neighbor)
530+
#
531+
# # Make sure all threads finished computing before we switch variables
532+
# @synchronize
533+
# local_coords, next_local_coords = next_local_coords, local_coords
534+
@kernel cpu=false function foreach_neighbor_double_buffer(f::F, system_coords, neighbor_system_coords,
535+
neighborhood_search, cells, ::Val{MAX}, search_radius) where {F, MAX}
536+
cell_ = @index(Group)
537+
cell = @inbounds Tuple(cells[cell_])
538+
particleidx = @index(Local)
539+
@assert 1 <= particleidx <= MAX
540+
541+
# Coordinate buffer in local memory
542+
local_points = @localmem Int32 MAX
543+
local_neighbor_coords = @localmem eltype(system_coords) (ndims(neighborhood_search), MAX)
544+
545+
# Next coordinate buffer in local memory
546+
next_local_points = @localmem Int32 MAX
547+
next_local_neighbor_coords = @localmem eltype(system_coords) (ndims(neighborhood_search), MAX)
548+
549+
points = points_in_cell(cell, neighborhood_search)
550+
n_particles_in_current_cell = length(points)
551+
552+
# Extract point coordinates if a point lies on this thread
553+
if particleidx <= n_particles_in_current_cell
554+
point = @inbounds points[particleidx]
555+
point_coords = @inbounds extract_svector(system_coords, Val(ndims(neighborhood_search)),
556+
point)
557+
else
558+
point = zero(Int32)
559+
point_coords = zero(SVector{ndims(neighborhood_search), eltype(system_coords)})
459560
end
460561

461562
neighborhood = neighboring_cells(cell, neighborhood_search)
462563
# (neighbor_cell, state) = iterate(neighborhood)
463-
neighbor_cell = first(neighborhood)
564+
neighbor_cell = Tuple(first(neighborhood))
565+
566+
n_particles_in_neighbor_cell = copy_to_localmem!(local_points, local_neighbor_coords,
567+
neighbor_cell, neighbor_system_coords,
568+
neighborhood_search, particleidx)
464569

465-
n_particles_in_neighbor_cell = stage!(local_points, local_neighbor_coords, Tuple(neighbor_cell))
466570
@synchronize()
467571

468572
for neighbor_ in 1:length(neighborhood)
469-
neighbor_cell = @inbounds neighborhood[neighbor_]
573+
neighbor_cell = @inbounds Tuple(neighborhood[neighbor_])
470574

471575
# while true
472576
# next = iterate(neighborhood, state)
473577
# if next !== nothing
474-
# n_particles_in_neighbor_cell = stage!(local_points, local_neighbor_coords, Tuple(neighbor_cell))
475-
# @synchronize
578+
476579
if neighbor_ < length(neighborhood)
477-
next_neighbor_cell = neighborhood[neighbor_ + 1]
580+
next_neighbor_cell = @inbounds Tuple(neighborhood[neighbor_ + 1])
478581
# (next_neighbor_cell, state) = next
479-
next_n_particles_in_neighbor_cell = stage!(next_local_points, next_local_neighbor_coords, Tuple(next_neighbor_cell))
582+
next_n_particles_in_neighbor_cell = copy_to_localmem!(next_local_points, next_local_neighbor_coords,
583+
next_neighbor_cell, neighbor_system_coords,
584+
neighborhood_search, particleidx)
480585
end
481586

482587
# Now each thread works on one point again
483588
if particleidx <= n_particles_in_current_cell
484589
for local_neighbor in 1:n_particles_in_neighbor_cell
485590
@inbounds neighbor = local_points[local_neighbor]
486591
@inbounds neighbor_coords = extract_svector(local_neighbor_coords,
487-
Val(ndims(neighborhood_search)), local_neighbor)
592+
Val(ndims(neighborhood_search)),
593+
local_neighbor)
488594

489595
pos_diff = point_coords - neighbor_coords
490596
distance2 = dot(pos_diff, pos_diff)
491597

492598
# TODO periodic
493599

494600
if distance2 <= search_radius^2
495-
# KernelAbstractions.@print("Point $point, neighbor $neighbor with distance2 $distance2\n")
496601
distance = sqrt(distance2) # TODO: eventuell fastmath
497602

498603
# Inline to avoid loss of performance
@@ -501,17 +606,15 @@ end
501606
end
502607
end
503608
end
609+
504610
# next === nothing && break
505611
neighbor_ >= length(neighborhood) && break
506612
@synchronize()
613+
507614
# swap variables
508615
n_particles_in_neighbor_cell = next_n_particles_in_neighbor_cell
509-
temp = local_points
510-
local_points = next_local_points
511-
next_local_points = temp
512-
temp = local_neighbor_coords
513-
local_neighbor_coords = next_local_neighbor_coords
514-
next_local_neighbor_coords = temp
616+
local_points, next_local_points = next_local_points, local_points
617+
local_neighbor_coords, next_local_neighbor_coords = next_local_neighbor_coords, local_neighbor_coords
515618
end
516619
end
517620

0 commit comments

Comments
 (0)