From 1df4f21321b2234b29e2662b9a27f548ff71c037 Mon Sep 17 00:00:00 2001 From: tharittk Date: Sun, 17 Aug 2025 04:07:34 -0500 Subject: [PATCH 01/27] Buffered Send/Recv --- pylops_mpi/DistributedArray.py | 61 +++++++++++++++++++++++----------- 1 file changed, 42 insertions(+), 19 deletions(-) diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index 979882c0..18470592 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -3,6 +3,7 @@ from typing import Any, List, Optional, Tuple, Union, NewType import numpy as np +import os from mpi4py import MPI from pylops.utils import DTypeLike, NDArray from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils @@ -21,6 +22,10 @@ NcclCommunicatorType = NewType("NcclCommunicator", NcclCommunicator) +if int(os.environ.get("PYLOPS_MPI_CUDA_AWARE", 0)): + is_cuda_aware_mpi = True +else: + is_cuda_aware_mpi = False class Partition(Enum): r"""Enum class @@ -529,34 +534,52 @@ def _allgather_subcomm(self, send_buf, recv_buf=None): return self.sub_comm.allgather(send_buf) self.sub_comm.Allgather(send_buf, recv_buf) - def _send(self, send_buf, dest, count=None, tag=None): - """ Send operation + def _send(self, send_buf, dest, count=None, tag=0): + """Send operation """ if deps.nccl_enabled and self.base_comm_nccl: if count is None: - # assuming sending the whole array count = send_buf.size nccl_send(self.base_comm_nccl, send_buf, dest, count) else: - self.base_comm.send(send_buf, dest, tag) - - def _recv(self, recv_buf=None, source=0, count=None, tag=None): - """ Receive operation - """ - # NCCL must be called with recv_buf. Size cannot be inferred from - # other arguments and thus cannot be dynamically allocated - if deps.nccl_enabled and self.base_comm_nccl and recv_buf is not None: - if recv_buf is not None: + if is_cuda_aware_mpi or self.engine == "numpy": + # Determine MPI type based on array dtype + mpi_type = MPI._typedict[send_buf.dtype.char] if count is None: - # assuming data will take a space of the whole buffer - count = recv_buf.size - nccl_recv(self.base_comm_nccl, recv_buf, source, count) - return recv_buf + count = send_buf.size + self.base_comm.Send([send_buf, count, mpi_type], dest=dest, tag=tag) else: - raise ValueError("Using recv with NCCL must also supply receiver buffer ") + # Uses CuPy without CUDA-aware MPI + self.base_comm.send(send_buf, dest, tag) + + + def _recv(self, recv_buf=None, source=0, count=None, tag=0): + """Receive operation + """ + if deps.nccl_enabled and self.base_comm_nccl: + if recv_buf is None: + raise ValueError("recv_buf must be supplied when using NCCL") + if count is None: + count = recv_buf.size + nccl_recv(self.base_comm_nccl, recv_buf, source, count) + return recv_buf else: - # MPI allows a receiver buffer to be optional and receives as a Python Object - return self.base_comm.recv(source=source, tag=tag) + # NumPy + MPI will benefit from buffered communication regardless of MPI installation + if is_cuda_aware_mpi or self.engine == "numpy": + ncp = get_module(self.engine) + if recv_buf is None: + if count is None: + raise ValueError("Must provide either recv_buf or count for MPI receive") + # Default to int32 works currently because add_ghost_cells() is called + # with recv_buf and is not affected by this branch. The int32 is for when + # dimension or shape-related integers are send/recv + recv_buf = ncp.zeros(count, dtype=ncp.int32) + mpi_type = MPI._typedict[recv_buf.dtype.char] + self.base_comm.Recv([recv_buf, recv_buf.size, mpi_type], source=source, tag=tag) + else: + # Uses CuPy without CUDA-aware MPI + recv_buf = self.base_comm.recv(source=source, tag=tag) + return recv_buf def _nccl_local_shapes(self, masked: bool): """Get the the list of shapes of every GPU in the communicator From 647ce658a149a48f12d8c7f938056a05cba6414e Mon Sep 17 00:00:00 2001 From: tharittk Date: Sun, 17 Aug 2025 05:50:26 -0500 Subject: [PATCH 02/27] Buffered Allreduce --- pylops_mpi/DistributedArray.py | 53 +++++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 17 deletions(-) diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index 18470592..dd9fd508 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -483,11 +483,19 @@ def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): if deps.nccl_enabled and getattr(self, "base_comm_nccl"): return nccl_allreduce(self.base_comm_nccl, send_buf, recv_buf, op) else: - if recv_buf is None: - return self.base_comm.allreduce(send_buf, op) - # For MIN and MAX which require recv_buf - self.base_comm.Allreduce(send_buf, recv_buf, op) - return recv_buf + if is_cuda_aware_mpi or self.engine == "numpy": + ncp = get_module(self.engine) + # mpi_type = MPI._typedict[send_buf.dtype.char] + recv_buf = ncp.zeros(send_buf.size, dtype=send_buf.dtype) + self.base_comm.Allreduce(send_buf, recv_buf, op) + return recv_buf + else: + # CuPy with non-CUDA-aware MPI + if recv_buf is None: + return self.base_comm.allreduce(send_buf, op) + # For MIN and MAX which require recv_buf + self.base_comm.Allreduce(send_buf, recv_buf, op) + return recv_buf def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): """Allreduce operation with subcommunicator @@ -495,11 +503,19 @@ def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): if deps.nccl_enabled and getattr(self, "base_comm_nccl"): return nccl_allreduce(self.sub_comm, send_buf, recv_buf, op) else: - if recv_buf is None: - return self.sub_comm.allreduce(send_buf, op) - # For MIN and MAX which require recv_buf - self.sub_comm.Allreduce(send_buf, recv_buf, op) - return recv_buf + if is_cuda_aware_mpi or self.engine == "numpy": + ncp = get_module(self.engine) + # mpi_type = MPI._typedict[send_buf.dtype.char] + recv_buf = ncp.zeros(send_buf.size, dtype=send_buf.dtype) + self.sub_comm.Allreduce(send_buf, recv_buf, op) + return recv_buf + else: + # CuPy with non-CUDA-aware MPI + if recv_buf is None: + return self.sub_comm.allreduce(send_buf, op) + # For MIN and MAX which require recv_buf + self.sub_comm.Allreduce(send_buf, recv_buf, op) + return recv_buf def _allgather(self, send_buf, recv_buf=None): """Allgather operation @@ -717,26 +733,29 @@ def _compute_vector_norm(self, local_array: NDArray, recv_buf = self._allreduce_subcomm(ncp.count_nonzero(local_array, axis=axis).astype(ncp.float64)) elif ord == ncp.inf: # Calculate max followed by max reduction - # TODO (tharitt): currently CuPy + MPI does not work well with buffered communication, particularly + # CuPy + non-CUDA-aware MPI does not work well with buffered communication, particularly # with MAX, MIN operator. Here we copy the array back to CPU, transfer, and copy them back to GPUs send_buf = ncp.max(ncp.abs(local_array), axis=axis).astype(ncp.float64) - if self.engine == "cupy" and self.base_comm_nccl is None: + if self.engine == "cupy" and self.base_comm_nccl is None and not is_cuda_aware_mpi: + # CuPy + non-CUDA-aware MPI: This will call non-buffered communication + # which return a list of object - must be copied back to a GPU memory. recv_buf = self._allreduce_subcomm(send_buf.get(), recv_buf.get(), op=MPI.MAX) recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) else: recv_buf = self._allreduce_subcomm(send_buf, recv_buf, op=MPI.MAX) - recv_buf = ncp.squeeze(recv_buf, axis=axis) + if self.base_comm_nccl: + recv_buf = ncp.squeeze(recv_buf, axis=axis) elif ord == -ncp.inf: # Calculate min followed by min reduction - # TODO (tharitt): see the comment above in infinity norm + # See the comment above in +infinity norm send_buf = ncp.min(ncp.abs(local_array), axis=axis).astype(ncp.float64) - if self.engine == "cupy" and self.base_comm_nccl is None: + if self.engine == "cupy" and self.base_comm_nccl is None and not is_cuda_aware_mpi: recv_buf = self._allreduce_subcomm(send_buf.get(), recv_buf.get(), op=MPI.MIN) recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) else: recv_buf = self._allreduce_subcomm(send_buf, recv_buf, op=MPI.MIN) - recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) - + if self.base_comm_nccl: + recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) else: recv_buf = self._allreduce_subcomm(ncp.sum(ncp.abs(ncp.float_power(local_array, ord)), axis=axis)) recv_buf = ncp.power(recv_buf, 1.0 / ord) From 31068f9b65cb3429483646ec993ba151b7a6cb91 Mon Sep 17 00:00:00 2001 From: tharittk Date: Thu, 28 Aug 2025 08:21:22 -0500 Subject: [PATCH 03/27] minor clean up --- pylops_mpi/DistributedArray.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index dd9fd508..9d99fe39 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -485,7 +485,6 @@ def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): else: if is_cuda_aware_mpi or self.engine == "numpy": ncp = get_module(self.engine) - # mpi_type = MPI._typedict[send_buf.dtype.char] recv_buf = ncp.zeros(send_buf.size, dtype=send_buf.dtype) self.base_comm.Allreduce(send_buf, recv_buf, op) return recv_buf @@ -505,7 +504,6 @@ def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): else: if is_cuda_aware_mpi or self.engine == "numpy": ncp = get_module(self.engine) - # mpi_type = MPI._typedict[send_buf.dtype.char] recv_buf = ncp.zeros(send_buf.size, dtype=send_buf.dtype) self.sub_comm.Allreduce(send_buf, recv_buf, op) return recv_buf @@ -743,6 +741,9 @@ def _compute_vector_norm(self, local_array: NDArray, recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) else: recv_buf = self._allreduce_subcomm(send_buf, recv_buf, op=MPI.MAX) + # TODO (tharitt): In current implementation, there seems to be a semantic difference between Buffered MPI and NCCL + # the (1, size) is collapsed to (size, ) with buffered MPI while NCCL retains it. + # There may be a way to unify it - may be something to do with how we allocate the recv_buf. if self.base_comm_nccl: recv_buf = ncp.squeeze(recv_buf, axis=axis) elif ord == -ncp.inf: From ca558fd70c568e21d09ff36673435a0de5b85ee2 Mon Sep 17 00:00:00 2001 From: mrava87 Date: Sun, 7 Sep 2025 20:47:35 +0000 Subject: [PATCH 04/27] feat: WIP DistributedMix A new DistributedMix class is create with the aim of simpflify and unify all comm. calls in both DistributedArray and operators (further hiding away all implementation details). --- pylops_mpi/Distributed.py | 45 ++++++++++++++++++++ pylops_mpi/DistributedArray.py | 66 ++++------------------------- pylops_mpi/basicoperators/VStack.py | 16 ++++--- pylops_mpi/utils/deps.py | 4 ++ 4 files changed, 68 insertions(+), 63 deletions(-) create mode 100644 pylops_mpi/Distributed.py diff --git a/pylops_mpi/Distributed.py b/pylops_mpi/Distributed.py new file mode 100644 index 00000000..dccaf6a6 --- /dev/null +++ b/pylops_mpi/Distributed.py @@ -0,0 +1,45 @@ +from typing import Any, NewType + +from mpi4py import MPI +from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils +from pylops_mpi.utils._mpi import mpi_allreduce +from pylops_mpi.utils import deps + +cupy_message = pylops_deps.cupy_import("the DistributedArray module") +nccl_message = deps.nccl_import("the DistributedArray module") + +if nccl_message is None and cupy_message is None: + from pylops_mpi.utils._nccl import ( + nccl_allgather, nccl_allreduce, + nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv, + _prepare_nccl_allgather_inputs, _unroll_nccl_allgather_recv + ) + + +class DistributedMixIn: + r"""Distributed Mixin class + + This class implements all methods associated with communication primitives + from MPI and NCCL. It is mostly charged to identifying which commuicator + to use and whether the buffered or object MPI primitives should be used + (the former in the case of NumPy arrays or CuPy arrays when a CUDA-Aware + MPI installation is available, the latter with CuPy arrays when a CUDA-Aware + MPI installation is not available). + """ + def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): + """Allreduce operation + """ + if deps.nccl_enabled and getattr(self, "base_comm_nccl"): + return nccl_allreduce(self.base_comm_nccl, send_buf, recv_buf, op) + else: + return mpi_allreduce(self.base_comm, send_buf, + recv_buf, self.engine, op) + + def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): + """Allreduce operation with subcommunicator + """ + if deps.nccl_enabled and getattr(self, "base_comm_nccl"): + return nccl_allreduce(self.sub_comm, send_buf, recv_buf, op) + else: + return mpi_allreduce(self.sub_comm, send_buf, + recv_buf, self.engine, op) diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index 9d99fe39..6fd3ee95 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -3,12 +3,13 @@ from typing import Any, List, Optional, Tuple, Union, NewType import numpy as np -import os from mpi4py import MPI +from pylops_mpi.Distributed import DistributedMixIn from pylops.utils import DTypeLike, NDArray from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils from pylops.utils._internal import _value_or_sized_to_tuple from pylops.utils.backend import get_array_module, get_module, get_module_name +from pylops_mpi.utils._mpi import mpi_allreduce, mpi_send from pylops_mpi.utils import deps cupy_message = pylops_deps.cupy_import("the DistributedArray module") @@ -22,10 +23,6 @@ NcclCommunicatorType = NewType("NcclCommunicator", NcclCommunicator) -if int(os.environ.get("PYLOPS_MPI_CUDA_AWARE", 0)): - is_cuda_aware_mpi = True -else: - is_cuda_aware_mpi = False class Partition(Enum): r"""Enum class @@ -104,7 +101,7 @@ def subcomm_split(mask, comm: Optional[Union[MPI.Comm, NcclCommunicatorType]] = return sub_comm -class DistributedArray: +class DistributedArray(DistributedMixIn): r"""Distributed Numpy Arrays Multidimensional NumPy-like distributed arrays. @@ -477,44 +474,6 @@ def _check_mask(self, dist_array): if not np.array_equal(self.mask, dist_array.mask): raise ValueError("Mask of both the arrays must be same") - def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): - """Allreduce operation - """ - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): - return nccl_allreduce(self.base_comm_nccl, send_buf, recv_buf, op) - else: - if is_cuda_aware_mpi or self.engine == "numpy": - ncp = get_module(self.engine) - recv_buf = ncp.zeros(send_buf.size, dtype=send_buf.dtype) - self.base_comm.Allreduce(send_buf, recv_buf, op) - return recv_buf - else: - # CuPy with non-CUDA-aware MPI - if recv_buf is None: - return self.base_comm.allreduce(send_buf, op) - # For MIN and MAX which require recv_buf - self.base_comm.Allreduce(send_buf, recv_buf, op) - return recv_buf - - def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): - """Allreduce operation with subcommunicator - """ - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): - return nccl_allreduce(self.sub_comm, send_buf, recv_buf, op) - else: - if is_cuda_aware_mpi or self.engine == "numpy": - ncp = get_module(self.engine) - recv_buf = ncp.zeros(send_buf.size, dtype=send_buf.dtype) - self.sub_comm.Allreduce(send_buf, recv_buf, op) - return recv_buf - else: - # CuPy with non-CUDA-aware MPI - if recv_buf is None: - return self.sub_comm.allreduce(send_buf, op) - # For MIN and MAX which require recv_buf - self.sub_comm.Allreduce(send_buf, recv_buf, op) - return recv_buf - def _allgather(self, send_buf, recv_buf=None): """Allgather operation """ @@ -556,16 +515,9 @@ def _send(self, send_buf, dest, count=None, tag=0): count = send_buf.size nccl_send(self.base_comm_nccl, send_buf, dest, count) else: - if is_cuda_aware_mpi or self.engine == "numpy": - # Determine MPI type based on array dtype - mpi_type = MPI._typedict[send_buf.dtype.char] - if count is None: - count = send_buf.size - self.base_comm.Send([send_buf, count, mpi_type], dest=dest, tag=tag) - else: - # Uses CuPy without CUDA-aware MPI - self.base_comm.send(send_buf, dest, tag) - + mpi_send(self.base_comm, + send_buf, dest, count, tag=tag, + engine=self.engine) def _recv(self, recv_buf=None, source=0, count=None, tag=0): """Receive operation @@ -579,7 +531,7 @@ def _recv(self, recv_buf=None, source=0, count=None, tag=0): return recv_buf else: # NumPy + MPI will benefit from buffered communication regardless of MPI installation - if is_cuda_aware_mpi or self.engine == "numpy": + if deps.cuda_aware_mpi_enabled or self.engine == "numpy": ncp = get_module(self.engine) if recv_buf is None: if count is None: @@ -734,7 +686,7 @@ def _compute_vector_norm(self, local_array: NDArray, # CuPy + non-CUDA-aware MPI does not work well with buffered communication, particularly # with MAX, MIN operator. Here we copy the array back to CPU, transfer, and copy them back to GPUs send_buf = ncp.max(ncp.abs(local_array), axis=axis).astype(ncp.float64) - if self.engine == "cupy" and self.base_comm_nccl is None and not is_cuda_aware_mpi: + if self.engine == "cupy" and self.base_comm_nccl is None and not deps.cuda_aware_mpi_enabled: # CuPy + non-CUDA-aware MPI: This will call non-buffered communication # which return a list of object - must be copied back to a GPU memory. recv_buf = self._allreduce_subcomm(send_buf.get(), recv_buf.get(), op=MPI.MAX) @@ -750,7 +702,7 @@ def _compute_vector_norm(self, local_array: NDArray, # Calculate min followed by min reduction # See the comment above in +infinity norm send_buf = ncp.min(ncp.abs(local_array), axis=axis).astype(ncp.float64) - if self.engine == "cupy" and self.base_comm_nccl is None and not is_cuda_aware_mpi: + if self.engine == "cupy" and self.base_comm_nccl is None and not deps.cuda_aware_mpi_enabled: recv_buf = self._allreduce_subcomm(send_buf.get(), recv_buf.get(), op=MPI.MIN) recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) else: diff --git a/pylops_mpi/basicoperators/VStack.py b/pylops_mpi/basicoperators/VStack.py index 58581565..f6d5b198 100644 --- a/pylops_mpi/basicoperators/VStack.py +++ b/pylops_mpi/basicoperators/VStack.py @@ -15,6 +15,7 @@ Partition, StackedDistributedArray ) +from pylops_mpi.Distributed import DistributedMixIn from pylops_mpi.utils.decorators import reshaped from pylops_mpi.utils import deps @@ -25,7 +26,7 @@ from pylops_mpi.utils._nccl import nccl_allreduce -class MPIVStack(MPILinearOperator): +class MPIVStack(DistributedMixIn, MPILinearOperator): r"""MPI VStack Operator Create a vertical stack of a set of linear operators using MPI. Each rank must @@ -141,16 +142,19 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: @reshaped(forward=False, stacking=True) def _rmatvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=self.shape[1], base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, partition=Partition.BROADCAST, + # TODO: consider adding base_comm, base_comm_nccl, engine to the + # input parameters of _allreduce instead of relying on self + self.base_comm, self.base_comm_nccl, self.engine = \ + x.base_comm, x.base_comm_nccl, x.engine + y = DistributedArray(global_shape=self.shape[1], base_comm=x.base_comm, + base_comm_nccl=x.base_comm_nccl, + partition=Partition.BROADCAST, engine=x.engine, dtype=self.dtype) y1 = [] for iop, oper in enumerate(self.ops): y1.append(oper.rmatvec(x.local_array[self.nnops[iop]: self.nnops[iop + 1]])) y1 = ncp.sum(ncp.vstack(y1), axis=0) - if deps.nccl_enabled and x.base_comm_nccl: - y[:] = nccl_allreduce(x.base_comm_nccl, y1, op=MPI.SUM) - else: - y[:] = self.base_comm.allreduce(y1, op=MPI.SUM) + y[:] = self._allreduce(y1, op=MPI.SUM) return y diff --git a/pylops_mpi/utils/deps.py b/pylops_mpi/utils/deps.py index 9d983f60..c9dc4aa3 100644 --- a/pylops_mpi/utils/deps.py +++ b/pylops_mpi/utils/deps.py @@ -39,6 +39,10 @@ def nccl_import(message: Optional[str] = None) -> str: return nccl_message +cuda_aware_mpi_enabled: bool = ( + True if int(os.getenv("PYLOPS_MPI_CUDA_AWARE", 1) == 1) else False +) + nccl_enabled: bool = ( True if (nccl_import() is None and int(os.getenv("NCCL_PYLOPS_MPI", 1)) == 1) else False ) From 64854bbe88061f097350d687ae3ad15e30a4e8c9 Mon Sep 17 00:00:00 2001 From: mrava87 Date: Sun, 7 Sep 2025 20:53:18 +0000 Subject: [PATCH 05/27] feat: added _mpi file with actual mpi comm. implementations --- pylops_mpi/utils/_mpi.py | 95 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 pylops_mpi/utils/_mpi.py diff --git a/pylops_mpi/utils/_mpi.py b/pylops_mpi/utils/_mpi.py new file mode 100644 index 00000000..2d08245e --- /dev/null +++ b/pylops_mpi/utils/_mpi.py @@ -0,0 +1,95 @@ +__all__ = [ + # "mpi_allgather", + "mpi_allreduce", + # "mpi_bcast", + # "mpi_asarray", + "mpi_send", + # "mpi_recv", +] + +from typing import Optional + +import numpy as np +from mpi4py import MPI +from pylops.utils.backend import get_module +from pylops_mpi.utils import deps + + +def mpi_allreduce(base_comm: MPI.Comm, + send_buf, recv_buf=None, + engine: Optional[str] = "numpy", + op: MPI.Op = MPI.SUM) -> np.ndarray: + """MPI_Allreduce/allreduce + + Dispatch allreduce routine based on type of input and availability of + CUDA-Aware MPI + + Parameters + ---------- + base_comm : :obj:`MPI.Comm` + Base MPI Communicator. + send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + The data buffer from the local GPU to be reduced. + recv_buf : :obj:`cupy.ndarray`, optional + The buffer to store the result of the reduction. If None, + a new buffer will be allocated with the appropriate shape. + engine : :obj:`str`, optional + Engine used to store array (``numpy`` or ``cupy``) + op : :obj:mpi4py.MPI.Op, optional + The reduction operation to apply. Defaults to MPI.SUM. + + Returns + ------- + recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + A buffer containing the result of the reduction, broadcasted + to all GPUs. + + """ + if deps.cuda_aware_mpi_enabled or engine == "numpy": + ncp = get_module(engine) + recv_buf = ncp.zeros(send_buf.size, dtype=send_buf.dtype) + base_comm.Allreduce(send_buf, recv_buf, op) + return recv_buf + else: + # CuPy with non-CUDA-aware MPI + if recv_buf is None: + return base_comm.allreduce(send_buf, op) + # For MIN and MAX which require recv_buf + base_comm.Allreduce(send_buf, recv_buf, op) + return recv_buf + + +def mpi_send(base_comm: MPI.Comm, + send_buf, dest, count, tag=0, + engine: Optional[str] = "numpy", + ) -> None: + """MPI_Send/send + + Dispatch send routine based on type of input and availability of + CUDA-Aware MPI + + Parameters + ---------- + base_comm : :obj:`MPI.Comm` + Base MPI Communicator. + send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + The array containing data to send. + dest: :obj:`int` + The rank of the destination GPU device. + count : :obj:`int` + Number of elements to send from `send_buf`. + tag : :obj:`int` + Tag of the message to be sent. + engine : :obj:`str`, optional + Engine used to store array (``numpy`` or ``cupy``) + + """ + if deps.cuda_aware_mpi_enabled or engine == "numpy": + # Determine MPI type based on array dtype + mpi_type = MPI._typedict[send_buf.dtype.char] + if count is None: + count = send_buf.size + base_comm.Send([send_buf, count, mpi_type], dest=dest, tag=tag) + else: + # Uses CuPy without CUDA-aware MPI + base_comm.send(send_buf, dest, tag) From 838ed0b98dcbb0afdab5413685bb33608c794ba5 Mon Sep 17 00:00:00 2001 From: mrava87 Date: Sun, 7 Sep 2025 21:10:47 +0000 Subject: [PATCH 06/27] feat: moved _send to Distributed --- pylops_mpi/Distributed.py | 14 +++++++++++++- pylops_mpi/DistributedArray.py | 12 ------------ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/pylops_mpi/Distributed.py b/pylops_mpi/Distributed.py index dccaf6a6..7384c40f 100644 --- a/pylops_mpi/Distributed.py +++ b/pylops_mpi/Distributed.py @@ -2,7 +2,7 @@ from mpi4py import MPI from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils -from pylops_mpi.utils._mpi import mpi_allreduce +from pylops_mpi.utils._mpi import mpi_allreduce, mpi_send from pylops_mpi.utils import deps cupy_message = pylops_deps.cupy_import("the DistributedArray module") @@ -43,3 +43,15 @@ def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): else: return mpi_allreduce(self.sub_comm, send_buf, recv_buf, self.engine, op) + + def _send(self, send_buf, dest, count=None, tag=0): + """Send operation + """ + if deps.nccl_enabled and self.base_comm_nccl: + if count is None: + count = send_buf.size + nccl_send(self.base_comm_nccl, send_buf, dest, count) + else: + mpi_send(self.base_comm, + send_buf, dest, count, tag=tag, + engine=self.engine) diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index 6fd3ee95..cac36f6a 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -507,18 +507,6 @@ def _allgather_subcomm(self, send_buf, recv_buf=None): return self.sub_comm.allgather(send_buf) self.sub_comm.Allgather(send_buf, recv_buf) - def _send(self, send_buf, dest, count=None, tag=0): - """Send operation - """ - if deps.nccl_enabled and self.base_comm_nccl: - if count is None: - count = send_buf.size - nccl_send(self.base_comm_nccl, send_buf, dest, count) - else: - mpi_send(self.base_comm, - send_buf, dest, count, tag=tag, - engine=self.engine) - def _recv(self, recv_buf=None, source=0, count=None, tag=0): """Receive operation """ From ab97e3dbfaaca6ff69dfc600a9aecfe7d5f93a3d Mon Sep 17 00:00:00 2001 From: tharittk Date: Fri, 12 Sep 2025 01:47:41 -0500 Subject: [PATCH 07/27] mpi_recv for MixIn --- pylops_mpi/utils/_mpi.py | 42 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/pylops_mpi/utils/_mpi.py b/pylops_mpi/utils/_mpi.py index 2d08245e..c635acc4 100644 --- a/pylops_mpi/utils/_mpi.py +++ b/pylops_mpi/utils/_mpi.py @@ -75,7 +75,7 @@ def mpi_send(base_comm: MPI.Comm, send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` The array containing data to send. dest: :obj:`int` - The rank of the destination GPU device. + The rank of the destination CPU/GPU device. count : :obj:`int` Number of elements to send from `send_buf`. tag : :obj:`int` @@ -93,3 +93,43 @@ def mpi_send(base_comm: MPI.Comm, else: # Uses CuPy without CUDA-aware MPI base_comm.send(send_buf, dest, tag) + +def mpi_recv(base_comm: MPI.Comm, + recv_buf=None, source=0, count=None, tag=0, + engine: Optional[str] = "numpy") -> np.ndarray: + """ MPI_Recv/recv + Dispatch receive routine based on type of input and availability of + CUDA-Aware MPI + + Parameters + ---------- + base_comm : :obj:`MPI.Comm` + Base MPI Communicator. + recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`, optional + The buffered array to receive data. + source : :obj:`int` + The rank of the sending CPU/GPU device. + count : :obj:`int` + Number of elements to receive. + tag : :obj:`int` + Tag of the message to be sent. + engine : :obj:`str`, optional + Engine used to store array (``numpy`` or ``cupy``) + + """ + if deps.cuda_aware_mpi_enabled or engine == "numpy": + ncp = get_module(engine) + if recv_buf is None: + if count is None: + raise ValueError("Must provide either recv_buf or count for MPI receive") + # Default to int32 works currently because add_ghost_cells() is called + # with recv_buf and is not affected by this branch. The int32 is for when + # dimension or shape-related integers are send/recv + recv_buf = ncp.zeros(count, dtype=ncp.int32) + mpi_type = MPI._typedict[recv_buf.dtype.char] + base_comm.Recv([recv_buf, recv_buf.size, mpi_type], source=source, tag=tag) + else: + # Uses CuPy without CUDA-aware MPI + recv_buf = base_comm.recv(source=source, tag=tag) + return recv_buf + From dbe1f30e3ab3d1d17f4ef2d551ed18df78e84a7f Mon Sep 17 00:00:00 2001 From: tharittk Date: Fri, 12 Sep 2025 02:53:27 -0500 Subject: [PATCH 08/27] MixIn for allgather. --- pylops_mpi/Distributed.py | 55 +++++++++++++-- pylops_mpi/DistributedArray.py | 64 +----------------- pylops_mpi/utils/_mpi.py | 109 ++++++++++++++++++++++++++++-- pylops_mpi/utils/_nccl.py | 89 +----------------------- tests_nccl/test_ncclutils_nccl.py | 7 +- 5 files changed, 163 insertions(+), 161 deletions(-) diff --git a/pylops_mpi/Distributed.py b/pylops_mpi/Distributed.py index 7384c40f..cc86e7d4 100644 --- a/pylops_mpi/Distributed.py +++ b/pylops_mpi/Distributed.py @@ -1,8 +1,8 @@ -from typing import Any, NewType +from typing import Any, NewType, Tuple from mpi4py import MPI from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils -from pylops_mpi.utils._mpi import mpi_allreduce, mpi_send +from pylops_mpi.utils._mpi import mpi_allreduce, mpi_allgather, mpi_send, mpi_recv, _prepare_allgather_inputs, _unroll_allgather_recv from pylops_mpi.utils import deps cupy_message = pylops_deps.cupy_import("the DistributedArray module") @@ -11,11 +11,9 @@ if nccl_message is None and cupy_message is None: from pylops_mpi.utils._nccl import ( nccl_allgather, nccl_allreduce, - nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv, - _prepare_nccl_allgather_inputs, _unroll_nccl_allgather_recv + nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv ) - class DistributedMixIn: r"""Distributed Mixin class @@ -44,6 +42,36 @@ def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): return mpi_allreduce(self.sub_comm, send_buf, recv_buf, self.engine, op) + def _allgather(self, send_buf, recv_buf=None): + """Allgather operation + """ + if deps.nccl_enabled and self.base_comm_nccl: + if isinstance(send_buf, (tuple, list, int)): + return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf) + else: + send_shapes = self.base_comm.allgather(send_buf.shape) + (padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine="cupy") + raw_recv = nccl_allgather(self.base_comm_nccl, padded_send, recv_buf if recv_buf else padded_recv) + return _unroll_allgather_recv(raw_recv, padded_send.shape, send_shapes) + else: + if isinstance(send_buf, (tuple, list, int)): + return self.base_comm.allgather(send_buf) + return mpi_allgather(self.base_comm, send_buf, recv_buf, self.engine) + + def _allgather_subcomm(self, send_buf, recv_buf=None): + """Allgather operation with subcommunicator + """ + if deps.nccl_enabled and getattr(self, "base_comm_nccl"): + if isinstance(send_buf, (tuple, list, int)): + return nccl_allgather(self.sub_comm, send_buf, recv_buf) + else: + send_shapes = self._allgather_subcomm(send_buf.shape) + (padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine="cupy") + raw_recv = nccl_allgather(self.sub_comm, padded_send, recv_buf if recv_buf else padded_recv) + return _unroll_allgather_recv(raw_recv, padded_send.shape, send_shapes) + else: + return mpi_allgather(self.sub_comm, send_buf, recv_buf, self.engine) + def _send(self, send_buf, dest, count=None, tag=0): """Send operation """ @@ -55,3 +83,20 @@ def _send(self, send_buf, dest, count=None, tag=0): mpi_send(self.base_comm, send_buf, dest, count, tag=tag, engine=self.engine) + + def _recv(self, recv_buf=None, source=0, count=None, tag=0): + """Receive operation + """ + if deps.nccl_enabled and self.base_comm_nccl: + if recv_buf is None: + raise ValueError("recv_buf must be supplied when using NCCL") + if count is None: + count = recv_buf.size + nccl_recv(self.base_comm_nccl, recv_buf, source, count) + return recv_buf + else: + return mpi_recv(self.base_comm, + recv_buf, source, count, tag=tag, + engine=self.engine) + + diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index cac36f6a..d3cb70f0 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -9,14 +9,13 @@ from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils from pylops.utils._internal import _value_or_sized_to_tuple from pylops.utils.backend import get_array_module, get_module, get_module_name -from pylops_mpi.utils._mpi import mpi_allreduce, mpi_send from pylops_mpi.utils import deps cupy_message = pylops_deps.cupy_import("the DistributedArray module") nccl_message = deps.nccl_import("the DistributedArray module") if nccl_message is None and cupy_message is None: - from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv, _prepare_nccl_allgather_inputs, _unroll_nccl_allgather_recv + from pylops_mpi.utils._nccl import nccl_asarray, nccl_bcast, nccl_split from cupy.cuda.nccl import NcclCommunicator else: NcclCommunicator = Any @@ -474,67 +473,6 @@ def _check_mask(self, dist_array): if not np.array_equal(self.mask, dist_array.mask): raise ValueError("Mask of both the arrays must be same") - def _allgather(self, send_buf, recv_buf=None): - """Allgather operation - """ - if deps.nccl_enabled and self.base_comm_nccl: - if isinstance(send_buf, (tuple, list, int)): - return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf) - else: - send_shapes = self.base_comm.allgather(send_buf.shape) - (padded_send, padded_recv) = _prepare_nccl_allgather_inputs(send_buf, send_shapes) - raw_recv = nccl_allgather(self.base_comm_nccl, padded_send, recv_buf if recv_buf else padded_recv) - return _unroll_nccl_allgather_recv(raw_recv, padded_send.shape, send_shapes) - else: - if recv_buf is None: - return self.base_comm.allgather(send_buf) - self.base_comm.Allgather(send_buf, recv_buf) - return recv_buf - - def _allgather_subcomm(self, send_buf, recv_buf=None): - """Allgather operation with subcommunicator - """ - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): - if isinstance(send_buf, (tuple, list, int)): - return nccl_allgather(self.sub_comm, send_buf, recv_buf) - else: - send_shapes = self._allgather_subcomm(send_buf.shape) - (padded_send, padded_recv) = _prepare_nccl_allgather_inputs(send_buf, send_shapes) - raw_recv = nccl_allgather(self.sub_comm, padded_send, recv_buf if recv_buf else padded_recv) - return _unroll_nccl_allgather_recv(raw_recv, padded_send.shape, send_shapes) - else: - if recv_buf is None: - return self.sub_comm.allgather(send_buf) - self.sub_comm.Allgather(send_buf, recv_buf) - - def _recv(self, recv_buf=None, source=0, count=None, tag=0): - """Receive operation - """ - if deps.nccl_enabled and self.base_comm_nccl: - if recv_buf is None: - raise ValueError("recv_buf must be supplied when using NCCL") - if count is None: - count = recv_buf.size - nccl_recv(self.base_comm_nccl, recv_buf, source, count) - return recv_buf - else: - # NumPy + MPI will benefit from buffered communication regardless of MPI installation - if deps.cuda_aware_mpi_enabled or self.engine == "numpy": - ncp = get_module(self.engine) - if recv_buf is None: - if count is None: - raise ValueError("Must provide either recv_buf or count for MPI receive") - # Default to int32 works currently because add_ghost_cells() is called - # with recv_buf and is not affected by this branch. The int32 is for when - # dimension or shape-related integers are send/recv - recv_buf = ncp.zeros(count, dtype=ncp.int32) - mpi_type = MPI._typedict[recv_buf.dtype.char] - self.base_comm.Recv([recv_buf, recv_buf.size, mpi_type], source=source, tag=tag) - else: - # Uses CuPy without CUDA-aware MPI - recv_buf = self.base_comm.recv(source=source, tag=tag) - return recv_buf - def _nccl_local_shapes(self, masked: bool): """Get the the list of shapes of every GPU in the communicator """ diff --git a/pylops_mpi/utils/_mpi.py b/pylops_mpi/utils/_mpi.py index c635acc4..33cfe270 100644 --- a/pylops_mpi/utils/_mpi.py +++ b/pylops_mpi/utils/_mpi.py @@ -1,19 +1,100 @@ __all__ = [ - # "mpi_allgather", + "mpi_allgather", "mpi_allreduce", # "mpi_bcast", # "mpi_asarray", "mpi_send", - # "mpi_recv", + "mpi_recv", + "_prepare_allgather_inputs", + "_unroll_allgather_recv" ] -from typing import Optional +from typing import Optional, Tuple import numpy as np from mpi4py import MPI from pylops.utils.backend import get_module from pylops_mpi.utils import deps +# TODO: return type annotation for both cupy and numpy +def _prepare_allgather_inputs(send_buf, send_buf_shapes, engine): + r""" Prepare send_buf and recv_buf for NCCL allgather (nccl_allgather) + + Buffered Allgather (MPI and NCCL) requires the sending buffer to have the same size for every device. + Therefore, padding is required when the array is not evenly partitioned across + all the ranks. The padding is applied such that the each dimension of the sending buffers + is equal to the max size of that dimension across all ranks. + + Similarly, each receiver buffer (recv_buf) is created with size equal to :math:n_rank \cdot send_buf.size + + Parameters + ---------- + send_buf : :obj: `numpy.ndarray` or `cupy.ndarray` or array-like + The data buffer from the local GPU to be sent for allgather. + send_buf_shapes: :obj:`list` + A list of shapes for each GPU send_buf (used to calculate padding size) + engine : :obj:`str` + Engine used to store array (``numpy`` or ``cupy``) + + Returns + ------- + send_buf: :obj:`cupy.ndarray` + A buffer containing the data and padded elements to be sent by this rank. + recv_buf : :obj:`cupy.ndarray` + An empty, padded buffer to gather data from all GPUs. + """ + ncp = get_module(engine) + sizes_each_dim = list(zip(*send_buf_shapes)) + send_shape = tuple(map(max, sizes_each_dim)) + pad_size = [ + (0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, send_buf.shape) + ] + + send_buf = ncp.pad( + send_buf, pad_size, mode="constant", constant_values=0 + ) + + ndev = len(send_buf_shapes) + recv_buf = ncp.zeros(ndev * send_buf.size, dtype=send_buf.dtype) + + return send_buf, recv_buf + + +def _unroll_allgather_recv(recv_buf, padded_send_buf_shape, send_buf_shapes) -> list: + r"""Unrolll recv_buf after Buffered Allgather (MPI and NCCL) + + Remove the padded elements in recv_buff, extract an individual array from each device and return them as a list of arrays + Each GPU may send array with a different shape, so the return type has to be a list of array + instead of the concatenated array. + + Parameters + ---------- + recv_buf: :obj:`cupy.ndarray` or array-like + The data buffer returned from nccl_allgather call + padded_send_buf_shape: :obj:`tuple`:int + The size of send_buf after padding used in nccl_allgather + send_buf_shapes: :obj:`list` + A list of original shapes for each GPU send_buf prior to padding + + Returns + ------- + chunks: :obj:`list` + A list of `cupy.ndarray` from each GPU with the padded element removed + """ + ndev = len(send_buf_shapes) + # extract an individual array from each device + chunk_size = np.prod(padded_send_buf_shape) + chunks = [ + recv_buf[i * chunk_size:(i + 1) * chunk_size] for i in range(ndev) + ] + + # Remove padding from each array: the padded value may appear somewhere + # in the middle of the flat array and thus the reshape and slicing for each dimension is required + for i in range(ndev): + slicing = tuple(slice(0, end) for end in send_buf_shapes[i]) + chunks[i] = chunks[i].reshape(padded_send_buf_shape)[slicing] + + return chunks def mpi_allreduce(base_comm: MPI.Comm, send_buf, recv_buf=None, @@ -57,7 +138,27 @@ def mpi_allreduce(base_comm: MPI.Comm, # For MIN and MAX which require recv_buf base_comm.Allreduce(send_buf, recv_buf, op) return recv_buf - + + +def mpi_allgather(base_comm: MPI.Comm, + send_buf, recv_buf=None, + engine: Optional[str] = "numpy", + ) -> np.ndarray: + + if deps.cuda_aware_mpi_enabled or engine == "numpy": + send_shapes = base_comm.allgather(send_buf.shape) + (padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine=engine) + recv_buffer_to_use = recv_buf if recv_buf else padded_recv + base_comm.Allgather(padded_send, recv_buffer_to_use) + return _unroll_allgather_recv(recv_buffer_to_use, padded_send.shape, send_shapes) + + else: + # CuPy with non-CUDA-aware MPI + if recv_buf is None: + return base_comm.allgather(send_buf) + base_comm.Allgather(send_buf, recv_buf) + return recv_buf + def mpi_send(base_comm: MPI.Comm, send_buf, dest, count, tag=0, diff --git a/pylops_mpi/utils/_nccl.py b/pylops_mpi/utils/_nccl.py index 19c09922..0eb6cde1 100644 --- a/pylops_mpi/utils/_nccl.py +++ b/pylops_mpi/utils/_nccl.py @@ -1,6 +1,4 @@ __all__ = [ - "_prepare_nccl_allgather_inputs", - "_unroll_nccl_allgather_recv", "_nccl_sync", "initialize_nccl_comm", "nccl_split", @@ -13,12 +11,11 @@ ] from enum import IntEnum -from typing import Tuple from mpi4py import MPI import os -import numpy as np import cupy as cp import cupy.cuda.nccl as nccl +from pylops_mpi.utils._mpi import _prepare_allgather_inputs, _unroll_allgather_recv cupy_to_nccl_dtype = { "float32": nccl.NCCL_FLOAT32, @@ -69,86 +66,6 @@ def _nccl_sync(): return cp.cuda.runtime.deviceSynchronize() - -def _prepare_nccl_allgather_inputs(send_buf, send_buf_shapes) -> Tuple[cp.ndarray, cp.ndarray]: - r""" Prepare send_buf and recv_buf for NCCL allgather (nccl_allgather) - - NCCL's allGather requires the sending buffer to have the same size for every device. - Therefore, padding is required when the array is not evenly partitioned across - all the ranks. The padding is applied such that the each dimension of the sending buffers - is equal to the max size of that dimension across all ranks. - - Similarly, each receiver buffer (recv_buf) is created with size equal to :math:n_rank \cdot send_buf.size - - Parameters - ---------- - send_buf : :obj:`cupy.ndarray` or array-like - The data buffer from the local GPU to be sent for allgather. - send_buf_shapes: :obj:`list` - A list of shapes for each GPU send_buf (used to calculate padding size) - - Returns - ------- - send_buf: :obj:`cupy.ndarray` - A buffer containing the data and padded elements to be sent by this rank. - recv_buf : :obj:`cupy.ndarray` - An empty, padded buffer to gather data from all GPUs. - """ - sizes_each_dim = list(zip(*send_buf_shapes)) - send_shape = tuple(map(max, sizes_each_dim)) - pad_size = [ - (0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, send_buf.shape) - ] - - send_buf = cp.pad( - send_buf, pad_size, mode="constant", constant_values=0 - ) - - # NCCL recommends to use one MPI Process per GPU and so size of receiving buffer can be inferred - ndev = len(send_buf_shapes) - recv_buf = cp.zeros(ndev * send_buf.size, dtype=send_buf.dtype) - - return send_buf, recv_buf - - -def _unroll_nccl_allgather_recv(recv_buf, padded_send_buf_shape, send_buf_shapes) -> list: - """Unrolll recv_buf after NCCL allgather (nccl_allgather) - - Remove the padded elements in recv_buff, extract an individual array from each device and return them as a list of arrays - Each GPU may send array with a different shape, so the return type has to be a list of array - instead of the concatenated array. - - Parameters - ---------- - recv_buf: :obj:`cupy.ndarray` or array-like - The data buffer returned from nccl_allgather call - padded_send_buf_shape: :obj:`tuple`:int - The size of send_buf after padding used in nccl_allgather - send_buf_shapes: :obj:`list` - A list of original shapes for each GPU send_buf prior to padding - - Returns - ------- - chunks: :obj:`list` - A list of `cupy.ndarray` from each GPU with the padded element removed - """ - - ndev = len(send_buf_shapes) - # extract an individual array from each device - chunk_size = np.prod(padded_send_buf_shape) - chunks = [ - recv_buf[i * chunk_size:(i + 1) * chunk_size] for i in range(ndev) - ] - - # Remove padding from each array: the padded value may appear somewhere - # in the middle of the flat array and thus the reshape and slicing for each dimension is required - for i in range(ndev): - slicing = tuple(slice(0, end) for end in send_buf_shapes[i]) - chunks[i] = chunks[i].reshape(padded_send_buf_shape)[slicing] - - return chunks - - def mpi_op_to_nccl(mpi_op) -> NcclOp: """ Map MPI reduction operation to NCCL equivalent @@ -363,9 +280,9 @@ def nccl_asarray(nccl_comm, local_array, local_shapes, axis) -> cp.ndarray: Global array gathered from all GPUs and concatenated along `axis`. """ - send_buf, recv_buf = _prepare_nccl_allgather_inputs(local_array, local_shapes) + send_buf, recv_buf = _prepare_allgather_inputs(local_array, local_shapes, engine="cupy") nccl_allgather(nccl_comm, send_buf, recv_buf) - chunks = _unroll_nccl_allgather_recv(recv_buf, send_buf.shape, local_shapes) + chunks = _unroll_allgather_recv(recv_buf, send_buf.shape, local_shapes) # combine back to single global array return cp.concatenate(chunks, axis=axis) diff --git a/tests_nccl/test_ncclutils_nccl.py b/tests_nccl/test_ncclutils_nccl.py index 21b28ca3..52502afc 100644 --- a/tests_nccl/test_ncclutils_nccl.py +++ b/tests_nccl/test_ncclutils_nccl.py @@ -8,7 +8,8 @@ from numpy.testing import assert_allclose import pytest -from pylops_mpi.utils._nccl import initialize_nccl_comm, nccl_allgather, _prepare_nccl_allgather_inputs, _unroll_nccl_allgather_recv +from pylops_mpi.utils._nccl import initialize_nccl_comm, nccl_allgather +from pylops_mpi.utils._mpi import _prepare_allgather_inputs, _unroll_allgather_recv np.random.seed(42) @@ -83,9 +84,9 @@ def test_allgather_differentsize_withrecbuf(par): # Gathered array send_shapes = MPI.COMM_WORLD.allgather(local_array.shape) - send_buf, recv_buf = _prepare_nccl_allgather_inputs(local_array, send_shapes) + send_buf, recv_buf = _prepare_allgather_inputs(local_array, send_shapes, engine="cupy") recv_buf = nccl_allgather(nccl_comm, send_buf, recv_buf) - chunks = _unroll_nccl_allgather_recv(recv_buf, send_buf.shape, send_shapes) + chunks = _unroll_allgather_recv(recv_buf, send_buf.shape, send_shapes) gathered_array = cp.concatenate(chunks) # Compare with global array created in rank0 From a08924bf18389a666d4e96b1eef845d76c6e2b46 Mon Sep 17 00:00:00 2001 From: tharittk Date: Fri, 12 Sep 2025 03:07:05 -0500 Subject: [PATCH 09/27] fix flake8 --- pylops_mpi/Distributed.py | 16 +++++-------- pylops_mpi/DistributedArray.py | 6 ++--- pylops_mpi/utils/_mpi.py | 41 +++++++++++++++++----------------- pylops_mpi/utils/_nccl.py | 1 + 4 files changed, 30 insertions(+), 34 deletions(-) diff --git a/pylops_mpi/Distributed.py b/pylops_mpi/Distributed.py index cc86e7d4..8876bfc1 100644 --- a/pylops_mpi/Distributed.py +++ b/pylops_mpi/Distributed.py @@ -1,5 +1,3 @@ -from typing import Any, NewType, Tuple - from mpi4py import MPI from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils from pylops_mpi.utils._mpi import mpi_allreduce, mpi_allgather, mpi_send, mpi_recv, _prepare_allgather_inputs, _unroll_allgather_recv @@ -10,10 +8,10 @@ if nccl_message is None and cupy_message is None: from pylops_mpi.utils._nccl import ( - nccl_allgather, nccl_allreduce, - nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv + nccl_allgather, nccl_allreduce, nccl_send, nccl_recv ) + class DistributedMixIn: r"""Distributed Mixin class @@ -30,7 +28,7 @@ def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): if deps.nccl_enabled and getattr(self, "base_comm_nccl"): return nccl_allreduce(self.base_comm_nccl, send_buf, recv_buf, op) else: - return mpi_allreduce(self.base_comm, send_buf, + return mpi_allreduce(self.base_comm, send_buf, recv_buf, self.engine, op) def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): @@ -39,7 +37,7 @@ def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): if deps.nccl_enabled and getattr(self, "base_comm_nccl"): return nccl_allreduce(self.sub_comm, send_buf, recv_buf, op) else: - return mpi_allreduce(self.sub_comm, send_buf, + return mpi_allreduce(self.sub_comm, send_buf, recv_buf, self.engine, op) def _allgather(self, send_buf, recv_buf=None): @@ -96,7 +94,5 @@ def _recv(self, recv_buf=None, source=0, count=None, tag=0): return recv_buf else: return mpi_recv(self.base_comm, - recv_buf, source, count, tag=tag, - engine=self.engine) - - + recv_buf, source, count, tag=tag, + engine=self.engine) diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index d3cb70f0..faa780ac 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -15,7 +15,7 @@ nccl_message = deps.nccl_import("the DistributedArray module") if nccl_message is None and cupy_message is None: - from pylops_mpi.utils._nccl import nccl_asarray, nccl_bcast, nccl_split + from pylops_mpi.utils._nccl import nccl_asarray, nccl_bcast, nccl_split from cupy.cuda.nccl import NcclCommunicator else: NcclCommunicator = Any @@ -613,14 +613,14 @@ def _compute_vector_norm(self, local_array: NDArray, # with MAX, MIN operator. Here we copy the array back to CPU, transfer, and copy them back to GPUs send_buf = ncp.max(ncp.abs(local_array), axis=axis).astype(ncp.float64) if self.engine == "cupy" and self.base_comm_nccl is None and not deps.cuda_aware_mpi_enabled: - # CuPy + non-CUDA-aware MPI: This will call non-buffered communication + # CuPy + non-CUDA-aware MPI: This will call non-buffered communication # which return a list of object - must be copied back to a GPU memory. recv_buf = self._allreduce_subcomm(send_buf.get(), recv_buf.get(), op=MPI.MAX) recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) else: recv_buf = self._allreduce_subcomm(send_buf, recv_buf, op=MPI.MAX) # TODO (tharitt): In current implementation, there seems to be a semantic difference between Buffered MPI and NCCL - # the (1, size) is collapsed to (size, ) with buffered MPI while NCCL retains it. + # the (1, size) is collapsed to (size, ) with buffered MPI while NCCL retains it. # There may be a way to unify it - may be something to do with how we allocate the recv_buf. if self.base_comm_nccl: recv_buf = ncp.squeeze(recv_buf, axis=axis) diff --git a/pylops_mpi/utils/_mpi.py b/pylops_mpi/utils/_mpi.py index 33cfe270..e3520c94 100644 --- a/pylops_mpi/utils/_mpi.py +++ b/pylops_mpi/utils/_mpi.py @@ -9,13 +9,14 @@ "_unroll_allgather_recv" ] -from typing import Optional, Tuple +from typing import Optional import numpy as np from mpi4py import MPI from pylops.utils.backend import get_module from pylops_mpi.utils import deps + # TODO: return type annotation for both cupy and numpy def _prepare_allgather_inputs(send_buf, send_buf_shapes, engine): r""" Prepare send_buf and recv_buf for NCCL allgather (nccl_allgather) @@ -33,7 +34,7 @@ def _prepare_allgather_inputs(send_buf, send_buf_shapes, engine): The data buffer from the local GPU to be sent for allgather. send_buf_shapes: :obj:`list` A list of shapes for each GPU send_buf (used to calculate padding size) - engine : :obj:`str` + engine : :obj:`str` Engine used to store array (``numpy`` or ``cupy``) Returns @@ -96,20 +97,21 @@ def _unroll_allgather_recv(recv_buf, padded_send_buf_shape, send_buf_shapes) -> return chunks + def mpi_allreduce(base_comm: MPI.Comm, - send_buf, recv_buf=None, + send_buf, recv_buf=None, engine: Optional[str] = "numpy", op: MPI.Op = MPI.SUM) -> np.ndarray: - """MPI_Allreduce/allreduce - - Dispatch allreduce routine based on type of input and availability of + """MPI_Allreduce/allreduce + + Dispatch allreduce routine based on type of input and availability of CUDA-Aware MPI Parameters ---------- base_comm : :obj:`MPI.Comm` Base MPI Communicator. - send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` The data buffer from the local GPU to be reduced. recv_buf : :obj:`cupy.ndarray`, optional The buffer to store the result of the reduction. If None, @@ -121,10 +123,10 @@ def mpi_allreduce(base_comm: MPI.Comm, Returns ------- - recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` A buffer containing the result of the reduction, broadcasted to all GPUs. - + """ if deps.cuda_aware_mpi_enabled or engine == "numpy": ncp = get_module(engine) @@ -141,9 +143,8 @@ def mpi_allreduce(base_comm: MPI.Comm, def mpi_allgather(base_comm: MPI.Comm, - send_buf, recv_buf=None, - engine: Optional[str] = "numpy", - ) -> np.ndarray: + send_buf, recv_buf=None, + engine: Optional[str] = "numpy") -> np.ndarray: if deps.cuda_aware_mpi_enabled or engine == "numpy": send_shapes = base_comm.allgather(send_buf.shape) @@ -165,15 +166,15 @@ def mpi_send(base_comm: MPI.Comm, engine: Optional[str] = "numpy", ) -> None: """MPI_Send/send - - Dispatch send routine based on type of input and availability of + + Dispatch send routine based on type of input and availability of CUDA-Aware MPI Parameters ---------- base_comm : :obj:`MPI.Comm` Base MPI Communicator. - send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` The array containing data to send. dest: :obj:`int` The rank of the destination CPU/GPU device. @@ -183,7 +184,6 @@ def mpi_send(base_comm: MPI.Comm, Tag of the message to be sent. engine : :obj:`str`, optional Engine used to store array (``numpy`` or ``cupy``) - """ if deps.cuda_aware_mpi_enabled or engine == "numpy": # Determine MPI type based on array dtype @@ -195,11 +195,12 @@ def mpi_send(base_comm: MPI.Comm, # Uses CuPy without CUDA-aware MPI base_comm.send(send_buf, dest, tag) + def mpi_recv(base_comm: MPI.Comm, - recv_buf=None, source=0, count=None, tag=0, - engine: Optional[str] = "numpy") -> np.ndarray: + recv_buf=None, source=0, count=None, tag=0, + engine: Optional[str] = "numpy") -> np.ndarray: """ MPI_Recv/recv - Dispatch receive routine based on type of input and availability of + Dispatch receive routine based on type of input and availability of CUDA-Aware MPI Parameters @@ -216,7 +217,6 @@ def mpi_recv(base_comm: MPI.Comm, Tag of the message to be sent. engine : :obj:`str`, optional Engine used to store array (``numpy`` or ``cupy``) - """ if deps.cuda_aware_mpi_enabled or engine == "numpy": ncp = get_module(engine) @@ -233,4 +233,3 @@ def mpi_recv(base_comm: MPI.Comm, # Uses CuPy without CUDA-aware MPI recv_buf = base_comm.recv(source=source, tag=tag) return recv_buf - diff --git a/pylops_mpi/utils/_nccl.py b/pylops_mpi/utils/_nccl.py index 0eb6cde1..5f297531 100644 --- a/pylops_mpi/utils/_nccl.py +++ b/pylops_mpi/utils/_nccl.py @@ -66,6 +66,7 @@ def _nccl_sync(): return cp.cuda.runtime.deviceSynchronize() + def mpi_op_to_nccl(mpi_op) -> NcclOp: """ Map MPI reduction operation to NCCL equivalent From b8bcd295c946967923d823693ec0447f4a8c3ef3 Mon Sep 17 00:00:00 2001 From: mrava87SW Date: Tue, 23 Sep 2025 20:54:55 +0000 Subject: [PATCH 10/27] feat: added _bcast to DistributedMixIn and added comms as input for all methods --- pylops_mpi/Distributed.py | 54 +++++++++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 19 deletions(-) diff --git a/pylops_mpi/Distributed.py b/pylops_mpi/Distributed.py index 8876bfc1..7e940b84 100644 --- a/pylops_mpi/Distributed.py +++ b/pylops_mpi/Distributed.py @@ -1,6 +1,6 @@ from mpi4py import MPI from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils -from pylops_mpi.utils._mpi import mpi_allreduce, mpi_allgather, mpi_send, mpi_recv, _prepare_allgather_inputs, _unroll_allgather_recv +from pylops_mpi.utils._mpi import mpi_allreduce, mpi_allgather, mpi_bcast, mpi_send, mpi_recv, _prepare_allgather_inputs, _unroll_allgather_recv from pylops_mpi.utils import deps cupy_message = pylops_deps.cupy_import("the DistributedArray module") @@ -8,7 +8,7 @@ if nccl_message is None and cupy_message is None: from pylops_mpi.utils._nccl import ( - nccl_allgather, nccl_allreduce, nccl_send, nccl_recv + nccl_allgather, nccl_allreduce, nccl_bcast, nccl_send, nccl_recv ) @@ -22,39 +22,45 @@ class DistributedMixIn: MPI installation is available, the latter with CuPy arrays when a CUDA-Aware MPI installation is not available). """ - def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): + def _allreduce(self, base_comm, base_comm_nccl, + send_buf, recv_buf=None, op: MPI.Op = MPI.SUM, + engine="numpy"): """Allreduce operation """ - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): - return nccl_allreduce(self.base_comm_nccl, send_buf, recv_buf, op) + if deps.nccl_enabled and base_comm_nccl is not None: + return nccl_allreduce(base_comm_nccl, send_buf, recv_buf, op) else: - return mpi_allreduce(self.base_comm, send_buf, - recv_buf, self.engine, op) + return mpi_allreduce(base_comm, send_buf, + recv_buf, engine, op) - def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): + def _allreduce_subcomm(self, sub_comm, base_comm_nccl, + send_buf, recv_buf=None, op: MPI.Op = MPI.SUM, + engine="numpy"): """Allreduce operation with subcommunicator """ - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): - return nccl_allreduce(self.sub_comm, send_buf, recv_buf, op) + if deps.nccl_enabled and base_comm_nccl is not None: + return nccl_allreduce(sub_comm, send_buf, recv_buf, op) else: - return mpi_allreduce(self.sub_comm, send_buf, - recv_buf, self.engine, op) + return mpi_allreduce(sub_comm, send_buf, + recv_buf, engine, op) - def _allgather(self, send_buf, recv_buf=None): + def _allgather(self, base_comm, base_comm_nccl, + send_buf, recv_buf=None, + engine="numpy"): """Allgather operation """ - if deps.nccl_enabled and self.base_comm_nccl: + if deps.nccl_enabled and base_comm_nccl is not None: if isinstance(send_buf, (tuple, list, int)): - return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf) + return nccl_allgather(base_comm_nccl, send_buf, recv_buf) else: - send_shapes = self.base_comm.allgather(send_buf.shape) + send_shapes = base_comm.allgather(send_buf.shape) (padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine="cupy") - raw_recv = nccl_allgather(self.base_comm_nccl, padded_send, recv_buf if recv_buf else padded_recv) + raw_recv = nccl_allgather(base_comm_nccl, padded_send, recv_buf if recv_buf else padded_recv) return _unroll_allgather_recv(raw_recv, padded_send.shape, send_shapes) else: if isinstance(send_buf, (tuple, list, int)): - return self.base_comm.allgather(send_buf) - return mpi_allgather(self.base_comm, send_buf, recv_buf, self.engine) + return base_comm.allgather(send_buf) + return mpi_allgather(base_comm, send_buf, recv_buf, engine) def _allgather_subcomm(self, send_buf, recv_buf=None): """Allgather operation with subcommunicator @@ -70,6 +76,16 @@ def _allgather_subcomm(self, send_buf, recv_buf=None): else: return mpi_allgather(self.sub_comm, send_buf, recv_buf, self.engine) + def _bcast(self, local_array, index, value): + """BCast operation + """ + if deps.nccl_enabled and getattr(self, "base_comm_nccl"): + nccl_bcast(self.base_comm_nccl, local_array, index, value) + else: + # self.local_array[index] = self.base_comm.bcast(value) + mpi_bcast(self.base_comm, self.rank, self.local_array, index, value, + engine=self.engine) + def _send(self, send_buf, dest, count=None, tag=0): """Send operation """ From f362436909db7c856e6fcde26f538455c7131879 Mon Sep 17 00:00:00 2001 From: mrava87SW Date: Tue, 23 Sep 2025 20:56:18 +0000 Subject: [PATCH 11/27] feat: adapted all comm calls in DistributedArray to new method signatures --- pylops_mpi/DistributedArray.py | 49 +++++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index faa780ac..75b66c7e 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -15,7 +15,7 @@ nccl_message = deps.nccl_import("the DistributedArray module") if nccl_message is None and cupy_message is None: - from pylops_mpi.utils._nccl import nccl_asarray, nccl_bcast, nccl_split + from pylops_mpi.utils._nccl import nccl_asarray, nccl_split from cupy.cuda.nccl import NcclCommunicator else: NcclCommunicator = Any @@ -204,10 +204,7 @@ def __setitem__(self, index, value): the specified index positions. """ if self.partition is Partition.BROADCAST: - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): - nccl_bcast(self.base_comm_nccl, self.local_array, index, value) - else: - self.local_array[index] = self.base_comm.bcast(value) + self._bcast(self.local_array, index, value) else: self.local_array[index] = value @@ -343,7 +340,9 @@ def local_shapes(self): if deps.nccl_enabled and getattr(self, "base_comm_nccl"): return self._nccl_local_shapes(False) else: - return self._allgather(self.local_shape) + return self._allgather(self.base_comm, + self.base_comm_nccl, + self.local_shape) @property def sub_comm(self): @@ -383,7 +382,10 @@ def asarray(self, masked: bool = False): if masked: final_array = self._allgather_subcomm(self.local_array) else: - final_array = self._allgather(self.local_array) + final_array = self._allgather(self.base_comm, + self.base_comm_nccl, + self.local_array, + engine=self.engine) return np.concatenate(final_array, axis=self.axis) @classmethod @@ -433,6 +435,7 @@ def to_dist(cls, x: NDArray, else: slices = [slice(None)] * x.ndim local_shapes = np.append([0], dist_array._allgather( + base_comm, base_comm_nccl, dist_array.local_shape[axis])) sum_shapes = np.cumsum(local_shapes) slices[axis] = slice(sum_shapes[dist_array.rank], @@ -480,7 +483,9 @@ def _nccl_local_shapes(self, masked: bool): if masked: all_tuples = self._allgather_subcomm(self.local_shape).get() else: - all_tuples = self._allgather(self.local_shape).get() + all_tuples = self._allgather(self.base_comm, + self.base_comm_nccl, + self.local_shape).get() # NCCL returns the flat array that packs every tuple as 1-dimensional array # unpack each tuple from each rank tuple_len = len(self.local_shape) @@ -578,7 +583,9 @@ def dot(self, dist_array): y = DistributedArray.to_dist(x=dist_array.local_array, base_comm=self.base_comm, base_comm_nccl=self.base_comm_nccl) \ if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else dist_array # Flatten the local arrays and calculate dot product - return self._allreduce_subcomm(ncp.dot(x.local_array.flatten(), y.local_array.flatten())) + return self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, + ncp.dot(x.local_array.flatten(), y.local_array.flatten()), + engine=self.engine) def _compute_vector_norm(self, local_array: NDArray, axis: int, ord: Optional[int] = None): @@ -606,7 +613,9 @@ def _compute_vector_norm(self, local_array: NDArray, raise ValueError(f"norm-{ord} not possible for vectors") elif ord == 0: # Count non-zero then sum reduction - recv_buf = self._allreduce_subcomm(ncp.count_nonzero(local_array, axis=axis).astype(ncp.float64)) + recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, + ncp.count_nonzero(local_array, axis=axis).astype(ncp.float64), + engine=self.engine) elif ord == ncp.inf: # Calculate max followed by max reduction # CuPy + non-CUDA-aware MPI does not work well with buffered communication, particularly @@ -615,10 +624,14 @@ def _compute_vector_norm(self, local_array: NDArray, if self.engine == "cupy" and self.base_comm_nccl is None and not deps.cuda_aware_mpi_enabled: # CuPy + non-CUDA-aware MPI: This will call non-buffered communication # which return a list of object - must be copied back to a GPU memory. - recv_buf = self._allreduce_subcomm(send_buf.get(), recv_buf.get(), op=MPI.MAX) + recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, + send_buf.get(), recv_buf.get(), + op=MPI.MAX, engine=self.engine) recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) else: - recv_buf = self._allreduce_subcomm(send_buf, recv_buf, op=MPI.MAX) + recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, + send_buf, recv_buf, op=MPI.MAX, + engine=self.engine) # TODO (tharitt): In current implementation, there seems to be a semantic difference between Buffered MPI and NCCL # the (1, size) is collapsed to (size, ) with buffered MPI while NCCL retains it. # There may be a way to unify it - may be something to do with how we allocate the recv_buf. @@ -629,14 +642,20 @@ def _compute_vector_norm(self, local_array: NDArray, # See the comment above in +infinity norm send_buf = ncp.min(ncp.abs(local_array), axis=axis).astype(ncp.float64) if self.engine == "cupy" and self.base_comm_nccl is None and not deps.cuda_aware_mpi_enabled: - recv_buf = self._allreduce_subcomm(send_buf.get(), recv_buf.get(), op=MPI.MIN) + recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, + send_buf.get(), recv_buf.get(), + op=MPI.MIN, engine=self.engine) recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) else: - recv_buf = self._allreduce_subcomm(send_buf, recv_buf, op=MPI.MIN) + recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, + send_buf, recv_buf, + op=MPI.MIN, engine=self.engine) if self.base_comm_nccl: recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) else: - recv_buf = self._allreduce_subcomm(ncp.sum(ncp.abs(ncp.float_power(local_array, ord)), axis=axis)) + recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, + ncp.sum(ncp.abs(ncp.float_power(local_array, ord)), axis=axis), + engine=self.engine) recv_buf = ncp.power(recv_buf, 1.0 / ord) return recv_buf From 693f0786dd10db78486136f5942eb7d958bb06e8 Mon Sep 17 00:00:00 2001 From: mrava87SW Date: Tue, 23 Sep 2025 20:56:48 +0000 Subject: [PATCH 12/27] feat: adapted all comm calls in VStack to new method signatures --- pylops_mpi/basicoperators/VStack.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/pylops_mpi/basicoperators/VStack.py b/pylops_mpi/basicoperators/VStack.py index f6d5b198..174e9739 100644 --- a/pylops_mpi/basicoperators/VStack.py +++ b/pylops_mpi/basicoperators/VStack.py @@ -6,7 +6,6 @@ from pylops import LinearOperator from pylops.utils import DTypeLike from pylops.utils.backend import get_module -from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils from pylops_mpi import ( MPILinearOperator, @@ -17,13 +16,6 @@ ) from pylops_mpi.Distributed import DistributedMixIn from pylops_mpi.utils.decorators import reshaped -from pylops_mpi.utils import deps - -cupy_message = pylops_deps.cupy_import("the VStack module") -nccl_message = deps.nccl_import("the VStack module") - -if nccl_message is None and cupy_message is None: - from pylops_mpi.utils._nccl import nccl_allreduce class MPIVStack(DistributedMixIn, MPILinearOperator): @@ -142,19 +134,18 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: @reshaped(forward=False, stacking=True) def _rmatvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - # TODO: consider adding base_comm, base_comm_nccl, engine to the - # input parameters of _allreduce instead of relying on self - self.base_comm, self.base_comm_nccl, self.engine = \ - x.base_comm, x.base_comm_nccl, x.engine - y = DistributedArray(global_shape=self.shape[1], base_comm=x.base_comm, + y = DistributedArray(global_shape=self.shape[1], + base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, partition=Partition.BROADCAST, - engine=x.engine, dtype=self.dtype) + engine=x.engine, + dtype=self.dtype) y1 = [] for iop, oper in enumerate(self.ops): y1.append(oper.rmatvec(x.local_array[self.nnops[iop]: self.nnops[iop + 1]])) y1 = ncp.sum(ncp.vstack(y1), axis=0) - y[:] = self._allreduce(y1, op=MPI.SUM) + y[:] = self._allreduce(x.base_comm, x.base_comm_nccl, + y1, op=MPI.SUM, engine=x.engine) return y From c852fc41bb05e2d0d961940321a75fd1bdd626f9 Mon Sep 17 00:00:00 2001 From: mrava87SW Date: Tue, 23 Sep 2025 20:57:21 +0000 Subject: [PATCH 13/27] feat: adapted all comm calls in Fredholm1 to new method signatures --- pylops_mpi/signalprocessing/Fredholm1.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pylops_mpi/signalprocessing/Fredholm1.py b/pylops_mpi/signalprocessing/Fredholm1.py index 6ccd9d21..2969e3c9 100644 --- a/pylops_mpi/signalprocessing/Fredholm1.py +++ b/pylops_mpi/signalprocessing/Fredholm1.py @@ -128,7 +128,8 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: for isl in range(self.nsls[self.rank]): y1[isl] = ncp.dot(self.G[isl], x[isl]) # gather results - y[:] = ncp.vstack(y._allgather(y1)).ravel() + y[:] = ncp.vstack(y._allgather(y.base_comm, y.base_comm_nccl, y1, + engine=y.engine)).ravel() return y def _rmatvec(self, x: NDArray) -> NDArray: @@ -165,5 +166,6 @@ def _rmatvec(self, x: NDArray) -> NDArray: y1[isl] = ncp.dot(x[isl].T.conj(), self.G[isl]).T.conj() # gather results - y[:] = ncp.vstack(y._allgather(y1)).ravel() + y[:] = ncp.vstack(y._allgather(y.base_comm, y.base_comm_nccl, y1, + engine=y.engine)).ravel() return y From 78d753847aa878439eaee220763d73609a47616d Mon Sep 17 00:00:00 2001 From: mrava87SW Date: Tue, 23 Sep 2025 20:58:13 +0000 Subject: [PATCH 14/27] feat: moved methods shared by _mpi and _nccl to _common --- pylops_mpi/utils/_common.py | 92 ++++++++++++++++++++++++++++ pylops_mpi/utils/_mpi.py | 118 ++++++++---------------------------- pylops_mpi/utils/_nccl.py | 2 +- 3 files changed, 117 insertions(+), 95 deletions(-) create mode 100644 pylops_mpi/utils/_common.py diff --git a/pylops_mpi/utils/_common.py b/pylops_mpi/utils/_common.py new file mode 100644 index 00000000..ab149b5c --- /dev/null +++ b/pylops_mpi/utils/_common.py @@ -0,0 +1,92 @@ +__all__ = [ + "_prepare_allgather_inputs", + "_unroll_allgather_recv" +] + +from typing import Optional + +import numpy as np +from mpi4py import MPI +from pylops.utils.backend import get_module +from pylops_mpi.utils import deps + + +# TODO: return type annotation for both cupy and numpy +def _prepare_allgather_inputs(send_buf, send_buf_shapes, engine): + r""" Prepare send_buf and recv_buf for NCCL allgather (nccl_allgather) + + Buffered Allgather (MPI and NCCL) requires the sending buffer to have the same size for every device. + Therefore, padding is required when the array is not evenly partitioned across + all the ranks. The padding is applied such that the each dimension of the sending buffers + is equal to the max size of that dimension across all ranks. + + Similarly, each receiver buffer (recv_buf) is created with size equal to :math:n_rank \cdot send_buf.size + + Parameters + ---------- + send_buf : :obj: `numpy.ndarray` or `cupy.ndarray` or array-like + The data buffer from the local GPU to be sent for allgather. + send_buf_shapes: :obj:`list` + A list of shapes for each GPU send_buf (used to calculate padding size) + engine : :obj:`str` + Engine used to store array (``numpy`` or ``cupy``) + + Returns + ------- + send_buf: :obj:`cupy.ndarray` + A buffer containing the data and padded elements to be sent by this rank. + recv_buf : :obj:`cupy.ndarray` + An empty, padded buffer to gather data from all GPUs. + """ + ncp = get_module(engine) + sizes_each_dim = list(zip(*send_buf_shapes)) + send_shape = tuple(map(max, sizes_each_dim)) + pad_size = [ + (0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, send_buf.shape) + ] + + send_buf = ncp.pad( + send_buf, pad_size, mode="constant", constant_values=0 + ) + + ndev = len(send_buf_shapes) + recv_buf = ncp.zeros(ndev * send_buf.size, dtype=send_buf.dtype) + + return send_buf, recv_buf + + +def _unroll_allgather_recv(recv_buf, padded_send_buf_shape, send_buf_shapes) -> list: + r"""Unrolll recv_buf after Buffered Allgather (MPI and NCCL) + + Remove the padded elements in recv_buff, extract an individual array from each device and return them as a list of arrays + Each GPU may send array with a different shape, so the return type has to be a list of array + instead of the concatenated array. + + Parameters + ---------- + recv_buf: :obj:`cupy.ndarray` or array-like + The data buffer returned from nccl_allgather call + padded_send_buf_shape: :obj:`tuple`:int + The size of send_buf after padding used in nccl_allgather + send_buf_shapes: :obj:`list` + A list of original shapes for each GPU send_buf prior to padding + + Returns + ------- + chunks: :obj:`list` + A list of `cupy.ndarray` from each GPU with the padded element removed + """ + ndev = len(send_buf_shapes) + # extract an individual array from each device + chunk_size = np.prod(padded_send_buf_shape) + chunks = [ + recv_buf[i * chunk_size:(i + 1) * chunk_size] for i in range(ndev) + ] + + # Remove padding from each array: the padded value may appear somewhere + # in the middle of the flat array and thus the reshape and slicing for each dimension is required + for i in range(ndev): + slicing = tuple(slice(0, end) for end in send_buf_shapes[i]) + chunks[i] = chunks[i].reshape(padded_send_buf_shape)[slicing] + + return chunks diff --git a/pylops_mpi/utils/_mpi.py b/pylops_mpi/utils/_mpi.py index e3520c94..89304b8c 100644 --- a/pylops_mpi/utils/_mpi.py +++ b/pylops_mpi/utils/_mpi.py @@ -1,12 +1,10 @@ __all__ = [ "mpi_allgather", "mpi_allreduce", - # "mpi_bcast", + "mpi_bcast", # "mpi_asarray", "mpi_send", "mpi_recv", - "_prepare_allgather_inputs", - "_unroll_allgather_recv" ] from typing import Optional @@ -15,87 +13,26 @@ from mpi4py import MPI from pylops.utils.backend import get_module from pylops_mpi.utils import deps +from pylops_mpi.utils._common import _prepare_allgather_inputs, _unroll_allgather_recv -# TODO: return type annotation for both cupy and numpy -def _prepare_allgather_inputs(send_buf, send_buf_shapes, engine): - r""" Prepare send_buf and recv_buf for NCCL allgather (nccl_allgather) - - Buffered Allgather (MPI and NCCL) requires the sending buffer to have the same size for every device. - Therefore, padding is required when the array is not evenly partitioned across - all the ranks. The padding is applied such that the each dimension of the sending buffers - is equal to the max size of that dimension across all ranks. - - Similarly, each receiver buffer (recv_buf) is created with size equal to :math:n_rank \cdot send_buf.size - - Parameters - ---------- - send_buf : :obj: `numpy.ndarray` or `cupy.ndarray` or array-like - The data buffer from the local GPU to be sent for allgather. - send_buf_shapes: :obj:`list` - A list of shapes for each GPU send_buf (used to calculate padding size) - engine : :obj:`str` - Engine used to store array (``numpy`` or ``cupy``) - - Returns - ------- - send_buf: :obj:`cupy.ndarray` - A buffer containing the data and padded elements to be sent by this rank. - recv_buf : :obj:`cupy.ndarray` - An empty, padded buffer to gather data from all GPUs. - """ - ncp = get_module(engine) - sizes_each_dim = list(zip(*send_buf_shapes)) - send_shape = tuple(map(max, sizes_each_dim)) - pad_size = [ - (0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, send_buf.shape) - ] - - send_buf = ncp.pad( - send_buf, pad_size, mode="constant", constant_values=0 - ) - - ndev = len(send_buf_shapes) - recv_buf = ncp.zeros(ndev * send_buf.size, dtype=send_buf.dtype) - - return send_buf, recv_buf - - -def _unroll_allgather_recv(recv_buf, padded_send_buf_shape, send_buf_shapes) -> list: - r"""Unrolll recv_buf after Buffered Allgather (MPI and NCCL) - - Remove the padded elements in recv_buff, extract an individual array from each device and return them as a list of arrays - Each GPU may send array with a different shape, so the return type has to be a list of array - instead of the concatenated array. - - Parameters - ---------- - recv_buf: :obj:`cupy.ndarray` or array-like - The data buffer returned from nccl_allgather call - padded_send_buf_shape: :obj:`tuple`:int - The size of send_buf after padding used in nccl_allgather - send_buf_shapes: :obj:`list` - A list of original shapes for each GPU send_buf prior to padding - - Returns - ------- - chunks: :obj:`list` - A list of `cupy.ndarray` from each GPU with the padded element removed - """ - ndev = len(send_buf_shapes) - # extract an individual array from each device - chunk_size = np.prod(padded_send_buf_shape) - chunks = [ - recv_buf[i * chunk_size:(i + 1) * chunk_size] for i in range(ndev) - ] +def mpi_allgather(base_comm: MPI.Comm, + send_buf, recv_buf=None, + engine: Optional[str] = "numpy") -> np.ndarray: - # Remove padding from each array: the padded value may appear somewhere - # in the middle of the flat array and thus the reshape and slicing for each dimension is required - for i in range(ndev): - slicing = tuple(slice(0, end) for end in send_buf_shapes[i]) - chunks[i] = chunks[i].reshape(padded_send_buf_shape)[slicing] + if deps.cuda_aware_mpi_enabled or engine == "numpy": + send_shapes = base_comm.allgather(send_buf.shape) + (padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine=engine) + recv_buffer_to_use = recv_buf if recv_buf else padded_recv + base_comm.Allgather(padded_send, recv_buffer_to_use) + return _unroll_allgather_recv(recv_buffer_to_use, padded_send.shape, send_shapes) - return chunks + else: + # CuPy with non-CUDA-aware MPI + if recv_buf is None: + return base_comm.allgather(send_buf) + base_comm.Allgather(send_buf, recv_buf) + return recv_buf def mpi_allreduce(base_comm: MPI.Comm, @@ -142,23 +79,16 @@ def mpi_allreduce(base_comm: MPI.Comm, return recv_buf -def mpi_allgather(base_comm: MPI.Comm, - send_buf, recv_buf=None, - engine: Optional[str] = "numpy") -> np.ndarray: - +def mpi_bcast(base_comm: MPI.Comm, + rank, local_array, index, value, + engine: Optional[str] = "numpy") -> np.ndarray: if deps.cuda_aware_mpi_enabled or engine == "numpy": - send_shapes = base_comm.allgather(send_buf.shape) - (padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine=engine) - recv_buffer_to_use = recv_buf if recv_buf else padded_recv - base_comm.Allgather(padded_send, recv_buffer_to_use) - return _unroll_allgather_recv(recv_buffer_to_use, padded_send.shape, send_shapes) - + if rank == 0: + local_array[index] = value + base_comm.Bcast(local_array[index]) else: # CuPy with non-CUDA-aware MPI - if recv_buf is None: - return base_comm.allgather(send_buf) - base_comm.Allgather(send_buf, recv_buf) - return recv_buf + local_array[index] = base_comm.bcast(value) def mpi_send(base_comm: MPI.Comm, diff --git a/pylops_mpi/utils/_nccl.py b/pylops_mpi/utils/_nccl.py index 5f297531..cac5b61c 100644 --- a/pylops_mpi/utils/_nccl.py +++ b/pylops_mpi/utils/_nccl.py @@ -15,7 +15,7 @@ import os import cupy as cp import cupy.cuda.nccl as nccl -from pylops_mpi.utils._mpi import _prepare_allgather_inputs, _unroll_allgather_recv +from pylops_mpi.utils._common import _prepare_allgather_inputs, _unroll_allgather_recv cupy_to_nccl_dtype = { "float32": nccl.NCCL_FLOAT32, From 0138e3aaa9e3374e1701d86d2d6fa2e408823bda Mon Sep 17 00:00:00 2001 From: tharittk Date: Thu, 9 Oct 2025 02:37:09 -0500 Subject: [PATCH 15/27] fix env flag precedence bug --- pylops_mpi/utils/deps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pylops_mpi/utils/deps.py b/pylops_mpi/utils/deps.py index c9dc4aa3..f0279ceb 100644 --- a/pylops_mpi/utils/deps.py +++ b/pylops_mpi/utils/deps.py @@ -40,7 +40,7 @@ def nccl_import(message: Optional[str] = None) -> str: cuda_aware_mpi_enabled: bool = ( - True if int(os.getenv("PYLOPS_MPI_CUDA_AWARE", 1) == 1) else False + True if int(os.getenv("PYLOPS_MPI_CUDA_AWARE", 1)) == 1 else False ) nccl_enabled: bool = ( From ec883711a7e0369c496ff5c8dcbc8035da5d9bf9 Mon Sep 17 00:00:00 2001 From: tharittk Date: Thu, 9 Oct 2025 02:46:02 -0500 Subject: [PATCH 16/27] fix flake8 --- pylops_mpi/Distributed.py | 10 +++++----- pylops_mpi/DistributedArray.py | 16 ++++++++-------- pylops_mpi/basicoperators/VStack.py | 8 ++++---- pylops_mpi/utils/_common.py | 3 --- 4 files changed, 17 insertions(+), 20 deletions(-) diff --git a/pylops_mpi/Distributed.py b/pylops_mpi/Distributed.py index 7e940b84..7e616a3a 100644 --- a/pylops_mpi/Distributed.py +++ b/pylops_mpi/Distributed.py @@ -22,8 +22,8 @@ class DistributedMixIn: MPI installation is available, the latter with CuPy arrays when a CUDA-Aware MPI installation is not available). """ - def _allreduce(self, base_comm, base_comm_nccl, - send_buf, recv_buf=None, op: MPI.Op = MPI.SUM, + def _allreduce(self, base_comm, base_comm_nccl, + send_buf, recv_buf=None, op: MPI.Op = MPI.SUM, engine="numpy"): """Allreduce operation """ @@ -33,7 +33,7 @@ def _allreduce(self, base_comm, base_comm_nccl, return mpi_allreduce(base_comm, send_buf, recv_buf, engine, op) - def _allreduce_subcomm(self, sub_comm, base_comm_nccl, + def _allreduce_subcomm(self, sub_comm, base_comm_nccl, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM, engine="numpy"): """Allreduce operation with subcommunicator @@ -44,7 +44,7 @@ def _allreduce_subcomm(self, sub_comm, base_comm_nccl, return mpi_allreduce(sub_comm, send_buf, recv_buf, engine, op) - def _allgather(self, base_comm, base_comm_nccl, + def _allgather(self, base_comm, base_comm_nccl, send_buf, recv_buf=None, engine="numpy"): """Allgather operation @@ -85,7 +85,7 @@ def _bcast(self, local_array, index, value): # self.local_array[index] = self.base_comm.bcast(value) mpi_bcast(self.base_comm, self.rank, self.local_array, index, value, engine=self.engine) - + def _send(self, send_buf, dest, count=None, tag=0): """Send operation """ diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index 75b66c7e..da7712d7 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -341,7 +341,7 @@ def local_shapes(self): return self._nccl_local_shapes(False) else: return self._allgather(self.base_comm, - self.base_comm_nccl, + self.base_comm_nccl, self.local_shape) @property @@ -383,7 +383,7 @@ def asarray(self, masked: bool = False): final_array = self._allgather_subcomm(self.local_array) else: final_array = self._allgather(self.base_comm, - self.base_comm_nccl, + self.base_comm_nccl, self.local_array, engine=self.engine) return np.concatenate(final_array, axis=self.axis) @@ -484,7 +484,7 @@ def _nccl_local_shapes(self, masked: bool): all_tuples = self._allgather_subcomm(self.local_shape).get() else: all_tuples = self._allgather(self.base_comm, - self.base_comm_nccl, + self.base_comm_nccl, self.local_shape).get() # NCCL returns the flat array that packs every tuple as 1-dimensional array # unpack each tuple from each rank @@ -625,12 +625,12 @@ def _compute_vector_norm(self, local_array: NDArray, # CuPy + non-CUDA-aware MPI: This will call non-buffered communication # which return a list of object - must be copied back to a GPU memory. recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, - send_buf.get(), recv_buf.get(), + send_buf.get(), recv_buf.get(), op=MPI.MAX, engine=self.engine) recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) else: recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, - send_buf, recv_buf, op=MPI.MAX, + send_buf, recv_buf, op=MPI.MAX, engine=self.engine) # TODO (tharitt): In current implementation, there seems to be a semantic difference between Buffered MPI and NCCL # the (1, size) is collapsed to (size, ) with buffered MPI while NCCL retains it. @@ -643,18 +643,18 @@ def _compute_vector_norm(self, local_array: NDArray, send_buf = ncp.min(ncp.abs(local_array), axis=axis).astype(ncp.float64) if self.engine == "cupy" and self.base_comm_nccl is None and not deps.cuda_aware_mpi_enabled: recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, - send_buf.get(), recv_buf.get(), + send_buf.get(), recv_buf.get(), op=MPI.MIN, engine=self.engine) recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) else: recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, - send_buf, recv_buf, + send_buf, recv_buf, op=MPI.MIN, engine=self.engine) if self.base_comm_nccl: recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) else: recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, - ncp.sum(ncp.abs(ncp.float_power(local_array, ord)), axis=axis), + ncp.sum(ncp.abs(ncp.float_power(local_array, ord)), axis=axis), engine=self.engine) recv_buf = ncp.power(recv_buf, 1.0 / ord) return recv_buf diff --git a/pylops_mpi/basicoperators/VStack.py b/pylops_mpi/basicoperators/VStack.py index 174e9739..de66c342 100644 --- a/pylops_mpi/basicoperators/VStack.py +++ b/pylops_mpi/basicoperators/VStack.py @@ -135,8 +135,8 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: def _rmatvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) y = DistributedArray(global_shape=self.shape[1], - base_comm=x.base_comm, - base_comm_nccl=x.base_comm_nccl, + base_comm=x.base_comm, + base_comm_nccl=x.base_comm_nccl, partition=Partition.BROADCAST, engine=x.engine, dtype=self.dtype) @@ -144,8 +144,8 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: for iop, oper in enumerate(self.ops): y1.append(oper.rmatvec(x.local_array[self.nnops[iop]: self.nnops[iop + 1]])) y1 = ncp.sum(ncp.vstack(y1), axis=0) - y[:] = self._allreduce(x.base_comm, x.base_comm_nccl, - y1, op=MPI.SUM, engine=x.engine) + y[:] = self._allreduce(x.base_comm, x.base_comm_nccl, + y1, op=MPI.SUM, engine=x.engine) return y diff --git a/pylops_mpi/utils/_common.py b/pylops_mpi/utils/_common.py index ab149b5c..895265df 100644 --- a/pylops_mpi/utils/_common.py +++ b/pylops_mpi/utils/_common.py @@ -3,12 +3,9 @@ "_unroll_allgather_recv" ] -from typing import Optional import numpy as np -from mpi4py import MPI from pylops.utils.backend import get_module -from pylops_mpi.utils import deps # TODO: return type annotation for both cupy and numpy From 02ba45bd872609490431c689dbf5ce16fe375abb Mon Sep 17 00:00:00 2001 From: mrava87SW Date: Wed, 15 Oct 2025 21:50:19 +0000 Subject: [PATCH 17/27] doc: added details about cuda-aware mpi in doc --- docs/source/gpu.rst | 10 +++++++- docs/source/installation.rst | 44 +++++++++++++++++++++++++++++++----- 2 files changed, 47 insertions(+), 7 deletions(-) diff --git a/docs/source/gpu.rst b/docs/source/gpu.rst index 9a1af651..52839069 100644 --- a/docs/source/gpu.rst +++ b/docs/source/gpu.rst @@ -11,7 +11,7 @@ This library must be installed *before* PyLops-mpi is installed. .. note:: - Set environment variable ``CUPY_PYLOPS=0`` to force PyLops to ignore the ``cupy`` backend. + Set the environment variable ``CUPY_PYLOPS=0`` to force PyLops to ignore the ``cupy`` backend. This can be also used if a previous (or faulty) version of ``cupy`` is installed in your system, otherwise you will get an error when importing PyLops. @@ -22,6 +22,14 @@ can handle both scenarios. Note that, since most operators in PyLops-mpi are thi some of the operators in PyLops that lack a GPU implementation cannot be used also in PyLops-mpi when working with cupy arrays. +.. note:: + + By default when using ``cupy`` arrays, PyLops-MPI will try to use methods in MPI4Py that communicate memory buffers. + However, this requires a CUDA-Aware MPI installation. If your MPI installation is not CUDA-Aware, set the + environment variable ``PYLOPS_MPI_CUDA_AWARE=0`` to force PyLops-MPI to use methods in MPI4Py that communicate + general Python objects (this will incur a loss of performance!). + + Moreover, PyLops-MPI also supports the Nvidia's Collective Communication Library (NCCL) for highly-optimized collective operations, such as AllReduce, AllGather, etc. This allows PyLops-MPI users to leverage the proprietary technology like NVLink that might be available in their infrastructure for fast data communication. diff --git a/docs/source/installation.rst b/docs/source/installation.rst index d0aafe88..e1d7faf3 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -15,7 +15,13 @@ The minimal set of dependencies for the PyLops-MPI project is: * `MPI4py `_ * `PyLops `_ -Additionally, to use the NCCL engine, the following additional +Additionally, to use the CUDA-aware MPI engine, the following additional +dependencies are required: + +* `CuPy `_ +* CUDA-aware MPI + +Similarly, to use the NCCL engine, the following additional dependencies are required: * `CuPy `_ @@ -27,12 +33,18 @@ if this is not possible, some of the dependencies must be installed prior to ins Download and Install MPI ======================== -Visit the official MPI website to download an appropriate MPI implementation for your system. -Follow the installation instructions provided by the MPI vendor. +Visit the official website of your MPI vendor of choice to download an appropriate MPI +implementation for your system: + +* `Open MPI `_ +* `MPICH `_ +* `Intel MPI `_ +* ... -* `Open MPI `_ -* `MPICH `_ -* `Intel MPI `_ +Alternatively, the conda-forge community provides ready-to-use binary packages for four MPI implementations +(see `MPI4Py documentation `_ for more +details). In this case, you can defer the installation to the stage when the conda environment for your project +is created - see below for more details. Verify MPI Installation ======================= @@ -42,6 +54,17 @@ After installing MPI, verify its installation by opening a terminal and running >> mpiexec --version +Install CUDA-Aware MPI (optional) +================================= +To be able to achieve the best performance when using PyLops-MPI with CuPy arrays, a CUDA-Aware version of +MPI must be installed. + +For `Open MPI`, the conda-forge package has built-in CUDA support, as long as a pre-installed CUDA is detected. +Run the following `commands `_ +for diagnostics. + +For the other MPI implementations, refer to their specific documentation. + Install NCCL (optional) ======================= To obtain highly-optimized performance on GPU clusters, PyLops-MPI also supports the Nvidia's collective communication calls @@ -103,6 +126,15 @@ For a ``conda`` environment, run This will create and activate an environment called ``pylops_mpi``, with all required and optional dependencies. +If you want to also install MPI as part of the creation process of the conda environment, +modify the ``environment-dev.yml`` file by adding ``openmpi``\``mpich`\``impi_rt``\``msmpi`` +just above ``mpi4py``. Note that only ``openmpi`` provides a CUDA-Aware MPI installation. + +If you want to leverage CUDA-Aware MPI but prefer to use another MPI installation, you must +either switch to a `Pip`-based installation (see below), or move ``mpi4py`` into the ``pip`` +section of the ``environment-dev.yml`` file and export the variable ``MPICC`` pointing to +the path of your CUDA-Aware MPI installation. + If you want to enable `NCCL `_ in PyLops-MPI, run this instead .. code-block:: bash From 473cd97618e64f1fa61fe958efefc5b202d205cb Mon Sep 17 00:00:00 2001 From: mrava87SW Date: Sun, 19 Oct 2025 21:14:44 +0000 Subject: [PATCH 18/27] doc: finalized gpu doc --- docs/source/gpu.rst | 70 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 59 insertions(+), 11 deletions(-) diff --git a/docs/source/gpu.rst b/docs/source/gpu.rst index 52839069..bb75ae0c 100644 --- a/docs/source/gpu.rst +++ b/docs/source/gpu.rst @@ -38,13 +38,35 @@ proprietary technology like NVLink that might be available in their infrastructu Set environment variable ``NCCL_PYLOPS_MPI=0`` to explicitly force PyLops-MPI to ignore the ``NCCL`` backend. However, this is optional as users may opt-out for NCCL by skip passing `cupy.cuda.nccl.NcclCommunicator` to - the :class:`pylops_mpi.DistributedArray` + the :class:`pylops_mpi.DistributedArray`. + +In summary: + +.. list-table:: + :widths: 50 25 25 + :header-rows: 1 + + * - Operation model + - Enabled with + - Disabled with + * - NumPy + MPI + - Default + - Cannot be disabled + * - CuPy + MPI + - ``PYLOPS_MPI_CUDA_AWARE=0`` + - ``PYLOPS_MPI_CUDA_AWARE=1`` (default) + * - CuPy + CUDA-Aware MPI + - ``PYLOPS_MPI_CUDA_AWARE=1`` (default) + - ``PYLOPS_MPI_CUDA_AWARE=0`` + * - CuPy + NCCL + - ``NCCL_PYLOPS_MPI=1`` (default) + - ``NCCL_PYLOPS_MPI=0`` Example ------- Finally, let's briefly look at an example. First we write a code snippet using -``numpy`` arrays which PyLops-mpi will run on your CPU: +``numpy`` arrays which PyLops-MPI will run on your CPU: .. code-block:: python @@ -128,41 +150,67 @@ one MPI process. In fact, minor communications like those dealing with array-rel The CuPy and NCCL backend is in active development, with many examples not yet in the docs. You can find many `other examples `_ from the `PyLops Notebooks repository `_. + Supports for NCCL Backend ---------------------------- -In the following, we provide a list of modules (i.e., operators and solvers) where we plan to support NCCL and the current status: +In the following, we provide a list of modules (i.e., operators and solvers) +and their current status in terms of support for the 3 different communication +backends: .. list-table:: - :widths: 50 25 + :widths: 50 25 25 25 :header-rows: 1 - * - modules - - NCCL supported + * - Operator/method + - CPU + - GPU with MPI + - GPU with NCCL * - :class:`pylops_mpi.DistributedArray` - ✅ - * - :class:`pylops_mpi.basicoperators.MPIVStack` - ✅ + - ✅ + * - :class:`pylops_mpi.basicoperators.MPIMatrixMult` + - ✅ + - 🔴 + - 🔴 * - :class:`pylops_mpi.basicoperators.MPIVStack` - ✅ + - ✅ + - ✅ * - :class:`pylops_mpi.basicoperators.MPIHStack` - ✅ + - ✅ + - ✅ * - :class:`pylops_mpi.basicoperators.MPIBlockDiag` - ✅ + - ✅ + - ✅ * - :class:`pylops_mpi.basicoperators.MPIGradient` - ✅ + - ✅ + - ✅ * - :class:`pylops_mpi.basicoperators.MPIFirstDerivative` - ✅ + - ✅ + - ✅ * - :class:`pylops_mpi.basicoperators.MPISecondDerivative` - ✅ + - ✅ + - ✅ * - :class:`pylops_mpi.basicoperators.MPILaplacian` - ✅ + - ✅ + - ✅ + * - :class:`pylops_mpi.signalprocessing.Fredhoml1` + - ✅ + - ✅ + - ✅ * - :class:`pylops_mpi.optimization.basic.cg` - ✅ + - ✅ + - ✅ * - :class:`pylops_mpi.optimization.basic.cgls` - ✅ - * - :class:`pylops_mpi.signalprocessing.Fredhoml1` - ✅ - * - Complex Numeric Data Type for NCCL - ✅ - * - ISTA Solver - - Planned ⏳ \ No newline at end of file + \ No newline at end of file From 2cdb8f7988312ac3b15f5a1a61bba28363f13173 Mon Sep 17 00:00:00 2001 From: mrava87SW Date: Sun, 19 Oct 2025 21:16:40 +0000 Subject: [PATCH 19/27] doc: added some docstrings to Distributed --- pylops_mpi/Distributed.py | 75 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 72 insertions(+), 3 deletions(-) diff --git a/pylops_mpi/Distributed.py b/pylops_mpi/Distributed.py index 7e616a3a..c8efd82f 100644 --- a/pylops_mpi/Distributed.py +++ b/pylops_mpi/Distributed.py @@ -16,16 +16,41 @@ class DistributedMixIn: r"""Distributed Mixin class This class implements all methods associated with communication primitives - from MPI and NCCL. It is mostly charged to identifying which commuicator + from MPI and NCCL. It is mostly charged with identifying which commuicator to use and whether the buffered or object MPI primitives should be used (the former in the case of NumPy arrays or CuPy arrays when a CUDA-Aware MPI installation is available, the latter with CuPy arrays when a CUDA-Aware MPI installation is not available). + """ def _allreduce(self, base_comm, base_comm_nccl, - send_buf, recv_buf=None, op: MPI.Op = MPI.SUM, + send_buf, recv_buf=None, + op: MPI.Op = MPI.SUM, engine="numpy"): """Allreduce operation + + Parameters + ---------- + base_comm : :obj:`MPI.Comm` + Base MPI Communicator. + base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator` + NCCL Communicator. + send_buf: :obj: `numpy.ndarray` or `cupy.ndarray` + A buffer containing the data to be sent by this rank. + recv_buf : :obj: `numpy.ndarray` or `cupy.ndarray`, optional + The buffer to store the result of the reduction. If None, + a new buffer will be allocated with the appropriate shape. + op : :obj: `MPI.Op`, optional + MPI operation to perform. + engine : :obj:`str`, optional + Engine used to store array (``numpy`` or ``cupy``) + + Returns + ------- + recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + A buffer containing the result of the reduction, broadcasted + to all GPUs. + """ if deps.nccl_enabled and base_comm_nccl is not None: return nccl_allreduce(base_comm_nccl, send_buf, recv_buf, op) @@ -34,9 +59,33 @@ def _allreduce(self, base_comm, base_comm_nccl, recv_buf, engine, op) def _allreduce_subcomm(self, sub_comm, base_comm_nccl, - send_buf, recv_buf=None, op: MPI.Op = MPI.SUM, + send_buf, recv_buf=None, + op: MPI.Op = MPI.SUM, engine="numpy"): """Allreduce operation with subcommunicator + + Parameters + ---------- + sub_comm : :obj:`MPI.Comm` + MPI Subcommunicator. + base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator` + NCCL Communicator. + send_buf: :obj: `numpy.ndarray` or `cupy.ndarray` + A buffer containing the data to be sent by this rank. + recv_buf : :obj: `numpy.ndarray` or `cupy.ndarray`, optional + The buffer to store the result of the reduction. If None, + a new buffer will be allocated with the appropriate shape. + op : :obj: `MPI.Op`, optional + MPI operation to perform. + engine : :obj:`str`, optional + Engine used to store array (``numpy`` or ``cupy``) + + Returns + ------- + recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + A buffer containing the result of the reduction, broadcasted + to all ranks. + """ if deps.nccl_enabled and base_comm_nccl is not None: return nccl_allreduce(sub_comm, send_buf, recv_buf, op) @@ -48,6 +97,26 @@ def _allgather(self, base_comm, base_comm_nccl, send_buf, recv_buf=None, engine="numpy"): """Allgather operation + + Parameters + ---------- + sub_comm : :obj:`MPI.Comm` + MPI Subcommunicator. + base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator` + NCCL Communicator. + send_buf: :obj: `numpy.ndarray` or `cupy.ndarray` + A buffer containing the data to be sent by this rank. + recv_buf : :obj: `numpy.ndarray` or `cupy.ndarray`, optional + The buffer to store the result of the gathering. If None, + a new buffer will be allocated with the appropriate shape. + engine : :obj:`str`, optional + Engine used to store array (``numpy`` or ``cupy``) + + Returns + ------- + recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + A buffer containing the gathered data from all ranks. + """ if deps.nccl_enabled and base_comm_nccl is not None: if isinstance(send_buf, (tuple, list, int)): From 563db16dd497b49d66ad1a4a64e0adebb585a0e7 Mon Sep 17 00:00:00 2001 From: mrava87SW Date: Mon, 20 Oct 2025 21:30:21 +0000 Subject: [PATCH 20/27] minor: fix flake8 --- pylops_mpi/Distributed.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pylops_mpi/Distributed.py b/pylops_mpi/Distributed.py index c8efd82f..5ea3bae3 100644 --- a/pylops_mpi/Distributed.py +++ b/pylops_mpi/Distributed.py @@ -24,7 +24,7 @@ class DistributedMixIn: """ def _allreduce(self, base_comm, base_comm_nccl, - send_buf, recv_buf=None, + send_buf, recv_buf=None, op: MPI.Op = MPI.SUM, engine="numpy"): """Allreduce operation @@ -44,13 +44,13 @@ def _allreduce(self, base_comm, base_comm_nccl, MPI operation to perform. engine : :obj:`str`, optional Engine used to store array (``numpy`` or ``cupy``) - + Returns ------- recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` A buffer containing the result of the reduction, broadcasted to all GPUs. - + """ if deps.nccl_enabled and base_comm_nccl is not None: return nccl_allreduce(base_comm_nccl, send_buf, recv_buf, op) @@ -59,7 +59,7 @@ def _allreduce(self, base_comm, base_comm_nccl, recv_buf, engine, op) def _allreduce_subcomm(self, sub_comm, base_comm_nccl, - send_buf, recv_buf=None, + send_buf, recv_buf=None, op: MPI.Op = MPI.SUM, engine="numpy"): """Allreduce operation with subcommunicator @@ -79,13 +79,13 @@ def _allreduce_subcomm(self, sub_comm, base_comm_nccl, MPI operation to perform. engine : :obj:`str`, optional Engine used to store array (``numpy`` or ``cupy``) - + Returns ------- recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` A buffer containing the result of the reduction, broadcasted to all ranks. - + """ if deps.nccl_enabled and base_comm_nccl is not None: return nccl_allreduce(sub_comm, send_buf, recv_buf, op) @@ -111,12 +111,12 @@ def _allgather(self, base_comm, base_comm_nccl, a new buffer will be allocated with the appropriate shape. engine : :obj:`str`, optional Engine used to store array (``numpy`` or ``cupy``) - + Returns ------- recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` A buffer containing the gathered data from all ranks. - + """ if deps.nccl_enabled and base_comm_nccl is not None: if isinstance(send_buf, (tuple, list, int)): From 50c5bd262e86fc5d0e123e50e6ef8bf8321731a8 Mon Sep 17 00:00:00 2001 From: mrava87SW Date: Mon, 20 Oct 2025 21:30:52 +0000 Subject: [PATCH 21/27] bug: fix import of methods in test_ncclutils_nccl --- tests_nccl/test_ncclutils_nccl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests_nccl/test_ncclutils_nccl.py b/tests_nccl/test_ncclutils_nccl.py index 52502afc..eaae0a69 100644 --- a/tests_nccl/test_ncclutils_nccl.py +++ b/tests_nccl/test_ncclutils_nccl.py @@ -9,7 +9,7 @@ import pytest from pylops_mpi.utils._nccl import initialize_nccl_comm, nccl_allgather -from pylops_mpi.utils._mpi import _prepare_allgather_inputs, _unroll_allgather_recv +from pylops_mpi.utils._common import _prepare_allgather_inputs, _unroll_allgather_recv np.random.seed(42) From a80f00eca049e42cec572204bf5b0174f08a65d3 Mon Sep 17 00:00:00 2001 From: mrava87SW Date: Mon, 20 Oct 2025 21:32:37 +0000 Subject: [PATCH 22/27] bug: added engine to x array i test_matrixmult --- tests/test_matrixmult.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_matrixmult.py b/tests/test_matrixmult.py index a596ffe9..20a11783 100644 --- a/tests/test_matrixmult.py +++ b/tests/test_matrixmult.py @@ -219,7 +219,8 @@ def test_MPIMatrixMult_summa(N, K, M, dtype_str): local_shapes=comm.allgather(X_p.shape[0] * X_p.shape[1]), partition=Partition.SCATTER, base_comm=comm, - dtype=dtype + dtype=dtype, + engine=backend, ) x_dist.local_array[:] = X_p.ravel() From 2c67755c74fa90eb20feed9ddd136ab47efcf61b Mon Sep 17 00:00:00 2001 From: mrava87SW Date: Sat, 25 Oct 2025 21:06:35 +0000 Subject: [PATCH 23/27] feat: finalized passing parameters to all methods in Distributed --- pylops_mpi/Distributed.py | 154 +++++++++++++++++++++++++++------ pylops_mpi/DistributedArray.py | 21 +++-- 2 files changed, 142 insertions(+), 33 deletions(-) diff --git a/pylops_mpi/Distributed.py b/pylops_mpi/Distributed.py index 5ea3bae3..828e11ef 100644 --- a/pylops_mpi/Distributed.py +++ b/pylops_mpi/Distributed.py @@ -1,4 +1,7 @@ +from typing import Any, NewType, Optional, Union + from mpi4py import MPI +from pylops.utils import NDArray from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils from pylops_mpi.utils._mpi import mpi_allreduce, mpi_allgather, mpi_bcast, mpi_send, mpi_recv, _prepare_allgather_inputs, _unroll_allgather_recv from pylops_mpi.utils import deps @@ -10,6 +13,11 @@ from pylops_mpi.utils._nccl import ( nccl_allgather, nccl_allreduce, nccl_bcast, nccl_send, nccl_recv ) + from cupy.cuda.nccl import NcclCommunicator +else: + NcclCommunicator = Any + +NcclCommunicatorType = NewType("NcclCommunicator", NcclCommunicator) class DistributedMixIn: @@ -23,10 +31,14 @@ class DistributedMixIn: MPI installation is not available). """ - def _allreduce(self, base_comm, base_comm_nccl, - send_buf, recv_buf=None, + def _allreduce(self, + base_comm: MPI.Comm, + base_comm_nccl: NcclCommunicatorType, + send_buf: NDArray, + recv_buf: Optional[NDArray] = None, op: MPI.Op = MPI.SUM, - engine="numpy"): + engine: str = "numpy", + ) -> NDArray: """Allreduce operation Parameters @@ -58,10 +70,14 @@ def _allreduce(self, base_comm, base_comm_nccl, return mpi_allreduce(base_comm, send_buf, recv_buf, engine, op) - def _allreduce_subcomm(self, sub_comm, base_comm_nccl, - send_buf, recv_buf=None, + def _allreduce_subcomm(self, + sub_comm: MPI.Comm, + base_comm_nccl: NcclCommunicatorType, + send_buf: NDArray, + recv_buf: Optional[NDArray] = None, op: MPI.Op = MPI.SUM, - engine="numpy"): + engine: str = "numpy", + ) -> NDArray: """Allreduce operation with subcommunicator Parameters @@ -93,15 +109,19 @@ def _allreduce_subcomm(self, sub_comm, base_comm_nccl, return mpi_allreduce(sub_comm, send_buf, recv_buf, engine, op) - def _allgather(self, base_comm, base_comm_nccl, - send_buf, recv_buf=None, - engine="numpy"): + def _allgather(self, + base_comm: MPI.Comm, + base_comm_nccl: NcclCommunicatorType, + send_buf: NDArray, + recv_buf: Optional[NDArray] = None, + engine: str = "numpy", + ) -> NDArray: """Allgather operation Parameters ---------- - sub_comm : :obj:`MPI.Comm` - MPI Subcommunicator. + base_comm : :obj:`MPI.Comm` + Base MPI Communicator. base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator` NCCL Communicator. send_buf: :obj: `numpy.ndarray` or `cupy.ndarray` @@ -131,41 +151,119 @@ def _allgather(self, base_comm, base_comm_nccl, return base_comm.allgather(send_buf) return mpi_allgather(base_comm, send_buf, recv_buf, engine) - def _allgather_subcomm(self, send_buf, recv_buf=None): + def _allgather_subcomm(self, + sub_comm: MPI.Comm, + base_comm_nccl: NcclCommunicatorType, + send_buf: NDArray, + recv_buf: Optional[NDArray] = None, + engine: str = "numpy", + ) -> NDArray: """Allgather operation with subcommunicator + + Parameters + ---------- + sub_comm : :obj:`MPI.Comm` + MPI Subcommunicator. + base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator` + NCCL Communicator. + send_buf: :obj: `numpy.ndarray` or `cupy.ndarray` + A buffer containing the data to be sent by this rank. + recv_buf : :obj: `numpy.ndarray` or `cupy.ndarray`, optional + The buffer to store the result of the gathering. If None, + a new buffer will be allocated with the appropriate shape. + engine : :obj:`str`, optional + Engine used to store array (``numpy`` or ``cupy``) + + Returns + ------- + recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + A buffer containing the gathered data from all ranks. + """ - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): + if deps.nccl_enabled and base_comm_nccl is not None: if isinstance(send_buf, (tuple, list, int)): - return nccl_allgather(self.sub_comm, send_buf, recv_buf) + return nccl_allgather(sub_comm, send_buf, recv_buf) else: - send_shapes = self._allgather_subcomm(send_buf.shape) + send_shapes = sub_comm._allgather_subcomm(send_buf.shape) (padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine="cupy") - raw_recv = nccl_allgather(self.sub_comm, padded_send, recv_buf if recv_buf else padded_recv) + raw_recv = nccl_allgather(sub_comm, padded_send, recv_buf if recv_buf else padded_recv) return _unroll_allgather_recv(raw_recv, padded_send.shape, send_shapes) else: - return mpi_allgather(self.sub_comm, send_buf, recv_buf, self.engine) + return mpi_allgather(sub_comm, send_buf, recv_buf, engine) - def _bcast(self, local_array, index, value): + def _bcast(self, + base_comm: MPI.Comm, + base_comm_nccl: NcclCommunicatorType, + rank : int, + local_array: NDArray, + index: int, + value: Union[int, NDArray], + engine: str = "numpy", + ) -> None: """BCast operation + + Parameters + ---------- + base_comm : :obj:`MPI.Comm` + Base MPI Communicator. + base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator` + NCCL Communicator. + rank : :obj:`int` + Rank. + local_array : :obj:`numpy.ndarray` + Localy array to be broadcasted. + index : :obj:`int` or :obj:`slice` + Represents the index positions where a value needs to be assigned. + value : :obj:`int` or :obj:`numpy.ndarray` + Represents the value that will be assigned to the local array at + the specified index positions. + engine : :obj:`str`, optional + Engine used to store array (``numpy`` or ``cupy``) + """ - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): - nccl_bcast(self.base_comm_nccl, local_array, index, value) + if deps.nccl_enabled and base_comm_nccl is not None: + nccl_bcast(base_comm_nccl, local_array, index, value) else: - # self.local_array[index] = self.base_comm.bcast(value) - mpi_bcast(self.base_comm, self.rank, self.local_array, index, value, - engine=self.engine) + mpi_bcast(base_comm, rank, local_array, index, value, + engine=engine) - def _send(self, send_buf, dest, count=None, tag=0): + def _send(self, + base_comm: MPI.Comm, + base_comm_nccl: NcclCommunicatorType, + send_buf: NDArray, + dest: int, + count: Optional[int] = None, + tag: int = 0, + engine: str = "numpy", + ) -> None: """Send operation + + Parameters + ---------- + base_comm : :obj:`MPI.Comm` + Base MPI Communicator. + base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator` + NCCL Communicator. + send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + The array containing data to send. + dest: :obj:`int` + The rank of the destination. + count : :obj:`int` + Number of elements to send from `send_buf`. + tag : :obj:`int` + Tag of the message to be sent. + engine : :obj:`str`, optional + Engine used to store array (``numpy`` or ``cupy``) + """ - if deps.nccl_enabled and self.base_comm_nccl: + if deps.nccl_enabled and base_comm_nccl is not None: if count is None: count = send_buf.size - nccl_send(self.base_comm_nccl, send_buf, dest, count) + nccl_send(base_comm_nccl, send_buf, dest, count) else: - mpi_send(self.base_comm, + mpi_send(base_comm, send_buf, dest, count, tag=tag, - engine=self.engine) + engine=engine) def _recv(self, recv_buf=None, source=0, count=None, tag=0): """Receive operation diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index da7712d7..492a49b2 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -204,7 +204,9 @@ def __setitem__(self, index, value): the specified index positions. """ if self.partition is Partition.BROADCAST: - self._bcast(self.local_array, index, value) + self._bcast(self.base_comm, self.base_comm_nccl, + self.rank, self.local_array, + index, value, engine=self.engine) else: self.local_array[index] = value @@ -380,7 +382,10 @@ def asarray(self, masked: bool = False): else: # Gather all the local arrays and apply concatenation. if masked: - final_array = self._allgather_subcomm(self.local_array) + final_array = self._allgather_subcomm(self.sub_comm, + self.base_comm_nccl, + self.local_array, + engine=self.engine) else: final_array = self._allgather(self.base_comm, self.base_comm_nccl, @@ -481,7 +486,9 @@ def _nccl_local_shapes(self, masked: bool): """ # gather tuple of shapes from every rank within thee communicator and copy from GPU to CPU if masked: - all_tuples = self._allgather_subcomm(self.local_shape).get() + all_tuples = self._allgather_subcomm(self.sub_comm, + self.base_comm_nccl, + self.local_shape).get() else: all_tuples = self._allgather(self.base_comm, self.base_comm_nccl, @@ -799,7 +806,9 @@ def add_ghost_cells(self, cells_front: Optional[int] = None, f"{self.local_shape[self.axis]} < {cells_front}; " f"to achieve this use NUM_PROCESSES <= " f"{max(1, self.global_shape[self.axis] // cells_front)}") - self._send(send_buf, dest=self.rank + 1, tag=1) + self._send(self.base_comm, self.base_comm_nccl, + send_buf, dest=self.rank + 1, tag=1, + engine=self.engine) if cells_back is not None: total_cells_back = self.base_comm.allgather(cells_back) + [0] # Read cells_back which needs to be sent to rank - 1(cells_back for rank - 1) @@ -814,7 +823,9 @@ def add_ghost_cells(self, cells_front: Optional[int] = None, f"{self.local_shape[self.axis]} < {cells_back}; " f"to achieve this use NUM_PROCESSES <= " f"{max(1, self.global_shape[self.axis] // cells_back)}") - self._send(send_buf, dest=self.rank - 1, tag=0) + self._send(self.base_comm, self.base_comm_nccl, + send_buf, dest=self.rank - 1, tag=0, + engine=self.engine) if self.rank != self.size - 1: recv_shape = list(recv_shapes[self.rank + 1]) recv_shape[self.axis] = total_cells_back[self.rank] From e33e8f1b35c0b9e246425cbd4888e3b69dd56c85 Mon Sep 17 00:00:00 2001 From: mrava87SW Date: Sat, 25 Oct 2025 21:07:16 +0000 Subject: [PATCH 24/27] doc: added documentation and type hints to _mpi --- pylops_mpi/utils/_mpi.py | 97 +++++++++++++++++++++++++++++++++------- 1 file changed, 81 insertions(+), 16 deletions(-) diff --git a/pylops_mpi/utils/_mpi.py b/pylops_mpi/utils/_mpi.py index 89304b8c..47a2e44b 100644 --- a/pylops_mpi/utils/_mpi.py +++ b/pylops_mpi/utils/_mpi.py @@ -2,31 +2,54 @@ "mpi_allgather", "mpi_allreduce", "mpi_bcast", - # "mpi_asarray", "mpi_send", "mpi_recv", ] -from typing import Optional +from typing import Optional, Union import numpy as np from mpi4py import MPI +from pylops.utils import NDArray from pylops.utils.backend import get_module from pylops_mpi.utils import deps from pylops_mpi.utils._common import _prepare_allgather_inputs, _unroll_allgather_recv def mpi_allgather(base_comm: MPI.Comm, - send_buf, recv_buf=None, - engine: Optional[str] = "numpy") -> np.ndarray: + send_buf: NDArray, + recv_buf: Optional[NDArray] = None, + engine: str = "numpy", + ) -> NDArray: + """MPI_Allallgather/allallgather + Dispatch allgather routine based on type of input and availability of + CUDA-Aware MPI + + Parameters + ---------- + base_comm : :obj:`MPI.Comm` + Base MPI Communicator. + send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + The data buffer from the local rank to be gathered. + recv_buf : :obj:`cupy.ndarray`, optional + The buffer to store the result of the gathering. If None, + a new buffer will be allocated with the appropriate shape. + engine : :obj:`str`, optional + Engine used to store array (``numpy`` or ``cupy``) + + Returns + ------- + recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + A buffer containing the gathered data from all ranks. + + """ if deps.cuda_aware_mpi_enabled or engine == "numpy": send_shapes = base_comm.allgather(send_buf.shape) (padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine=engine) recv_buffer_to_use = recv_buf if recv_buf else padded_recv base_comm.Allgather(padded_send, recv_buffer_to_use) return _unroll_allgather_recv(recv_buffer_to_use, padded_send.shape, send_shapes) - else: # CuPy with non-CUDA-aware MPI if recv_buf is None: @@ -36,9 +59,11 @@ def mpi_allgather(base_comm: MPI.Comm, def mpi_allreduce(base_comm: MPI.Comm, - send_buf, recv_buf=None, - engine: Optional[str] = "numpy", - op: MPI.Op = MPI.SUM) -> np.ndarray: + send_buf: NDArray, + recv_buf: Optional[NDArray] = None, + engine: str = "numpy", + op: MPI.Op = MPI.SUM, + ) -> NDArray: """MPI_Allreduce/allreduce Dispatch allreduce routine based on type of input and availability of @@ -49,7 +74,7 @@ def mpi_allreduce(base_comm: MPI.Comm, base_comm : :obj:`MPI.Comm` Base MPI Communicator. send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` - The data buffer from the local GPU to be reduced. + The data buffer from the local rank to be reduced. recv_buf : :obj:`cupy.ndarray`, optional The buffer to store the result of the reduction. If None, a new buffer will be allocated with the appropriate shape. @@ -62,7 +87,7 @@ def mpi_allreduce(base_comm: MPI.Comm, ------- recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` A buffer containing the result of the reduction, broadcasted - to all GPUs. + to all ranks. """ if deps.cuda_aware_mpi_enabled or engine == "numpy": @@ -80,8 +105,34 @@ def mpi_allreduce(base_comm: MPI.Comm, def mpi_bcast(base_comm: MPI.Comm, - rank, local_array, index, value, - engine: Optional[str] = "numpy") -> np.ndarray: + rank: int, + local_array: NDArray, + index: int, + value: Union[int, NDArray], + engine: Optional[str] = "numpy", + ) -> None: + """MPI_Bcast/bcast + + Dispatch bcast routine based on type of input and availability of + CUDA-Aware MPI + + Parameters + ---------- + base_comm : :obj:`MPI.Comm` + Base MPI Communicator. + rank : :obj:`int` + Rank. + local_array : :obj:`numpy.ndarray` + Localy array to be broadcasted. + index : :obj:`int` or :obj:`slice` + Represents the index positions where a value needs to be assigned. + value : :obj:`int` or :obj:`numpy.ndarray` + Represents the value that will be assigned to the local array at + the specified index positions. + engine : :obj:`str`, optional + Engine used to store array (``numpy`` or ``cupy``) + + """ if deps.cuda_aware_mpi_enabled or engine == "numpy": if rank == 0: local_array[index] = value @@ -92,8 +143,11 @@ def mpi_bcast(base_comm: MPI.Comm, def mpi_send(base_comm: MPI.Comm, - send_buf, dest, count, tag=0, - engine: Optional[str] = "numpy", + send_buf: NDArray, + dest: int, + count: Optional[int] = None, + tag: int = 0, + engine: str = "numpy", ) -> None: """MPI_Send/send @@ -114,6 +168,7 @@ def mpi_send(base_comm: MPI.Comm, Tag of the message to be sent. engine : :obj:`str`, optional Engine used to store array (``numpy`` or ``cupy``) + """ if deps.cuda_aware_mpi_enabled or engine == "numpy": # Determine MPI type based on array dtype @@ -127,8 +182,12 @@ def mpi_send(base_comm: MPI.Comm, def mpi_recv(base_comm: MPI.Comm, - recv_buf=None, source=0, count=None, tag=0, - engine: Optional[str] = "numpy") -> np.ndarray: + recv_buf: Optional[NDArray] = None, + source: int = 0, + count: Optional[int] = None, + tag: int = 0, + engine: Optional[str] = "numpy", + ) -> NDArray: """ MPI_Recv/recv Dispatch receive routine based on type of input and availability of CUDA-Aware MPI @@ -147,6 +206,12 @@ def mpi_recv(base_comm: MPI.Comm, Tag of the message to be sent. engine : :obj:`str`, optional Engine used to store array (``numpy`` or ``cupy``) + + Returns + ------- + recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + The buffer containing the received data. + """ if deps.cuda_aware_mpi_enabled or engine == "numpy": ncp = get_module(engine) From e0fd7163e31cdb2a18c3f5ae4549d7618a37a5c5 Mon Sep 17 00:00:00 2001 From: mrava87SW Date: Sat, 25 Oct 2025 21:08:06 +0000 Subject: [PATCH 25/27] minor: added TEST_CUPY_PYLOPS=0 in tests target --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index ee4b8cb2..1d824ee0 100644 --- a/Makefile +++ b/Makefile @@ -47,7 +47,7 @@ lint: flake8 pylops_mpi/ tests/ examples/ tutorials/ tests: - mpiexec -n $(NUM_PROCESSES) pytest tests/ --with-mpi + export TEST_CUPY_PYLOPS=0 && mpiexec -n $(NUM_PROCESSES) pytest tests/ --with-mpi # assuming NUM_PROCESSES <= number of gpus available tests_gpu: From 83f7a8bd303d14a060c018e0dd992766297fbfb8 Mon Sep 17 00:00:00 2001 From: mrava87SW Date: Sat, 25 Oct 2025 21:17:34 +0000 Subject: [PATCH 26/27] minor: fix flake8 --- pylops_mpi/Distributed.py | 38 +++++++++++++++++++++++++++++----- pylops_mpi/DistributedArray.py | 8 +++++-- pylops_mpi/utils/_mpi.py | 5 ++--- 3 files changed, 41 insertions(+), 10 deletions(-) diff --git a/pylops_mpi/Distributed.py b/pylops_mpi/Distributed.py index 828e11ef..3f1cf068 100644 --- a/pylops_mpi/Distributed.py +++ b/pylops_mpi/Distributed.py @@ -265,17 +265,45 @@ def _send(self, send_buf, dest, count, tag=tag, engine=engine) - def _recv(self, recv_buf=None, source=0, count=None, tag=0): + def _recv(self, + base_comm: MPI.Comm, + base_comm_nccl: NcclCommunicatorType, + recv_buf=None, source=0, count=None, tag=0, + engine: str = "numpy", + ) -> NDArray: """Receive operation + + Parameters + ---------- + base_comm : :obj:`MPI.Comm` + Base MPI Communicator. + base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator` + NCCL Communicator. + recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`, optional + The buffered array to receive data. + source : :obj:`int` + The rank of the sending CPU/GPU device. + count : :obj:`int` + Number of elements to receive. + tag : :obj:`int` + Tag of the message to be sent. + engine : :obj:`str`, optional + Engine used to store array (``numpy`` or ``cupy``) + + Returns + ------- + recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + The buffer containing the received data. + """ - if deps.nccl_enabled and self.base_comm_nccl: + if deps.nccl_enabled and base_comm_nccl is not None: if recv_buf is None: raise ValueError("recv_buf must be supplied when using NCCL") if count is None: count = recv_buf.size - nccl_recv(self.base_comm_nccl, recv_buf, source, count) + nccl_recv(base_comm_nccl, recv_buf, source, count) return recv_buf else: - return mpi_recv(self.base_comm, + return mpi_recv(base_comm, recv_buf, source, count, tag=tag, - engine=self.engine) + engine=engine) diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index 492a49b2..c383e4d0 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -797,7 +797,9 @@ def add_ghost_cells(self, cells_front: Optional[int] = None, # Transfer of ghost cells can be skipped if len(recv_buf) = 0 # Additionally, NCCL will hang if the buffer size is 0 so this optimization is somewhat mandatory if len(recv_buf) != 0: - ghosted_array = ncp.concatenate([self._recv(recv_buf, source=self.rank - 1, tag=1), ghosted_array], axis=self.axis) + ghosted_array = ncp.concatenate([self._recv(self.base_comm, self.base_comm_nccl, + recv_buf, source=self.rank - 1, tag=1, + engine=self.engine), ghosted_array], axis=self.axis) # The skip in sender is to match with what described in receiver if self.rank != self.size - 1 and len(send_buf) != 0: if cells_front > self.local_shape[self.axis]: @@ -831,7 +833,9 @@ def add_ghost_cells(self, cells_front: Optional[int] = None, recv_shape[self.axis] = total_cells_back[self.rank] recv_buf = ncp.zeros(recv_shape, dtype=ghosted_array.dtype) if len(recv_buf) != 0: - ghosted_array = ncp.append(ghosted_array, self._recv(recv_buf, source=self.rank + 1, tag=0), + ghosted_array = ncp.append(ghosted_array, self._recv(self.base_comm, self.base_comm_nccl, + recv_buf, source=self.rank + 1, tag=0, + engine=self.engine), axis=self.axis) return ghosted_array diff --git a/pylops_mpi/utils/_mpi.py b/pylops_mpi/utils/_mpi.py index 47a2e44b..fdc84c5e 100644 --- a/pylops_mpi/utils/_mpi.py +++ b/pylops_mpi/utils/_mpi.py @@ -8,7 +8,6 @@ from typing import Optional, Union -import numpy as np from mpi4py import MPI from pylops.utils import NDArray from pylops.utils.backend import get_module @@ -37,12 +36,12 @@ def mpi_allgather(base_comm: MPI.Comm, a new buffer will be allocated with the appropriate shape. engine : :obj:`str`, optional Engine used to store array (``numpy`` or ``cupy``) - + Returns ------- recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` A buffer containing the gathered data from all ranks. - + """ if deps.cuda_aware_mpi_enabled or engine == "numpy": send_shapes = base_comm.allgather(send_buf.shape) From 02efdbb023f8dad140057a7a45283ec8119ff7a9 Mon Sep 17 00:00:00 2001 From: mrava87SW Date: Sat, 25 Oct 2025 21:23:28 +0000 Subject: [PATCH 27/27] minor: fix more flake8 --- pylops_mpi/utils/_mpi.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pylops_mpi/utils/_mpi.py b/pylops_mpi/utils/_mpi.py index fdc84c5e..d0c8c73f 100644 --- a/pylops_mpi/utils/_mpi.py +++ b/pylops_mpi/utils/_mpi.py @@ -167,7 +167,7 @@ def mpi_send(base_comm: MPI.Comm, Tag of the message to be sent. engine : :obj:`str`, optional Engine used to store array (``numpy`` or ``cupy``) - + """ if deps.cuda_aware_mpi_enabled or engine == "numpy": # Determine MPI type based on array dtype @@ -183,7 +183,7 @@ def mpi_send(base_comm: MPI.Comm, def mpi_recv(base_comm: MPI.Comm, recv_buf: Optional[NDArray] = None, source: int = 0, - count: Optional[int] = None, + count: Optional[int] = None, tag: int = 0, engine: Optional[str] = "numpy", ) -> NDArray: @@ -205,7 +205,7 @@ def mpi_recv(base_comm: MPI.Comm, Tag of the message to be sent. engine : :obj:`str`, optional Engine used to store array (``numpy`` or ``cupy``) - + Returns ------- recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`