diff --git a/Makefile b/Makefile index ee4b8cb2..1d824ee0 100644 --- a/Makefile +++ b/Makefile @@ -47,7 +47,7 @@ lint: flake8 pylops_mpi/ tests/ examples/ tutorials/ tests: - mpiexec -n $(NUM_PROCESSES) pytest tests/ --with-mpi + export TEST_CUPY_PYLOPS=0 && mpiexec -n $(NUM_PROCESSES) pytest tests/ --with-mpi # assuming NUM_PROCESSES <= number of gpus available tests_gpu: diff --git a/docs/source/gpu.rst b/docs/source/gpu.rst index 4ed0218d..c86f847e 100644 --- a/docs/source/gpu.rst +++ b/docs/source/gpu.rst @@ -11,7 +11,7 @@ This library must be installed *before* PyLops-mpi is installed. .. note:: - Set environment variable ``CUPY_PYLOPS=0`` to force PyLops to ignore the ``cupy`` backend. + Set the environment variable ``CUPY_PYLOPS=0`` to force PyLops to ignore the ``cupy`` backend. This can be also used if a previous (or faulty) version of ``cupy`` is installed in your system, otherwise you will get an error when importing PyLops. @@ -22,6 +22,14 @@ can handle both scenarios. Note that, since most operators in PyLops-mpi are thi some of the operators in PyLops that lack a GPU implementation cannot be used also in PyLops-mpi when working with cupy arrays. +.. note:: + + By default when using ``cupy`` arrays, PyLops-MPI will try to use methods in MPI4Py that communicate memory buffers. + However, this requires a CUDA-Aware MPI installation. If your MPI installation is not CUDA-Aware, set the + environment variable ``PYLOPS_MPI_CUDA_AWARE=0`` to force PyLops-MPI to use methods in MPI4Py that communicate + general Python objects (this will incur a loss of performance!). + + Moreover, PyLops-MPI also supports the Nvidia's Collective Communication Library (NCCL) for highly-optimized collective operations, such as AllReduce, AllGather, etc. This allows PyLops-MPI users to leverage the proprietary technology like NVLink that might be available in their infrastructure for fast data communication. @@ -30,13 +38,35 @@ proprietary technology like NVLink that might be available in their infrastructu Set environment variable ``NCCL_PYLOPS_MPI=0`` to explicitly force PyLops-MPI to ignore the ``NCCL`` backend. However, this is optional as users may opt-out for NCCL by skip passing `cupy.cuda.nccl.NcclCommunicator` to - the :class:`pylops_mpi.DistributedArray` + the :class:`pylops_mpi.DistributedArray`. + +In summary: + +.. list-table:: + :widths: 50 25 25 + :header-rows: 1 + + * - Operation model + - Enabled with + - Disabled with + * - NumPy + MPI + - Default + - Cannot be disabled + * - CuPy + MPI + - ``PYLOPS_MPI_CUDA_AWARE=0`` + - ``PYLOPS_MPI_CUDA_AWARE=1`` (default) + * - CuPy + CUDA-Aware MPI + - ``PYLOPS_MPI_CUDA_AWARE=1`` (default) + - ``PYLOPS_MPI_CUDA_AWARE=0`` + * - CuPy + NCCL + - ``NCCL_PYLOPS_MPI=1`` (default) + - ``NCCL_PYLOPS_MPI=0`` Example ------- Finally, let's briefly look at an example. First we write a code snippet using -``numpy`` arrays which PyLops-mpi will run on your CPU: +``numpy`` arrays which PyLops-MPI will run on your CPU: .. code-block:: python @@ -157,6 +187,8 @@ GPU+MPI, and GPU+NCCL): - ✅ - ✅ - ✅ + - ✅ + - ✅ * - :class:`pylops_mpi.basicoperators.MPISecondDerivative` - ✅ - ✅ @@ -184,4 +216,4 @@ GPU+MPI, and GPU+NCCL): * - :class:`pylops_mpi.optimization.basic.cgls` - ✅ - ✅ - - ✅ \ No newline at end of file + - ✅ diff --git a/docs/source/installation.rst b/docs/source/installation.rst index d0aafe88..e1d7faf3 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -15,7 +15,13 @@ The minimal set of dependencies for the PyLops-MPI project is: * `MPI4py `_ * `PyLops `_ -Additionally, to use the NCCL engine, the following additional +Additionally, to use the CUDA-aware MPI engine, the following additional +dependencies are required: + +* `CuPy `_ +* CUDA-aware MPI + +Similarly, to use the NCCL engine, the following additional dependencies are required: * `CuPy `_ @@ -27,12 +33,18 @@ if this is not possible, some of the dependencies must be installed prior to ins Download and Install MPI ======================== -Visit the official MPI website to download an appropriate MPI implementation for your system. -Follow the installation instructions provided by the MPI vendor. +Visit the official website of your MPI vendor of choice to download an appropriate MPI +implementation for your system: + +* `Open MPI `_ +* `MPICH `_ +* `Intel MPI `_ +* ... -* `Open MPI `_ -* `MPICH `_ -* `Intel MPI `_ +Alternatively, the conda-forge community provides ready-to-use binary packages for four MPI implementations +(see `MPI4Py documentation `_ for more +details). In this case, you can defer the installation to the stage when the conda environment for your project +is created - see below for more details. Verify MPI Installation ======================= @@ -42,6 +54,17 @@ After installing MPI, verify its installation by opening a terminal and running >> mpiexec --version +Install CUDA-Aware MPI (optional) +================================= +To be able to achieve the best performance when using PyLops-MPI with CuPy arrays, a CUDA-Aware version of +MPI must be installed. + +For `Open MPI`, the conda-forge package has built-in CUDA support, as long as a pre-installed CUDA is detected. +Run the following `commands `_ +for diagnostics. + +For the other MPI implementations, refer to their specific documentation. + Install NCCL (optional) ======================= To obtain highly-optimized performance on GPU clusters, PyLops-MPI also supports the Nvidia's collective communication calls @@ -103,6 +126,15 @@ For a ``conda`` environment, run This will create and activate an environment called ``pylops_mpi``, with all required and optional dependencies. +If you want to also install MPI as part of the creation process of the conda environment, +modify the ``environment-dev.yml`` file by adding ``openmpi``\``mpich`\``impi_rt``\``msmpi`` +just above ``mpi4py``. Note that only ``openmpi`` provides a CUDA-Aware MPI installation. + +If you want to leverage CUDA-Aware MPI but prefer to use another MPI installation, you must +either switch to a `Pip`-based installation (see below), or move ``mpi4py`` into the ``pip`` +section of the ``environment-dev.yml`` file and export the variable ``MPICC`` pointing to +the path of your CUDA-Aware MPI installation. + If you want to enable `NCCL `_ in PyLops-MPI, run this instead .. code-block:: bash diff --git a/pylops_mpi/Distributed.py b/pylops_mpi/Distributed.py new file mode 100644 index 00000000..3f1cf068 --- /dev/null +++ b/pylops_mpi/Distributed.py @@ -0,0 +1,309 @@ +from typing import Any, NewType, Optional, Union + +from mpi4py import MPI +from pylops.utils import NDArray +from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils +from pylops_mpi.utils._mpi import mpi_allreduce, mpi_allgather, mpi_bcast, mpi_send, mpi_recv, _prepare_allgather_inputs, _unroll_allgather_recv +from pylops_mpi.utils import deps + +cupy_message = pylops_deps.cupy_import("the DistributedArray module") +nccl_message = deps.nccl_import("the DistributedArray module") + +if nccl_message is None and cupy_message is None: + from pylops_mpi.utils._nccl import ( + nccl_allgather, nccl_allreduce, nccl_bcast, nccl_send, nccl_recv + ) + from cupy.cuda.nccl import NcclCommunicator +else: + NcclCommunicator = Any + +NcclCommunicatorType = NewType("NcclCommunicator", NcclCommunicator) + + +class DistributedMixIn: + r"""Distributed Mixin class + + This class implements all methods associated with communication primitives + from MPI and NCCL. It is mostly charged with identifying which commuicator + to use and whether the buffered or object MPI primitives should be used + (the former in the case of NumPy arrays or CuPy arrays when a CUDA-Aware + MPI installation is available, the latter with CuPy arrays when a CUDA-Aware + MPI installation is not available). + + """ + def _allreduce(self, + base_comm: MPI.Comm, + base_comm_nccl: NcclCommunicatorType, + send_buf: NDArray, + recv_buf: Optional[NDArray] = None, + op: MPI.Op = MPI.SUM, + engine: str = "numpy", + ) -> NDArray: + """Allreduce operation + + Parameters + ---------- + base_comm : :obj:`MPI.Comm` + Base MPI Communicator. + base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator` + NCCL Communicator. + send_buf: :obj: `numpy.ndarray` or `cupy.ndarray` + A buffer containing the data to be sent by this rank. + recv_buf : :obj: `numpy.ndarray` or `cupy.ndarray`, optional + The buffer to store the result of the reduction. If None, + a new buffer will be allocated with the appropriate shape. + op : :obj: `MPI.Op`, optional + MPI operation to perform. + engine : :obj:`str`, optional + Engine used to store array (``numpy`` or ``cupy``) + + Returns + ------- + recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + A buffer containing the result of the reduction, broadcasted + to all GPUs. + + """ + if deps.nccl_enabled and base_comm_nccl is not None: + return nccl_allreduce(base_comm_nccl, send_buf, recv_buf, op) + else: + return mpi_allreduce(base_comm, send_buf, + recv_buf, engine, op) + + def _allreduce_subcomm(self, + sub_comm: MPI.Comm, + base_comm_nccl: NcclCommunicatorType, + send_buf: NDArray, + recv_buf: Optional[NDArray] = None, + op: MPI.Op = MPI.SUM, + engine: str = "numpy", + ) -> NDArray: + """Allreduce operation with subcommunicator + + Parameters + ---------- + sub_comm : :obj:`MPI.Comm` + MPI Subcommunicator. + base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator` + NCCL Communicator. + send_buf: :obj: `numpy.ndarray` or `cupy.ndarray` + A buffer containing the data to be sent by this rank. + recv_buf : :obj: `numpy.ndarray` or `cupy.ndarray`, optional + The buffer to store the result of the reduction. If None, + a new buffer will be allocated with the appropriate shape. + op : :obj: `MPI.Op`, optional + MPI operation to perform. + engine : :obj:`str`, optional + Engine used to store array (``numpy`` or ``cupy``) + + Returns + ------- + recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + A buffer containing the result of the reduction, broadcasted + to all ranks. + + """ + if deps.nccl_enabled and base_comm_nccl is not None: + return nccl_allreduce(sub_comm, send_buf, recv_buf, op) + else: + return mpi_allreduce(sub_comm, send_buf, + recv_buf, engine, op) + + def _allgather(self, + base_comm: MPI.Comm, + base_comm_nccl: NcclCommunicatorType, + send_buf: NDArray, + recv_buf: Optional[NDArray] = None, + engine: str = "numpy", + ) -> NDArray: + """Allgather operation + + Parameters + ---------- + base_comm : :obj:`MPI.Comm` + Base MPI Communicator. + base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator` + NCCL Communicator. + send_buf: :obj: `numpy.ndarray` or `cupy.ndarray` + A buffer containing the data to be sent by this rank. + recv_buf : :obj: `numpy.ndarray` or `cupy.ndarray`, optional + The buffer to store the result of the gathering. If None, + a new buffer will be allocated with the appropriate shape. + engine : :obj:`str`, optional + Engine used to store array (``numpy`` or ``cupy``) + + Returns + ------- + recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + A buffer containing the gathered data from all ranks. + + """ + if deps.nccl_enabled and base_comm_nccl is not None: + if isinstance(send_buf, (tuple, list, int)): + return nccl_allgather(base_comm_nccl, send_buf, recv_buf) + else: + send_shapes = base_comm.allgather(send_buf.shape) + (padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine="cupy") + raw_recv = nccl_allgather(base_comm_nccl, padded_send, recv_buf if recv_buf else padded_recv) + return _unroll_allgather_recv(raw_recv, padded_send.shape, send_shapes) + else: + if isinstance(send_buf, (tuple, list, int)): + return base_comm.allgather(send_buf) + return mpi_allgather(base_comm, send_buf, recv_buf, engine) + + def _allgather_subcomm(self, + sub_comm: MPI.Comm, + base_comm_nccl: NcclCommunicatorType, + send_buf: NDArray, + recv_buf: Optional[NDArray] = None, + engine: str = "numpy", + ) -> NDArray: + """Allgather operation with subcommunicator + + Parameters + ---------- + sub_comm : :obj:`MPI.Comm` + MPI Subcommunicator. + base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator` + NCCL Communicator. + send_buf: :obj: `numpy.ndarray` or `cupy.ndarray` + A buffer containing the data to be sent by this rank. + recv_buf : :obj: `numpy.ndarray` or `cupy.ndarray`, optional + The buffer to store the result of the gathering. If None, + a new buffer will be allocated with the appropriate shape. + engine : :obj:`str`, optional + Engine used to store array (``numpy`` or ``cupy``) + + Returns + ------- + recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + A buffer containing the gathered data from all ranks. + + """ + if deps.nccl_enabled and base_comm_nccl is not None: + if isinstance(send_buf, (tuple, list, int)): + return nccl_allgather(sub_comm, send_buf, recv_buf) + else: + send_shapes = sub_comm._allgather_subcomm(send_buf.shape) + (padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine="cupy") + raw_recv = nccl_allgather(sub_comm, padded_send, recv_buf if recv_buf else padded_recv) + return _unroll_allgather_recv(raw_recv, padded_send.shape, send_shapes) + else: + return mpi_allgather(sub_comm, send_buf, recv_buf, engine) + + def _bcast(self, + base_comm: MPI.Comm, + base_comm_nccl: NcclCommunicatorType, + rank : int, + local_array: NDArray, + index: int, + value: Union[int, NDArray], + engine: str = "numpy", + ) -> None: + """BCast operation + + Parameters + ---------- + base_comm : :obj:`MPI.Comm` + Base MPI Communicator. + base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator` + NCCL Communicator. + rank : :obj:`int` + Rank. + local_array : :obj:`numpy.ndarray` + Localy array to be broadcasted. + index : :obj:`int` or :obj:`slice` + Represents the index positions where a value needs to be assigned. + value : :obj:`int` or :obj:`numpy.ndarray` + Represents the value that will be assigned to the local array at + the specified index positions. + engine : :obj:`str`, optional + Engine used to store array (``numpy`` or ``cupy``) + + """ + if deps.nccl_enabled and base_comm_nccl is not None: + nccl_bcast(base_comm_nccl, local_array, index, value) + else: + mpi_bcast(base_comm, rank, local_array, index, value, + engine=engine) + + def _send(self, + base_comm: MPI.Comm, + base_comm_nccl: NcclCommunicatorType, + send_buf: NDArray, + dest: int, + count: Optional[int] = None, + tag: int = 0, + engine: str = "numpy", + ) -> None: + """Send operation + + Parameters + ---------- + base_comm : :obj:`MPI.Comm` + Base MPI Communicator. + base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator` + NCCL Communicator. + send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + The array containing data to send. + dest: :obj:`int` + The rank of the destination. + count : :obj:`int` + Number of elements to send from `send_buf`. + tag : :obj:`int` + Tag of the message to be sent. + engine : :obj:`str`, optional + Engine used to store array (``numpy`` or ``cupy``) + + """ + if deps.nccl_enabled and base_comm_nccl is not None: + if count is None: + count = send_buf.size + nccl_send(base_comm_nccl, send_buf, dest, count) + else: + mpi_send(base_comm, + send_buf, dest, count, tag=tag, + engine=engine) + + def _recv(self, + base_comm: MPI.Comm, + base_comm_nccl: NcclCommunicatorType, + recv_buf=None, source=0, count=None, tag=0, + engine: str = "numpy", + ) -> NDArray: + """Receive operation + + Parameters + ---------- + base_comm : :obj:`MPI.Comm` + Base MPI Communicator. + base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator` + NCCL Communicator. + recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`, optional + The buffered array to receive data. + source : :obj:`int` + The rank of the sending CPU/GPU device. + count : :obj:`int` + Number of elements to receive. + tag : :obj:`int` + Tag of the message to be sent. + engine : :obj:`str`, optional + Engine used to store array (``numpy`` or ``cupy``) + + Returns + ------- + recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + The buffer containing the received data. + + """ + if deps.nccl_enabled and base_comm_nccl is not None: + if recv_buf is None: + raise ValueError("recv_buf must be supplied when using NCCL") + if count is None: + count = recv_buf.size + nccl_recv(base_comm_nccl, recv_buf, source, count) + return recv_buf + else: + return mpi_recv(base_comm, + recv_buf, source, count, tag=tag, + engine=engine) diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index 979882c0..c383e4d0 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -4,6 +4,7 @@ import numpy as np from mpi4py import MPI +from pylops_mpi.Distributed import DistributedMixIn from pylops.utils import DTypeLike, NDArray from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils from pylops.utils._internal import _value_or_sized_to_tuple @@ -14,7 +15,7 @@ nccl_message = deps.nccl_import("the DistributedArray module") if nccl_message is None and cupy_message is None: - from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv, _prepare_nccl_allgather_inputs, _unroll_nccl_allgather_recv + from pylops_mpi.utils._nccl import nccl_asarray, nccl_split from cupy.cuda.nccl import NcclCommunicator else: NcclCommunicator = Any @@ -99,7 +100,7 @@ def subcomm_split(mask, comm: Optional[Union[MPI.Comm, NcclCommunicatorType]] = return sub_comm -class DistributedArray: +class DistributedArray(DistributedMixIn): r"""Distributed Numpy Arrays Multidimensional NumPy-like distributed arrays. @@ -203,10 +204,9 @@ def __setitem__(self, index, value): the specified index positions. """ if self.partition is Partition.BROADCAST: - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): - nccl_bcast(self.base_comm_nccl, self.local_array, index, value) - else: - self.local_array[index] = self.base_comm.bcast(value) + self._bcast(self.base_comm, self.base_comm_nccl, + self.rank, self.local_array, + index, value, engine=self.engine) else: self.local_array[index] = value @@ -342,7 +342,9 @@ def local_shapes(self): if deps.nccl_enabled and getattr(self, "base_comm_nccl"): return self._nccl_local_shapes(False) else: - return self._allgather(self.local_shape) + return self._allgather(self.base_comm, + self.base_comm_nccl, + self.local_shape) @property def sub_comm(self): @@ -380,9 +382,15 @@ def asarray(self, masked: bool = False): else: # Gather all the local arrays and apply concatenation. if masked: - final_array = self._allgather_subcomm(self.local_array) + final_array = self._allgather_subcomm(self.sub_comm, + self.base_comm_nccl, + self.local_array, + engine=self.engine) else: - final_array = self._allgather(self.local_array) + final_array = self._allgather(self.base_comm, + self.base_comm_nccl, + self.local_array, + engine=self.engine) return np.concatenate(final_array, axis=self.axis) @classmethod @@ -432,6 +440,7 @@ def to_dist(cls, x: NDArray, else: slices = [slice(None)] * x.ndim local_shapes = np.append([0], dist_array._allgather( + base_comm, base_comm_nccl, dist_array.local_shape[axis])) sum_shapes = np.cumsum(local_shapes) slices[axis] = slice(sum_shapes[dist_array.rank], @@ -472,100 +481,18 @@ def _check_mask(self, dist_array): if not np.array_equal(self.mask, dist_array.mask): raise ValueError("Mask of both the arrays must be same") - def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): - """Allreduce operation - """ - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): - return nccl_allreduce(self.base_comm_nccl, send_buf, recv_buf, op) - else: - if recv_buf is None: - return self.base_comm.allreduce(send_buf, op) - # For MIN and MAX which require recv_buf - self.base_comm.Allreduce(send_buf, recv_buf, op) - return recv_buf - - def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM): - """Allreduce operation with subcommunicator - """ - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): - return nccl_allreduce(self.sub_comm, send_buf, recv_buf, op) - else: - if recv_buf is None: - return self.sub_comm.allreduce(send_buf, op) - # For MIN and MAX which require recv_buf - self.sub_comm.Allreduce(send_buf, recv_buf, op) - return recv_buf - - def _allgather(self, send_buf, recv_buf=None): - """Allgather operation - """ - if deps.nccl_enabled and self.base_comm_nccl: - if isinstance(send_buf, (tuple, list, int)): - return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf) - else: - send_shapes = self.base_comm.allgather(send_buf.shape) - (padded_send, padded_recv) = _prepare_nccl_allgather_inputs(send_buf, send_shapes) - raw_recv = nccl_allgather(self.base_comm_nccl, padded_send, recv_buf if recv_buf else padded_recv) - return _unroll_nccl_allgather_recv(raw_recv, padded_send.shape, send_shapes) - else: - if recv_buf is None: - return self.base_comm.allgather(send_buf) - self.base_comm.Allgather(send_buf, recv_buf) - return recv_buf - - def _allgather_subcomm(self, send_buf, recv_buf=None): - """Allgather operation with subcommunicator - """ - if deps.nccl_enabled and getattr(self, "base_comm_nccl"): - if isinstance(send_buf, (tuple, list, int)): - return nccl_allgather(self.sub_comm, send_buf, recv_buf) - else: - send_shapes = self._allgather_subcomm(send_buf.shape) - (padded_send, padded_recv) = _prepare_nccl_allgather_inputs(send_buf, send_shapes) - raw_recv = nccl_allgather(self.sub_comm, padded_send, recv_buf if recv_buf else padded_recv) - return _unroll_nccl_allgather_recv(raw_recv, padded_send.shape, send_shapes) - else: - if recv_buf is None: - return self.sub_comm.allgather(send_buf) - self.sub_comm.Allgather(send_buf, recv_buf) - - def _send(self, send_buf, dest, count=None, tag=None): - """ Send operation - """ - if deps.nccl_enabled and self.base_comm_nccl: - if count is None: - # assuming sending the whole array - count = send_buf.size - nccl_send(self.base_comm_nccl, send_buf, dest, count) - else: - self.base_comm.send(send_buf, dest, tag) - - def _recv(self, recv_buf=None, source=0, count=None, tag=None): - """ Receive operation - """ - # NCCL must be called with recv_buf. Size cannot be inferred from - # other arguments and thus cannot be dynamically allocated - if deps.nccl_enabled and self.base_comm_nccl and recv_buf is not None: - if recv_buf is not None: - if count is None: - # assuming data will take a space of the whole buffer - count = recv_buf.size - nccl_recv(self.base_comm_nccl, recv_buf, source, count) - return recv_buf - else: - raise ValueError("Using recv with NCCL must also supply receiver buffer ") - else: - # MPI allows a receiver buffer to be optional and receives as a Python Object - return self.base_comm.recv(source=source, tag=tag) - def _nccl_local_shapes(self, masked: bool): """Get the the list of shapes of every GPU in the communicator """ # gather tuple of shapes from every rank within thee communicator and copy from GPU to CPU if masked: - all_tuples = self._allgather_subcomm(self.local_shape).get() + all_tuples = self._allgather_subcomm(self.sub_comm, + self.base_comm_nccl, + self.local_shape).get() else: - all_tuples = self._allgather(self.local_shape).get() + all_tuples = self._allgather(self.base_comm, + self.base_comm_nccl, + self.local_shape).get() # NCCL returns the flat array that packs every tuple as 1-dimensional array # unpack each tuple from each rank tuple_len = len(self.local_shape) @@ -663,7 +590,9 @@ def dot(self, dist_array): y = DistributedArray.to_dist(x=dist_array.local_array, base_comm=self.base_comm, base_comm_nccl=self.base_comm_nccl) \ if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else dist_array # Flatten the local arrays and calculate dot product - return self._allreduce_subcomm(ncp.dot(x.local_array.flatten(), y.local_array.flatten())) + return self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, + ncp.dot(x.local_array.flatten(), y.local_array.flatten()), + engine=self.engine) def _compute_vector_norm(self, local_array: NDArray, axis: int, ord: Optional[int] = None): @@ -691,31 +620,49 @@ def _compute_vector_norm(self, local_array: NDArray, raise ValueError(f"norm-{ord} not possible for vectors") elif ord == 0: # Count non-zero then sum reduction - recv_buf = self._allreduce_subcomm(ncp.count_nonzero(local_array, axis=axis).astype(ncp.float64)) + recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, + ncp.count_nonzero(local_array, axis=axis).astype(ncp.float64), + engine=self.engine) elif ord == ncp.inf: # Calculate max followed by max reduction - # TODO (tharitt): currently CuPy + MPI does not work well with buffered communication, particularly + # CuPy + non-CUDA-aware MPI does not work well with buffered communication, particularly # with MAX, MIN operator. Here we copy the array back to CPU, transfer, and copy them back to GPUs send_buf = ncp.max(ncp.abs(local_array), axis=axis).astype(ncp.float64) - if self.engine == "cupy" and self.base_comm_nccl is None: - recv_buf = self._allreduce_subcomm(send_buf.get(), recv_buf.get(), op=MPI.MAX) + if self.engine == "cupy" and self.base_comm_nccl is None and not deps.cuda_aware_mpi_enabled: + # CuPy + non-CUDA-aware MPI: This will call non-buffered communication + # which return a list of object - must be copied back to a GPU memory. + recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, + send_buf.get(), recv_buf.get(), + op=MPI.MAX, engine=self.engine) recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) else: - recv_buf = self._allreduce_subcomm(send_buf, recv_buf, op=MPI.MAX) - recv_buf = ncp.squeeze(recv_buf, axis=axis) + recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, + send_buf, recv_buf, op=MPI.MAX, + engine=self.engine) + # TODO (tharitt): In current implementation, there seems to be a semantic difference between Buffered MPI and NCCL + # the (1, size) is collapsed to (size, ) with buffered MPI while NCCL retains it. + # There may be a way to unify it - may be something to do with how we allocate the recv_buf. + if self.base_comm_nccl: + recv_buf = ncp.squeeze(recv_buf, axis=axis) elif ord == -ncp.inf: # Calculate min followed by min reduction - # TODO (tharitt): see the comment above in infinity norm + # See the comment above in +infinity norm send_buf = ncp.min(ncp.abs(local_array), axis=axis).astype(ncp.float64) - if self.engine == "cupy" and self.base_comm_nccl is None: - recv_buf = self._allreduce_subcomm(send_buf.get(), recv_buf.get(), op=MPI.MIN) + if self.engine == "cupy" and self.base_comm_nccl is None and not deps.cuda_aware_mpi_enabled: + recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, + send_buf.get(), recv_buf.get(), + op=MPI.MIN, engine=self.engine) recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) else: - recv_buf = self._allreduce_subcomm(send_buf, recv_buf, op=MPI.MIN) - recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) - + recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, + send_buf, recv_buf, + op=MPI.MIN, engine=self.engine) + if self.base_comm_nccl: + recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis)) else: - recv_buf = self._allreduce_subcomm(ncp.sum(ncp.abs(ncp.float_power(local_array, ord)), axis=axis)) + recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl, + ncp.sum(ncp.abs(ncp.float_power(local_array, ord)), axis=axis), + engine=self.engine) recv_buf = ncp.power(recv_buf, 1.0 / ord) return recv_buf @@ -850,7 +797,9 @@ def add_ghost_cells(self, cells_front: Optional[int] = None, # Transfer of ghost cells can be skipped if len(recv_buf) = 0 # Additionally, NCCL will hang if the buffer size is 0 so this optimization is somewhat mandatory if len(recv_buf) != 0: - ghosted_array = ncp.concatenate([self._recv(recv_buf, source=self.rank - 1, tag=1), ghosted_array], axis=self.axis) + ghosted_array = ncp.concatenate([self._recv(self.base_comm, self.base_comm_nccl, + recv_buf, source=self.rank - 1, tag=1, + engine=self.engine), ghosted_array], axis=self.axis) # The skip in sender is to match with what described in receiver if self.rank != self.size - 1 and len(send_buf) != 0: if cells_front > self.local_shape[self.axis]: @@ -859,7 +808,9 @@ def add_ghost_cells(self, cells_front: Optional[int] = None, f"{self.local_shape[self.axis]} < {cells_front}; " f"to achieve this use NUM_PROCESSES <= " f"{max(1, self.global_shape[self.axis] // cells_front)}") - self._send(send_buf, dest=self.rank + 1, tag=1) + self._send(self.base_comm, self.base_comm_nccl, + send_buf, dest=self.rank + 1, tag=1, + engine=self.engine) if cells_back is not None: total_cells_back = self.base_comm.allgather(cells_back) + [0] # Read cells_back which needs to be sent to rank - 1(cells_back for rank - 1) @@ -874,13 +825,17 @@ def add_ghost_cells(self, cells_front: Optional[int] = None, f"{self.local_shape[self.axis]} < {cells_back}; " f"to achieve this use NUM_PROCESSES <= " f"{max(1, self.global_shape[self.axis] // cells_back)}") - self._send(send_buf, dest=self.rank - 1, tag=0) + self._send(self.base_comm, self.base_comm_nccl, + send_buf, dest=self.rank - 1, tag=0, + engine=self.engine) if self.rank != self.size - 1: recv_shape = list(recv_shapes[self.rank + 1]) recv_shape[self.axis] = total_cells_back[self.rank] recv_buf = ncp.zeros(recv_shape, dtype=ghosted_array.dtype) if len(recv_buf) != 0: - ghosted_array = ncp.append(ghosted_array, self._recv(recv_buf, source=self.rank + 1, tag=0), + ghosted_array = ncp.append(ghosted_array, self._recv(self.base_comm, self.base_comm_nccl, + recv_buf, source=self.rank + 1, tag=0, + engine=self.engine), axis=self.axis) return ghosted_array diff --git a/pylops_mpi/basicoperators/VStack.py b/pylops_mpi/basicoperators/VStack.py index 58581565..de66c342 100644 --- a/pylops_mpi/basicoperators/VStack.py +++ b/pylops_mpi/basicoperators/VStack.py @@ -6,7 +6,6 @@ from pylops import LinearOperator from pylops.utils import DTypeLike from pylops.utils.backend import get_module -from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils from pylops_mpi import ( MPILinearOperator, @@ -15,17 +14,11 @@ Partition, StackedDistributedArray ) +from pylops_mpi.Distributed import DistributedMixIn from pylops_mpi.utils.decorators import reshaped -from pylops_mpi.utils import deps -cupy_message = pylops_deps.cupy_import("the VStack module") -nccl_message = deps.nccl_import("the VStack module") -if nccl_message is None and cupy_message is None: - from pylops_mpi.utils._nccl import nccl_allreduce - - -class MPIVStack(MPILinearOperator): +class MPIVStack(DistributedMixIn, MPILinearOperator): r"""MPI VStack Operator Create a vertical stack of a set of linear operators using MPI. Each rank must @@ -141,16 +134,18 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: @reshaped(forward=False, stacking=True) def _rmatvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=self.shape[1], base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, partition=Partition.BROADCAST, - engine=x.engine, dtype=self.dtype) + y = DistributedArray(global_shape=self.shape[1], + base_comm=x.base_comm, + base_comm_nccl=x.base_comm_nccl, + partition=Partition.BROADCAST, + engine=x.engine, + dtype=self.dtype) y1 = [] for iop, oper in enumerate(self.ops): y1.append(oper.rmatvec(x.local_array[self.nnops[iop]: self.nnops[iop + 1]])) y1 = ncp.sum(ncp.vstack(y1), axis=0) - if deps.nccl_enabled and x.base_comm_nccl: - y[:] = nccl_allreduce(x.base_comm_nccl, y1, op=MPI.SUM) - else: - y[:] = self.base_comm.allreduce(y1, op=MPI.SUM) + y[:] = self._allreduce(x.base_comm, x.base_comm_nccl, + y1, op=MPI.SUM, engine=x.engine) return y diff --git a/pylops_mpi/signalprocessing/Fredholm1.py b/pylops_mpi/signalprocessing/Fredholm1.py index 6ccd9d21..2969e3c9 100644 --- a/pylops_mpi/signalprocessing/Fredholm1.py +++ b/pylops_mpi/signalprocessing/Fredholm1.py @@ -128,7 +128,8 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: for isl in range(self.nsls[self.rank]): y1[isl] = ncp.dot(self.G[isl], x[isl]) # gather results - y[:] = ncp.vstack(y._allgather(y1)).ravel() + y[:] = ncp.vstack(y._allgather(y.base_comm, y.base_comm_nccl, y1, + engine=y.engine)).ravel() return y def _rmatvec(self, x: NDArray) -> NDArray: @@ -165,5 +166,6 @@ def _rmatvec(self, x: NDArray) -> NDArray: y1[isl] = ncp.dot(x[isl].T.conj(), self.G[isl]).T.conj() # gather results - y[:] = ncp.vstack(y._allgather(y1)).ravel() + y[:] = ncp.vstack(y._allgather(y.base_comm, y.base_comm_nccl, y1, + engine=y.engine)).ravel() return y diff --git a/pylops_mpi/utils/_common.py b/pylops_mpi/utils/_common.py new file mode 100644 index 00000000..895265df --- /dev/null +++ b/pylops_mpi/utils/_common.py @@ -0,0 +1,89 @@ +__all__ = [ + "_prepare_allgather_inputs", + "_unroll_allgather_recv" +] + + +import numpy as np +from pylops.utils.backend import get_module + + +# TODO: return type annotation for both cupy and numpy +def _prepare_allgather_inputs(send_buf, send_buf_shapes, engine): + r""" Prepare send_buf and recv_buf for NCCL allgather (nccl_allgather) + + Buffered Allgather (MPI and NCCL) requires the sending buffer to have the same size for every device. + Therefore, padding is required when the array is not evenly partitioned across + all the ranks. The padding is applied such that the each dimension of the sending buffers + is equal to the max size of that dimension across all ranks. + + Similarly, each receiver buffer (recv_buf) is created with size equal to :math:n_rank \cdot send_buf.size + + Parameters + ---------- + send_buf : :obj: `numpy.ndarray` or `cupy.ndarray` or array-like + The data buffer from the local GPU to be sent for allgather. + send_buf_shapes: :obj:`list` + A list of shapes for each GPU send_buf (used to calculate padding size) + engine : :obj:`str` + Engine used to store array (``numpy`` or ``cupy``) + + Returns + ------- + send_buf: :obj:`cupy.ndarray` + A buffer containing the data and padded elements to be sent by this rank. + recv_buf : :obj:`cupy.ndarray` + An empty, padded buffer to gather data from all GPUs. + """ + ncp = get_module(engine) + sizes_each_dim = list(zip(*send_buf_shapes)) + send_shape = tuple(map(max, sizes_each_dim)) + pad_size = [ + (0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, send_buf.shape) + ] + + send_buf = ncp.pad( + send_buf, pad_size, mode="constant", constant_values=0 + ) + + ndev = len(send_buf_shapes) + recv_buf = ncp.zeros(ndev * send_buf.size, dtype=send_buf.dtype) + + return send_buf, recv_buf + + +def _unroll_allgather_recv(recv_buf, padded_send_buf_shape, send_buf_shapes) -> list: + r"""Unrolll recv_buf after Buffered Allgather (MPI and NCCL) + + Remove the padded elements in recv_buff, extract an individual array from each device and return them as a list of arrays + Each GPU may send array with a different shape, so the return type has to be a list of array + instead of the concatenated array. + + Parameters + ---------- + recv_buf: :obj:`cupy.ndarray` or array-like + The data buffer returned from nccl_allgather call + padded_send_buf_shape: :obj:`tuple`:int + The size of send_buf after padding used in nccl_allgather + send_buf_shapes: :obj:`list` + A list of original shapes for each GPU send_buf prior to padding + + Returns + ------- + chunks: :obj:`list` + A list of `cupy.ndarray` from each GPU with the padded element removed + """ + ndev = len(send_buf_shapes) + # extract an individual array from each device + chunk_size = np.prod(padded_send_buf_shape) + chunks = [ + recv_buf[i * chunk_size:(i + 1) * chunk_size] for i in range(ndev) + ] + + # Remove padding from each array: the padded value may appear somewhere + # in the middle of the flat array and thus the reshape and slicing for each dimension is required + for i in range(ndev): + slicing = tuple(slice(0, end) for end in send_buf_shapes[i]) + chunks[i] = chunks[i].reshape(padded_send_buf_shape)[slicing] + + return chunks diff --git a/pylops_mpi/utils/_mpi.py b/pylops_mpi/utils/_mpi.py new file mode 100644 index 00000000..d0c8c73f --- /dev/null +++ b/pylops_mpi/utils/_mpi.py @@ -0,0 +1,229 @@ +__all__ = [ + "mpi_allgather", + "mpi_allreduce", + "mpi_bcast", + "mpi_send", + "mpi_recv", +] + +from typing import Optional, Union + +from mpi4py import MPI +from pylops.utils import NDArray +from pylops.utils.backend import get_module +from pylops_mpi.utils import deps +from pylops_mpi.utils._common import _prepare_allgather_inputs, _unroll_allgather_recv + + +def mpi_allgather(base_comm: MPI.Comm, + send_buf: NDArray, + recv_buf: Optional[NDArray] = None, + engine: str = "numpy", + ) -> NDArray: + """MPI_Allallgather/allallgather + + Dispatch allgather routine based on type of input and availability of + CUDA-Aware MPI + + Parameters + ---------- + base_comm : :obj:`MPI.Comm` + Base MPI Communicator. + send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + The data buffer from the local rank to be gathered. + recv_buf : :obj:`cupy.ndarray`, optional + The buffer to store the result of the gathering. If None, + a new buffer will be allocated with the appropriate shape. + engine : :obj:`str`, optional + Engine used to store array (``numpy`` or ``cupy``) + + Returns + ------- + recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + A buffer containing the gathered data from all ranks. + + """ + if deps.cuda_aware_mpi_enabled or engine == "numpy": + send_shapes = base_comm.allgather(send_buf.shape) + (padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine=engine) + recv_buffer_to_use = recv_buf if recv_buf else padded_recv + base_comm.Allgather(padded_send, recv_buffer_to_use) + return _unroll_allgather_recv(recv_buffer_to_use, padded_send.shape, send_shapes) + else: + # CuPy with non-CUDA-aware MPI + if recv_buf is None: + return base_comm.allgather(send_buf) + base_comm.Allgather(send_buf, recv_buf) + return recv_buf + + +def mpi_allreduce(base_comm: MPI.Comm, + send_buf: NDArray, + recv_buf: Optional[NDArray] = None, + engine: str = "numpy", + op: MPI.Op = MPI.SUM, + ) -> NDArray: + """MPI_Allreduce/allreduce + + Dispatch allreduce routine based on type of input and availability of + CUDA-Aware MPI + + Parameters + ---------- + base_comm : :obj:`MPI.Comm` + Base MPI Communicator. + send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + The data buffer from the local rank to be reduced. + recv_buf : :obj:`cupy.ndarray`, optional + The buffer to store the result of the reduction. If None, + a new buffer will be allocated with the appropriate shape. + engine : :obj:`str`, optional + Engine used to store array (``numpy`` or ``cupy``) + op : :obj:mpi4py.MPI.Op, optional + The reduction operation to apply. Defaults to MPI.SUM. + + Returns + ------- + recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + A buffer containing the result of the reduction, broadcasted + to all ranks. + + """ + if deps.cuda_aware_mpi_enabled or engine == "numpy": + ncp = get_module(engine) + recv_buf = ncp.zeros(send_buf.size, dtype=send_buf.dtype) + base_comm.Allreduce(send_buf, recv_buf, op) + return recv_buf + else: + # CuPy with non-CUDA-aware MPI + if recv_buf is None: + return base_comm.allreduce(send_buf, op) + # For MIN and MAX which require recv_buf + base_comm.Allreduce(send_buf, recv_buf, op) + return recv_buf + + +def mpi_bcast(base_comm: MPI.Comm, + rank: int, + local_array: NDArray, + index: int, + value: Union[int, NDArray], + engine: Optional[str] = "numpy", + ) -> None: + """MPI_Bcast/bcast + + Dispatch bcast routine based on type of input and availability of + CUDA-Aware MPI + + Parameters + ---------- + base_comm : :obj:`MPI.Comm` + Base MPI Communicator. + rank : :obj:`int` + Rank. + local_array : :obj:`numpy.ndarray` + Localy array to be broadcasted. + index : :obj:`int` or :obj:`slice` + Represents the index positions where a value needs to be assigned. + value : :obj:`int` or :obj:`numpy.ndarray` + Represents the value that will be assigned to the local array at + the specified index positions. + engine : :obj:`str`, optional + Engine used to store array (``numpy`` or ``cupy``) + + """ + if deps.cuda_aware_mpi_enabled or engine == "numpy": + if rank == 0: + local_array[index] = value + base_comm.Bcast(local_array[index]) + else: + # CuPy with non-CUDA-aware MPI + local_array[index] = base_comm.bcast(value) + + +def mpi_send(base_comm: MPI.Comm, + send_buf: NDArray, + dest: int, + count: Optional[int] = None, + tag: int = 0, + engine: str = "numpy", + ) -> None: + """MPI_Send/send + + Dispatch send routine based on type of input and availability of + CUDA-Aware MPI + + Parameters + ---------- + base_comm : :obj:`MPI.Comm` + Base MPI Communicator. + send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + The array containing data to send. + dest: :obj:`int` + The rank of the destination CPU/GPU device. + count : :obj:`int` + Number of elements to send from `send_buf`. + tag : :obj:`int` + Tag of the message to be sent. + engine : :obj:`str`, optional + Engine used to store array (``numpy`` or ``cupy``) + + """ + if deps.cuda_aware_mpi_enabled or engine == "numpy": + # Determine MPI type based on array dtype + mpi_type = MPI._typedict[send_buf.dtype.char] + if count is None: + count = send_buf.size + base_comm.Send([send_buf, count, mpi_type], dest=dest, tag=tag) + else: + # Uses CuPy without CUDA-aware MPI + base_comm.send(send_buf, dest, tag) + + +def mpi_recv(base_comm: MPI.Comm, + recv_buf: Optional[NDArray] = None, + source: int = 0, + count: Optional[int] = None, + tag: int = 0, + engine: Optional[str] = "numpy", + ) -> NDArray: + """ MPI_Recv/recv + Dispatch receive routine based on type of input and availability of + CUDA-Aware MPI + + Parameters + ---------- + base_comm : :obj:`MPI.Comm` + Base MPI Communicator. + recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`, optional + The buffered array to receive data. + source : :obj:`int` + The rank of the sending CPU/GPU device. + count : :obj:`int` + Number of elements to receive. + tag : :obj:`int` + Tag of the message to be sent. + engine : :obj:`str`, optional + Engine used to store array (``numpy`` or ``cupy``) + + Returns + ------- + recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` + The buffer containing the received data. + + """ + if deps.cuda_aware_mpi_enabled or engine == "numpy": + ncp = get_module(engine) + if recv_buf is None: + if count is None: + raise ValueError("Must provide either recv_buf or count for MPI receive") + # Default to int32 works currently because add_ghost_cells() is called + # with recv_buf and is not affected by this branch. The int32 is for when + # dimension or shape-related integers are send/recv + recv_buf = ncp.zeros(count, dtype=ncp.int32) + mpi_type = MPI._typedict[recv_buf.dtype.char] + base_comm.Recv([recv_buf, recv_buf.size, mpi_type], source=source, tag=tag) + else: + # Uses CuPy without CUDA-aware MPI + recv_buf = base_comm.recv(source=source, tag=tag) + return recv_buf diff --git a/pylops_mpi/utils/_nccl.py b/pylops_mpi/utils/_nccl.py index 19c09922..cac5b61c 100644 --- a/pylops_mpi/utils/_nccl.py +++ b/pylops_mpi/utils/_nccl.py @@ -1,6 +1,4 @@ __all__ = [ - "_prepare_nccl_allgather_inputs", - "_unroll_nccl_allgather_recv", "_nccl_sync", "initialize_nccl_comm", "nccl_split", @@ -13,12 +11,11 @@ ] from enum import IntEnum -from typing import Tuple from mpi4py import MPI import os -import numpy as np import cupy as cp import cupy.cuda.nccl as nccl +from pylops_mpi.utils._common import _prepare_allgather_inputs, _unroll_allgather_recv cupy_to_nccl_dtype = { "float32": nccl.NCCL_FLOAT32, @@ -70,85 +67,6 @@ def _nccl_sync(): cp.cuda.runtime.deviceSynchronize() -def _prepare_nccl_allgather_inputs(send_buf, send_buf_shapes) -> Tuple[cp.ndarray, cp.ndarray]: - r""" Prepare send_buf and recv_buf for NCCL allgather (nccl_allgather) - - NCCL's allGather requires the sending buffer to have the same size for every device. - Therefore, padding is required when the array is not evenly partitioned across - all the ranks. The padding is applied such that the each dimension of the sending buffers - is equal to the max size of that dimension across all ranks. - - Similarly, each receiver buffer (recv_buf) is created with size equal to :math:n_rank \cdot send_buf.size - - Parameters - ---------- - send_buf : :obj:`cupy.ndarray` or array-like - The data buffer from the local GPU to be sent for allgather. - send_buf_shapes: :obj:`list` - A list of shapes for each GPU send_buf (used to calculate padding size) - - Returns - ------- - send_buf: :obj:`cupy.ndarray` - A buffer containing the data and padded elements to be sent by this rank. - recv_buf : :obj:`cupy.ndarray` - An empty, padded buffer to gather data from all GPUs. - """ - sizes_each_dim = list(zip(*send_buf_shapes)) - send_shape = tuple(map(max, sizes_each_dim)) - pad_size = [ - (0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, send_buf.shape) - ] - - send_buf = cp.pad( - send_buf, pad_size, mode="constant", constant_values=0 - ) - - # NCCL recommends to use one MPI Process per GPU and so size of receiving buffer can be inferred - ndev = len(send_buf_shapes) - recv_buf = cp.zeros(ndev * send_buf.size, dtype=send_buf.dtype) - - return send_buf, recv_buf - - -def _unroll_nccl_allgather_recv(recv_buf, padded_send_buf_shape, send_buf_shapes) -> list: - """Unrolll recv_buf after NCCL allgather (nccl_allgather) - - Remove the padded elements in recv_buff, extract an individual array from each device and return them as a list of arrays - Each GPU may send array with a different shape, so the return type has to be a list of array - instead of the concatenated array. - - Parameters - ---------- - recv_buf: :obj:`cupy.ndarray` or array-like - The data buffer returned from nccl_allgather call - padded_send_buf_shape: :obj:`tuple`:int - The size of send_buf after padding used in nccl_allgather - send_buf_shapes: :obj:`list` - A list of original shapes for each GPU send_buf prior to padding - - Returns - ------- - chunks: :obj:`list` - A list of `cupy.ndarray` from each GPU with the padded element removed - """ - - ndev = len(send_buf_shapes) - # extract an individual array from each device - chunk_size = np.prod(padded_send_buf_shape) - chunks = [ - recv_buf[i * chunk_size:(i + 1) * chunk_size] for i in range(ndev) - ] - - # Remove padding from each array: the padded value may appear somewhere - # in the middle of the flat array and thus the reshape and slicing for each dimension is required - for i in range(ndev): - slicing = tuple(slice(0, end) for end in send_buf_shapes[i]) - chunks[i] = chunks[i].reshape(padded_send_buf_shape)[slicing] - - return chunks - - def mpi_op_to_nccl(mpi_op) -> NcclOp: """ Map MPI reduction operation to NCCL equivalent @@ -363,9 +281,9 @@ def nccl_asarray(nccl_comm, local_array, local_shapes, axis) -> cp.ndarray: Global array gathered from all GPUs and concatenated along `axis`. """ - send_buf, recv_buf = _prepare_nccl_allgather_inputs(local_array, local_shapes) + send_buf, recv_buf = _prepare_allgather_inputs(local_array, local_shapes, engine="cupy") nccl_allgather(nccl_comm, send_buf, recv_buf) - chunks = _unroll_nccl_allgather_recv(recv_buf, send_buf.shape, local_shapes) + chunks = _unroll_allgather_recv(recv_buf, send_buf.shape, local_shapes) # combine back to single global array return cp.concatenate(chunks, axis=axis) diff --git a/pylops_mpi/utils/deps.py b/pylops_mpi/utils/deps.py index 9d983f60..f0279ceb 100644 --- a/pylops_mpi/utils/deps.py +++ b/pylops_mpi/utils/deps.py @@ -39,6 +39,10 @@ def nccl_import(message: Optional[str] = None) -> str: return nccl_message +cuda_aware_mpi_enabled: bool = ( + True if int(os.getenv("PYLOPS_MPI_CUDA_AWARE", 1)) == 1 else False +) + nccl_enabled: bool = ( True if (nccl_import() is None and int(os.getenv("NCCL_PYLOPS_MPI", 1)) == 1) else False ) diff --git a/tests/test_matrixmult.py b/tests/test_matrixmult.py index a596ffe9..20a11783 100644 --- a/tests/test_matrixmult.py +++ b/tests/test_matrixmult.py @@ -219,7 +219,8 @@ def test_MPIMatrixMult_summa(N, K, M, dtype_str): local_shapes=comm.allgather(X_p.shape[0] * X_p.shape[1]), partition=Partition.SCATTER, base_comm=comm, - dtype=dtype + dtype=dtype, + engine=backend, ) x_dist.local_array[:] = X_p.ravel() diff --git a/tests_nccl/test_ncclutils_nccl.py b/tests_nccl/test_ncclutils_nccl.py index 21b28ca3..eaae0a69 100644 --- a/tests_nccl/test_ncclutils_nccl.py +++ b/tests_nccl/test_ncclutils_nccl.py @@ -8,7 +8,8 @@ from numpy.testing import assert_allclose import pytest -from pylops_mpi.utils._nccl import initialize_nccl_comm, nccl_allgather, _prepare_nccl_allgather_inputs, _unroll_nccl_allgather_recv +from pylops_mpi.utils._nccl import initialize_nccl_comm, nccl_allgather +from pylops_mpi.utils._common import _prepare_allgather_inputs, _unroll_allgather_recv np.random.seed(42) @@ -83,9 +84,9 @@ def test_allgather_differentsize_withrecbuf(par): # Gathered array send_shapes = MPI.COMM_WORLD.allgather(local_array.shape) - send_buf, recv_buf = _prepare_nccl_allgather_inputs(local_array, send_shapes) + send_buf, recv_buf = _prepare_allgather_inputs(local_array, send_shapes, engine="cupy") recv_buf = nccl_allgather(nccl_comm, send_buf, recv_buf) - chunks = _unroll_nccl_allgather_recv(recv_buf, send_buf.shape, send_shapes) + chunks = _unroll_allgather_recv(recv_buf, send_buf.shape, send_shapes) gathered_array = cp.concatenate(chunks) # Compare with global array created in rank0