@@ -420,79 +420,184 @@ end
420420 return nothing
421421end
422422
423+ @inline function copy_to_localmem! (local_points, local_neighbor_coords,
424+ neighbor_cell, neighbor_system_coords,
425+ neighborhood_search, particleidx)
426+ points_view = points_in_cell (neighbor_cell, neighborhood_search)
427+ n_particles_in_neighbor_cell = length (points_view)
428+
429+ # First use all threads to load the neighbors into local memory in parallel
430+ if particleidx <= n_particles_in_neighbor_cell
431+ @inbounds p = local_points[particleidx] = points_view[particleidx]
432+ for d in 1 : ndims (neighborhood_search)
433+ @inbounds local_neighbor_coords[d, particleidx] = neighbor_system_coords[d, p]
434+ end
435+ end
436+ return n_particles_in_neighbor_cell
437+ end
438+
439+ # @parallel(block) for cell in cells
440+ # for neighbor_cell in neighboring_cells
441+ # @parallel(thread) for neighbor in neighbor_cell
442+ # copy_coordinates_to_localmem(neighbor)
443+ #
444+ # # Make sure all threads finished the copying
445+ # @synchronize
446+ #
447+ # @parallel(thread) for particle in cell
448+ # for neighbor in neighbor_cell
449+ # # This uses the neighbor coordinates from the local memory
450+ # compute(point, neighbor)
451+ #
452+ # # Make sure all threads finished computing before we continue with copying
453+ # @synchronize
423454@kernel cpu= false function foreach_neighbor_localmem (f:: F , system_coords, neighbor_system_coords,
424455 neighborhood_search, cells, :: Val{MAX} , search_radius) where {F, MAX}
425456 cell_ = @index (Group)
426457 cell = @inbounds Tuple (cells[cell_])
427458 particleidx = @index (Local)
428459 @assert 1 <= particleidx <= MAX
429460
461+ # Coordinate buffer in local memory
430462 local_points = @localmem Int32 MAX
431463 local_neighbor_coords = @localmem eltype (system_coords) (ndims (neighborhood_search), MAX)
432464
433- next_local_points = @localmem Int32 MAX
434- next_local_neighbor_coords = @localmem eltype (system_coords) ( ndims (neighborhood_search), MAX )
465+ points = points_in_cell (cell, neighborhood_search)
466+ n_particles_in_current_cell = length (points )
435467
436- pv = points_in_cell (cell, neighborhood_search)
437- n_particles_in_current_cell = length (pv)
468+ # Extract point coordinates if a point lies on this thread
438469 if particleidx <= n_particles_in_current_cell
439- point = @inbounds pv [particleidx]
470+ point = @inbounds points [particleidx]
440471 point_coords = @inbounds extract_svector (system_coords, Val (ndims (neighborhood_search)),
441472 point)
442473 else
443474 point = zero (Int32)
444475 point_coords = zero (SVector{ndims (neighborhood_search), eltype (system_coords)})
445476 end
446477
447- @inline function stage! (local_points, local_neighbor_coords, neighbor_cell)
448- points_view = points_in_cell (neighbor_cell, neighborhood_search)
449- n_particles_in_neighbor_cell_ = length (points_view)
478+ for neighbor_cell_ in neighboring_cells (cell, neighborhood_search)
479+ neighbor_cell = Tuple (neighbor_cell_)
480+
481+ n_particles_in_neighbor_cell = copy_to_localmem! (local_points, local_neighbor_coords,
482+ neighbor_cell, neighbor_system_coords,
483+ neighborhood_search, particleidx)
484+
485+ # Make sure all threads finished the copying
486+ @synchronize
450487
451- # First use all threads to load the neighbors into local memory in parallel
452- if particleidx <= n_particles_in_neighbor_cell_
453- @inbounds p = local_points[particleidx] = points_view[particleidx]
454- for d in 1 : ndims (neighborhood_search)
455- @inbounds local_neighbor_coords[d, particleidx] = neighbor_system_coords[d, p]
488+ # Now each thread works on one point again
489+ if particleidx <= n_particles_in_current_cell
490+ for local_neighbor in 1 : n_particles_in_neighbor_cell
491+ @inbounds neighbor = local_points[local_neighbor]
492+ @inbounds neighbor_coords = extract_svector (local_neighbor_coords,
493+ Val (ndims (neighborhood_search)),
494+ local_neighbor)
495+
496+ pos_diff = point_coords - neighbor_coords
497+ distance2 = dot (pos_diff, pos_diff)
498+
499+ # TODO periodic
500+
501+ if distance2 <= search_radius^ 2
502+ distance = sqrt (distance2) # TODO : eventuell fastmath
503+
504+ # Inline to avoid loss of performance
505+ # compared to not using `foreach_point_neighbor`.
506+ @inline f (point, neighbor, pos_diff, distance)
507+ end
456508 end
457509 end
458- return n_particles_in_neighbor_cell_
510+
511+ # Make sure all threads finished computing before we continue with copying
512+ @synchronize ()
513+ end
514+ end
515+
516+ # @parallel(block) for cell in cells
517+ # @parallel(thread) for neighbor in first_neighbor_cell
518+ # copy_coordinates_to_localmem!(local_coords, neighbor)
519+ #
520+ # for neighbor_cell in neighboring_cells
521+ # @parallel(thread) for neighbor in neighbor_cell + 1
522+ # copy_coordinates_to_localmem!(next_local_coords, neighbor)
523+ #
524+ # # No synchronize needed. The following loop works on `local_coords`.
525+ #
526+ # @parallel(thread) for particle in cell
527+ # for neighbor in neighbor_cell
528+ # # This uses the neighbor coordinates from the local memory
529+ # compute(point, neighbor)
530+ #
531+ # # Make sure all threads finished computing before we switch variables
532+ # @synchronize
533+ # local_coords, next_local_coords = next_local_coords, local_coords
534+ @kernel cpu= false function foreach_neighbor_double_buffer (f:: F , system_coords, neighbor_system_coords,
535+ neighborhood_search, cells, :: Val{MAX} , search_radius) where {F, MAX}
536+ cell_ = @index (Group)
537+ cell = @inbounds Tuple (cells[cell_])
538+ particleidx = @index (Local)
539+ @assert 1 <= particleidx <= MAX
540+
541+ # Coordinate buffer in local memory
542+ local_points = @localmem Int32 MAX
543+ local_neighbor_coords = @localmem eltype (system_coords) (ndims (neighborhood_search), MAX)
544+
545+ # Next coordinate buffer in local memory
546+ next_local_points = @localmem Int32 MAX
547+ next_local_neighbor_coords = @localmem eltype (system_coords) (ndims (neighborhood_search), MAX)
548+
549+ points = points_in_cell (cell, neighborhood_search)
550+ n_particles_in_current_cell = length (points)
551+
552+ # Extract point coordinates if a point lies on this thread
553+ if particleidx <= n_particles_in_current_cell
554+ point = @inbounds points[particleidx]
555+ point_coords = @inbounds extract_svector (system_coords, Val (ndims (neighborhood_search)),
556+ point)
557+ else
558+ point = zero (Int32)
559+ point_coords = zero (SVector{ndims (neighborhood_search), eltype (system_coords)})
459560 end
460561
461562 neighborhood = neighboring_cells (cell, neighborhood_search)
462563 # (neighbor_cell, state) = iterate(neighborhood)
463- neighbor_cell = first (neighborhood)
564+ neighbor_cell = Tuple (first (neighborhood))
565+
566+ n_particles_in_neighbor_cell = copy_to_localmem! (local_points, local_neighbor_coords,
567+ neighbor_cell, neighbor_system_coords,
568+ neighborhood_search, particleidx)
464569
465- n_particles_in_neighbor_cell = stage! (local_points, local_neighbor_coords, Tuple (neighbor_cell))
466570 @synchronize ()
467571
468572 for neighbor_ in 1 : length (neighborhood)
469- neighbor_cell = @inbounds neighborhood[neighbor_]
573+ neighbor_cell = @inbounds Tuple ( neighborhood[neighbor_])
470574
471575 # while true
472576 # next = iterate(neighborhood, state)
473577 # if next !== nothing
474- # n_particles_in_neighbor_cell = stage!(local_points, local_neighbor_coords, Tuple(neighbor_cell))
475- # @synchronize
578+
476579 if neighbor_ < length (neighborhood)
477- next_neighbor_cell = neighborhood[neighbor_ + 1 ]
580+ next_neighbor_cell = @inbounds Tuple ( neighborhood[neighbor_ + 1 ])
478581 # (next_neighbor_cell, state) = next
479- next_n_particles_in_neighbor_cell = stage! (next_local_points, next_local_neighbor_coords, Tuple (next_neighbor_cell))
582+ next_n_particles_in_neighbor_cell = copy_to_localmem! (next_local_points, next_local_neighbor_coords,
583+ next_neighbor_cell, neighbor_system_coords,
584+ neighborhood_search, particleidx)
480585 end
481586
482587 # Now each thread works on one point again
483588 if particleidx <= n_particles_in_current_cell
484589 for local_neighbor in 1 : n_particles_in_neighbor_cell
485590 @inbounds neighbor = local_points[local_neighbor]
486591 @inbounds neighbor_coords = extract_svector (local_neighbor_coords,
487- Val (ndims (neighborhood_search)), local_neighbor)
592+ Val (ndims (neighborhood_search)),
593+ local_neighbor)
488594
489595 pos_diff = point_coords - neighbor_coords
490596 distance2 = dot (pos_diff, pos_diff)
491597
492598 # TODO periodic
493599
494600 if distance2 <= search_radius^ 2
495- # KernelAbstractions.@print("Point $point, neighbor $neighbor with distance2 $distance2\n")
496601 distance = sqrt (distance2) # TODO : eventuell fastmath
497602
498603 # Inline to avoid loss of performance
@@ -501,17 +606,15 @@ end
501606 end
502607 end
503608 end
609+
504610 # next === nothing && break
505611 neighbor_ >= length (neighborhood) && break
506612 @synchronize ()
613+
507614 # swap variables
508615 n_particles_in_neighbor_cell = next_n_particles_in_neighbor_cell
509- temp = local_points
510- local_points = next_local_points
511- next_local_points = temp
512- temp = local_neighbor_coords
513- local_neighbor_coords = next_local_neighbor_coords
514- next_local_neighbor_coords = temp
616+ local_points, next_local_points = next_local_points, local_points
617+ local_neighbor_coords, next_local_neighbor_coords = next_local_neighbor_coords, local_neighbor_coords
515618 end
516619end
517620
0 commit comments