diff --git a/Makefile b/Makefile
index ee4b8cb2..1d824ee0 100644
--- a/Makefile
+++ b/Makefile
@@ -47,7 +47,7 @@ lint:
flake8 pylops_mpi/ tests/ examples/ tutorials/
tests:
- mpiexec -n $(NUM_PROCESSES) pytest tests/ --with-mpi
+ export TEST_CUPY_PYLOPS=0 && mpiexec -n $(NUM_PROCESSES) pytest tests/ --with-mpi
# assuming NUM_PROCESSES <= number of gpus available
tests_gpu:
diff --git a/docs/source/gpu.rst b/docs/source/gpu.rst
index 4ed0218d..c86f847e 100644
--- a/docs/source/gpu.rst
+++ b/docs/source/gpu.rst
@@ -11,7 +11,7 @@ This library must be installed *before* PyLops-mpi is installed.
.. note::
- Set environment variable ``CUPY_PYLOPS=0`` to force PyLops to ignore the ``cupy`` backend.
+ Set the environment variable ``CUPY_PYLOPS=0`` to force PyLops to ignore the ``cupy`` backend.
This can be also used if a previous (or faulty) version of ``cupy`` is installed in your system,
otherwise you will get an error when importing PyLops.
@@ -22,6 +22,14 @@ can handle both scenarios. Note that, since most operators in PyLops-mpi are thi
some of the operators in PyLops that lack a GPU implementation cannot be used also in PyLops-mpi when working with
cupy arrays.
+.. note::
+
+ By default when using ``cupy`` arrays, PyLops-MPI will try to use methods in MPI4Py that communicate memory buffers.
+ However, this requires a CUDA-Aware MPI installation. If your MPI installation is not CUDA-Aware, set the
+ environment variable ``PYLOPS_MPI_CUDA_AWARE=0`` to force PyLops-MPI to use methods in MPI4Py that communicate
+ general Python objects (this will incur a loss of performance!).
+
+
Moreover, PyLops-MPI also supports the Nvidia's Collective Communication Library (NCCL) for highly-optimized
collective operations, such as AllReduce, AllGather, etc. This allows PyLops-MPI users to leverage the
proprietary technology like NVLink that might be available in their infrastructure for fast data communication.
@@ -30,13 +38,35 @@ proprietary technology like NVLink that might be available in their infrastructu
Set environment variable ``NCCL_PYLOPS_MPI=0`` to explicitly force PyLops-MPI to ignore the ``NCCL`` backend.
However, this is optional as users may opt-out for NCCL by skip passing `cupy.cuda.nccl.NcclCommunicator` to
- the :class:`pylops_mpi.DistributedArray`
+ the :class:`pylops_mpi.DistributedArray`.
+
+In summary:
+
+.. list-table::
+ :widths: 50 25 25
+ :header-rows: 1
+
+ * - Operation model
+ - Enabled with
+ - Disabled with
+ * - NumPy + MPI
+ - Default
+ - Cannot be disabled
+ * - CuPy + MPI
+ - ``PYLOPS_MPI_CUDA_AWARE=0``
+ - ``PYLOPS_MPI_CUDA_AWARE=1`` (default)
+ * - CuPy + CUDA-Aware MPI
+ - ``PYLOPS_MPI_CUDA_AWARE=1`` (default)
+ - ``PYLOPS_MPI_CUDA_AWARE=0``
+ * - CuPy + NCCL
+ - ``NCCL_PYLOPS_MPI=1`` (default)
+ - ``NCCL_PYLOPS_MPI=0``
Example
-------
Finally, let's briefly look at an example. First we write a code snippet using
-``numpy`` arrays which PyLops-mpi will run on your CPU:
+``numpy`` arrays which PyLops-MPI will run on your CPU:
.. code-block:: python
@@ -157,6 +187,8 @@ GPU+MPI, and GPU+NCCL):
- ✅
- ✅
- ✅
+ - ✅
+ - ✅
* - :class:`pylops_mpi.basicoperators.MPISecondDerivative`
- ✅
- ✅
@@ -184,4 +216,4 @@ GPU+MPI, and GPU+NCCL):
* - :class:`pylops_mpi.optimization.basic.cgls`
- ✅
- ✅
- - ✅
\ No newline at end of file
+ - ✅
diff --git a/docs/source/installation.rst b/docs/source/installation.rst
index d0aafe88..e1d7faf3 100644
--- a/docs/source/installation.rst
+++ b/docs/source/installation.rst
@@ -15,7 +15,13 @@ The minimal set of dependencies for the PyLops-MPI project is:
* `MPI4py `_
* `PyLops `_
-Additionally, to use the NCCL engine, the following additional
+Additionally, to use the CUDA-aware MPI engine, the following additional
+dependencies are required:
+
+* `CuPy `_
+* CUDA-aware MPI
+
+Similarly, to use the NCCL engine, the following additional
dependencies are required:
* `CuPy `_
@@ -27,12 +33,18 @@ if this is not possible, some of the dependencies must be installed prior to ins
Download and Install MPI
========================
-Visit the official MPI website to download an appropriate MPI implementation for your system.
-Follow the installation instructions provided by the MPI vendor.
+Visit the official website of your MPI vendor of choice to download an appropriate MPI
+implementation for your system:
+
+* `Open MPI `_
+* `MPICH `_
+* `Intel MPI `_
+* ...
-* `Open MPI `_
-* `MPICH `_
-* `Intel MPI `_
+Alternatively, the conda-forge community provides ready-to-use binary packages for four MPI implementations
+(see `MPI4Py documentation `_ for more
+details). In this case, you can defer the installation to the stage when the conda environment for your project
+is created - see below for more details.
Verify MPI Installation
=======================
@@ -42,6 +54,17 @@ After installing MPI, verify its installation by opening a terminal and running
>> mpiexec --version
+Install CUDA-Aware MPI (optional)
+=================================
+To be able to achieve the best performance when using PyLops-MPI with CuPy arrays, a CUDA-Aware version of
+MPI must be installed.
+
+For `Open MPI`, the conda-forge package has built-in CUDA support, as long as a pre-installed CUDA is detected.
+Run the following `commands `_
+for diagnostics.
+
+For the other MPI implementations, refer to their specific documentation.
+
Install NCCL (optional)
=======================
To obtain highly-optimized performance on GPU clusters, PyLops-MPI also supports the Nvidia's collective communication calls
@@ -103,6 +126,15 @@ For a ``conda`` environment, run
This will create and activate an environment called ``pylops_mpi``, with all
required and optional dependencies.
+If you want to also install MPI as part of the creation process of the conda environment,
+modify the ``environment-dev.yml`` file by adding ``openmpi``\``mpich`\``impi_rt``\``msmpi``
+just above ``mpi4py``. Note that only ``openmpi`` provides a CUDA-Aware MPI installation.
+
+If you want to leverage CUDA-Aware MPI but prefer to use another MPI installation, you must
+either switch to a `Pip`-based installation (see below), or move ``mpi4py`` into the ``pip``
+section of the ``environment-dev.yml`` file and export the variable ``MPICC`` pointing to
+the path of your CUDA-Aware MPI installation.
+
If you want to enable `NCCL `_ in PyLops-MPI, run this instead
.. code-block:: bash
diff --git a/pylops_mpi/Distributed.py b/pylops_mpi/Distributed.py
new file mode 100644
index 00000000..3f1cf068
--- /dev/null
+++ b/pylops_mpi/Distributed.py
@@ -0,0 +1,309 @@
+from typing import Any, NewType, Optional, Union
+
+from mpi4py import MPI
+from pylops.utils import NDArray
+from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils
+from pylops_mpi.utils._mpi import mpi_allreduce, mpi_allgather, mpi_bcast, mpi_send, mpi_recv, _prepare_allgather_inputs, _unroll_allgather_recv
+from pylops_mpi.utils import deps
+
+cupy_message = pylops_deps.cupy_import("the DistributedArray module")
+nccl_message = deps.nccl_import("the DistributedArray module")
+
+if nccl_message is None and cupy_message is None:
+ from pylops_mpi.utils._nccl import (
+ nccl_allgather, nccl_allreduce, nccl_bcast, nccl_send, nccl_recv
+ )
+ from cupy.cuda.nccl import NcclCommunicator
+else:
+ NcclCommunicator = Any
+
+NcclCommunicatorType = NewType("NcclCommunicator", NcclCommunicator)
+
+
+class DistributedMixIn:
+ r"""Distributed Mixin class
+
+ This class implements all methods associated with communication primitives
+ from MPI and NCCL. It is mostly charged with identifying which commuicator
+ to use and whether the buffered or object MPI primitives should be used
+ (the former in the case of NumPy arrays or CuPy arrays when a CUDA-Aware
+ MPI installation is available, the latter with CuPy arrays when a CUDA-Aware
+ MPI installation is not available).
+
+ """
+ def _allreduce(self,
+ base_comm: MPI.Comm,
+ base_comm_nccl: NcclCommunicatorType,
+ send_buf: NDArray,
+ recv_buf: Optional[NDArray] = None,
+ op: MPI.Op = MPI.SUM,
+ engine: str = "numpy",
+ ) -> NDArray:
+ """Allreduce operation
+
+ Parameters
+ ----------
+ base_comm : :obj:`MPI.Comm`
+ Base MPI Communicator.
+ base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`
+ NCCL Communicator.
+ send_buf: :obj: `numpy.ndarray` or `cupy.ndarray`
+ A buffer containing the data to be sent by this rank.
+ recv_buf : :obj: `numpy.ndarray` or `cupy.ndarray`, optional
+ The buffer to store the result of the reduction. If None,
+ a new buffer will be allocated with the appropriate shape.
+ op : :obj: `MPI.Op`, optional
+ MPI operation to perform.
+ engine : :obj:`str`, optional
+ Engine used to store array (``numpy`` or ``cupy``)
+
+ Returns
+ -------
+ recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
+ A buffer containing the result of the reduction, broadcasted
+ to all GPUs.
+
+ """
+ if deps.nccl_enabled and base_comm_nccl is not None:
+ return nccl_allreduce(base_comm_nccl, send_buf, recv_buf, op)
+ else:
+ return mpi_allreduce(base_comm, send_buf,
+ recv_buf, engine, op)
+
+ def _allreduce_subcomm(self,
+ sub_comm: MPI.Comm,
+ base_comm_nccl: NcclCommunicatorType,
+ send_buf: NDArray,
+ recv_buf: Optional[NDArray] = None,
+ op: MPI.Op = MPI.SUM,
+ engine: str = "numpy",
+ ) -> NDArray:
+ """Allreduce operation with subcommunicator
+
+ Parameters
+ ----------
+ sub_comm : :obj:`MPI.Comm`
+ MPI Subcommunicator.
+ base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`
+ NCCL Communicator.
+ send_buf: :obj: `numpy.ndarray` or `cupy.ndarray`
+ A buffer containing the data to be sent by this rank.
+ recv_buf : :obj: `numpy.ndarray` or `cupy.ndarray`, optional
+ The buffer to store the result of the reduction. If None,
+ a new buffer will be allocated with the appropriate shape.
+ op : :obj: `MPI.Op`, optional
+ MPI operation to perform.
+ engine : :obj:`str`, optional
+ Engine used to store array (``numpy`` or ``cupy``)
+
+ Returns
+ -------
+ recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
+ A buffer containing the result of the reduction, broadcasted
+ to all ranks.
+
+ """
+ if deps.nccl_enabled and base_comm_nccl is not None:
+ return nccl_allreduce(sub_comm, send_buf, recv_buf, op)
+ else:
+ return mpi_allreduce(sub_comm, send_buf,
+ recv_buf, engine, op)
+
+ def _allgather(self,
+ base_comm: MPI.Comm,
+ base_comm_nccl: NcclCommunicatorType,
+ send_buf: NDArray,
+ recv_buf: Optional[NDArray] = None,
+ engine: str = "numpy",
+ ) -> NDArray:
+ """Allgather operation
+
+ Parameters
+ ----------
+ base_comm : :obj:`MPI.Comm`
+ Base MPI Communicator.
+ base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`
+ NCCL Communicator.
+ send_buf: :obj: `numpy.ndarray` or `cupy.ndarray`
+ A buffer containing the data to be sent by this rank.
+ recv_buf : :obj: `numpy.ndarray` or `cupy.ndarray`, optional
+ The buffer to store the result of the gathering. If None,
+ a new buffer will be allocated with the appropriate shape.
+ engine : :obj:`str`, optional
+ Engine used to store array (``numpy`` or ``cupy``)
+
+ Returns
+ -------
+ recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
+ A buffer containing the gathered data from all ranks.
+
+ """
+ if deps.nccl_enabled and base_comm_nccl is not None:
+ if isinstance(send_buf, (tuple, list, int)):
+ return nccl_allgather(base_comm_nccl, send_buf, recv_buf)
+ else:
+ send_shapes = base_comm.allgather(send_buf.shape)
+ (padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine="cupy")
+ raw_recv = nccl_allgather(base_comm_nccl, padded_send, recv_buf if recv_buf else padded_recv)
+ return _unroll_allgather_recv(raw_recv, padded_send.shape, send_shapes)
+ else:
+ if isinstance(send_buf, (tuple, list, int)):
+ return base_comm.allgather(send_buf)
+ return mpi_allgather(base_comm, send_buf, recv_buf, engine)
+
+ def _allgather_subcomm(self,
+ sub_comm: MPI.Comm,
+ base_comm_nccl: NcclCommunicatorType,
+ send_buf: NDArray,
+ recv_buf: Optional[NDArray] = None,
+ engine: str = "numpy",
+ ) -> NDArray:
+ """Allgather operation with subcommunicator
+
+ Parameters
+ ----------
+ sub_comm : :obj:`MPI.Comm`
+ MPI Subcommunicator.
+ base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`
+ NCCL Communicator.
+ send_buf: :obj: `numpy.ndarray` or `cupy.ndarray`
+ A buffer containing the data to be sent by this rank.
+ recv_buf : :obj: `numpy.ndarray` or `cupy.ndarray`, optional
+ The buffer to store the result of the gathering. If None,
+ a new buffer will be allocated with the appropriate shape.
+ engine : :obj:`str`, optional
+ Engine used to store array (``numpy`` or ``cupy``)
+
+ Returns
+ -------
+ recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
+ A buffer containing the gathered data from all ranks.
+
+ """
+ if deps.nccl_enabled and base_comm_nccl is not None:
+ if isinstance(send_buf, (tuple, list, int)):
+ return nccl_allgather(sub_comm, send_buf, recv_buf)
+ else:
+ send_shapes = sub_comm._allgather_subcomm(send_buf.shape)
+ (padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine="cupy")
+ raw_recv = nccl_allgather(sub_comm, padded_send, recv_buf if recv_buf else padded_recv)
+ return _unroll_allgather_recv(raw_recv, padded_send.shape, send_shapes)
+ else:
+ return mpi_allgather(sub_comm, send_buf, recv_buf, engine)
+
+ def _bcast(self,
+ base_comm: MPI.Comm,
+ base_comm_nccl: NcclCommunicatorType,
+ rank : int,
+ local_array: NDArray,
+ index: int,
+ value: Union[int, NDArray],
+ engine: str = "numpy",
+ ) -> None:
+ """BCast operation
+
+ Parameters
+ ----------
+ base_comm : :obj:`MPI.Comm`
+ Base MPI Communicator.
+ base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`
+ NCCL Communicator.
+ rank : :obj:`int`
+ Rank.
+ local_array : :obj:`numpy.ndarray`
+ Localy array to be broadcasted.
+ index : :obj:`int` or :obj:`slice`
+ Represents the index positions where a value needs to be assigned.
+ value : :obj:`int` or :obj:`numpy.ndarray`
+ Represents the value that will be assigned to the local array at
+ the specified index positions.
+ engine : :obj:`str`, optional
+ Engine used to store array (``numpy`` or ``cupy``)
+
+ """
+ if deps.nccl_enabled and base_comm_nccl is not None:
+ nccl_bcast(base_comm_nccl, local_array, index, value)
+ else:
+ mpi_bcast(base_comm, rank, local_array, index, value,
+ engine=engine)
+
+ def _send(self,
+ base_comm: MPI.Comm,
+ base_comm_nccl: NcclCommunicatorType,
+ send_buf: NDArray,
+ dest: int,
+ count: Optional[int] = None,
+ tag: int = 0,
+ engine: str = "numpy",
+ ) -> None:
+ """Send operation
+
+ Parameters
+ ----------
+ base_comm : :obj:`MPI.Comm`
+ Base MPI Communicator.
+ base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`
+ NCCL Communicator.
+ send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
+ The array containing data to send.
+ dest: :obj:`int`
+ The rank of the destination.
+ count : :obj:`int`
+ Number of elements to send from `send_buf`.
+ tag : :obj:`int`
+ Tag of the message to be sent.
+ engine : :obj:`str`, optional
+ Engine used to store array (``numpy`` or ``cupy``)
+
+ """
+ if deps.nccl_enabled and base_comm_nccl is not None:
+ if count is None:
+ count = send_buf.size
+ nccl_send(base_comm_nccl, send_buf, dest, count)
+ else:
+ mpi_send(base_comm,
+ send_buf, dest, count, tag=tag,
+ engine=engine)
+
+ def _recv(self,
+ base_comm: MPI.Comm,
+ base_comm_nccl: NcclCommunicatorType,
+ recv_buf=None, source=0, count=None, tag=0,
+ engine: str = "numpy",
+ ) -> NDArray:
+ """Receive operation
+
+ Parameters
+ ----------
+ base_comm : :obj:`MPI.Comm`
+ Base MPI Communicator.
+ base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`
+ NCCL Communicator.
+ recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`, optional
+ The buffered array to receive data.
+ source : :obj:`int`
+ The rank of the sending CPU/GPU device.
+ count : :obj:`int`
+ Number of elements to receive.
+ tag : :obj:`int`
+ Tag of the message to be sent.
+ engine : :obj:`str`, optional
+ Engine used to store array (``numpy`` or ``cupy``)
+
+ Returns
+ -------
+ recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
+ The buffer containing the received data.
+
+ """
+ if deps.nccl_enabled and base_comm_nccl is not None:
+ if recv_buf is None:
+ raise ValueError("recv_buf must be supplied when using NCCL")
+ if count is None:
+ count = recv_buf.size
+ nccl_recv(base_comm_nccl, recv_buf, source, count)
+ return recv_buf
+ else:
+ return mpi_recv(base_comm,
+ recv_buf, source, count, tag=tag,
+ engine=engine)
diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py
index 979882c0..c383e4d0 100644
--- a/pylops_mpi/DistributedArray.py
+++ b/pylops_mpi/DistributedArray.py
@@ -4,6 +4,7 @@
import numpy as np
from mpi4py import MPI
+from pylops_mpi.Distributed import DistributedMixIn
from pylops.utils import DTypeLike, NDArray
from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils
from pylops.utils._internal import _value_or_sized_to_tuple
@@ -14,7 +15,7 @@
nccl_message = deps.nccl_import("the DistributedArray module")
if nccl_message is None and cupy_message is None:
- from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv, _prepare_nccl_allgather_inputs, _unroll_nccl_allgather_recv
+ from pylops_mpi.utils._nccl import nccl_asarray, nccl_split
from cupy.cuda.nccl import NcclCommunicator
else:
NcclCommunicator = Any
@@ -99,7 +100,7 @@ def subcomm_split(mask, comm: Optional[Union[MPI.Comm, NcclCommunicatorType]] =
return sub_comm
-class DistributedArray:
+class DistributedArray(DistributedMixIn):
r"""Distributed Numpy Arrays
Multidimensional NumPy-like distributed arrays.
@@ -203,10 +204,9 @@ def __setitem__(self, index, value):
the specified index positions.
"""
if self.partition is Partition.BROADCAST:
- if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
- nccl_bcast(self.base_comm_nccl, self.local_array, index, value)
- else:
- self.local_array[index] = self.base_comm.bcast(value)
+ self._bcast(self.base_comm, self.base_comm_nccl,
+ self.rank, self.local_array,
+ index, value, engine=self.engine)
else:
self.local_array[index] = value
@@ -342,7 +342,9 @@ def local_shapes(self):
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
return self._nccl_local_shapes(False)
else:
- return self._allgather(self.local_shape)
+ return self._allgather(self.base_comm,
+ self.base_comm_nccl,
+ self.local_shape)
@property
def sub_comm(self):
@@ -380,9 +382,15 @@ def asarray(self, masked: bool = False):
else:
# Gather all the local arrays and apply concatenation.
if masked:
- final_array = self._allgather_subcomm(self.local_array)
+ final_array = self._allgather_subcomm(self.sub_comm,
+ self.base_comm_nccl,
+ self.local_array,
+ engine=self.engine)
else:
- final_array = self._allgather(self.local_array)
+ final_array = self._allgather(self.base_comm,
+ self.base_comm_nccl,
+ self.local_array,
+ engine=self.engine)
return np.concatenate(final_array, axis=self.axis)
@classmethod
@@ -432,6 +440,7 @@ def to_dist(cls, x: NDArray,
else:
slices = [slice(None)] * x.ndim
local_shapes = np.append([0], dist_array._allgather(
+ base_comm, base_comm_nccl,
dist_array.local_shape[axis]))
sum_shapes = np.cumsum(local_shapes)
slices[axis] = slice(sum_shapes[dist_array.rank],
@@ -472,100 +481,18 @@ def _check_mask(self, dist_array):
if not np.array_equal(self.mask, dist_array.mask):
raise ValueError("Mask of both the arrays must be same")
- def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
- """Allreduce operation
- """
- if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
- return nccl_allreduce(self.base_comm_nccl, send_buf, recv_buf, op)
- else:
- if recv_buf is None:
- return self.base_comm.allreduce(send_buf, op)
- # For MIN and MAX which require recv_buf
- self.base_comm.Allreduce(send_buf, recv_buf, op)
- return recv_buf
-
- def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
- """Allreduce operation with subcommunicator
- """
- if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
- return nccl_allreduce(self.sub_comm, send_buf, recv_buf, op)
- else:
- if recv_buf is None:
- return self.sub_comm.allreduce(send_buf, op)
- # For MIN and MAX which require recv_buf
- self.sub_comm.Allreduce(send_buf, recv_buf, op)
- return recv_buf
-
- def _allgather(self, send_buf, recv_buf=None):
- """Allgather operation
- """
- if deps.nccl_enabled and self.base_comm_nccl:
- if isinstance(send_buf, (tuple, list, int)):
- return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf)
- else:
- send_shapes = self.base_comm.allgather(send_buf.shape)
- (padded_send, padded_recv) = _prepare_nccl_allgather_inputs(send_buf, send_shapes)
- raw_recv = nccl_allgather(self.base_comm_nccl, padded_send, recv_buf if recv_buf else padded_recv)
- return _unroll_nccl_allgather_recv(raw_recv, padded_send.shape, send_shapes)
- else:
- if recv_buf is None:
- return self.base_comm.allgather(send_buf)
- self.base_comm.Allgather(send_buf, recv_buf)
- return recv_buf
-
- def _allgather_subcomm(self, send_buf, recv_buf=None):
- """Allgather operation with subcommunicator
- """
- if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
- if isinstance(send_buf, (tuple, list, int)):
- return nccl_allgather(self.sub_comm, send_buf, recv_buf)
- else:
- send_shapes = self._allgather_subcomm(send_buf.shape)
- (padded_send, padded_recv) = _prepare_nccl_allgather_inputs(send_buf, send_shapes)
- raw_recv = nccl_allgather(self.sub_comm, padded_send, recv_buf if recv_buf else padded_recv)
- return _unroll_nccl_allgather_recv(raw_recv, padded_send.shape, send_shapes)
- else:
- if recv_buf is None:
- return self.sub_comm.allgather(send_buf)
- self.sub_comm.Allgather(send_buf, recv_buf)
-
- def _send(self, send_buf, dest, count=None, tag=None):
- """ Send operation
- """
- if deps.nccl_enabled and self.base_comm_nccl:
- if count is None:
- # assuming sending the whole array
- count = send_buf.size
- nccl_send(self.base_comm_nccl, send_buf, dest, count)
- else:
- self.base_comm.send(send_buf, dest, tag)
-
- def _recv(self, recv_buf=None, source=0, count=None, tag=None):
- """ Receive operation
- """
- # NCCL must be called with recv_buf. Size cannot be inferred from
- # other arguments and thus cannot be dynamically allocated
- if deps.nccl_enabled and self.base_comm_nccl and recv_buf is not None:
- if recv_buf is not None:
- if count is None:
- # assuming data will take a space of the whole buffer
- count = recv_buf.size
- nccl_recv(self.base_comm_nccl, recv_buf, source, count)
- return recv_buf
- else:
- raise ValueError("Using recv with NCCL must also supply receiver buffer ")
- else:
- # MPI allows a receiver buffer to be optional and receives as a Python Object
- return self.base_comm.recv(source=source, tag=tag)
-
def _nccl_local_shapes(self, masked: bool):
"""Get the the list of shapes of every GPU in the communicator
"""
# gather tuple of shapes from every rank within thee communicator and copy from GPU to CPU
if masked:
- all_tuples = self._allgather_subcomm(self.local_shape).get()
+ all_tuples = self._allgather_subcomm(self.sub_comm,
+ self.base_comm_nccl,
+ self.local_shape).get()
else:
- all_tuples = self._allgather(self.local_shape).get()
+ all_tuples = self._allgather(self.base_comm,
+ self.base_comm_nccl,
+ self.local_shape).get()
# NCCL returns the flat array that packs every tuple as 1-dimensional array
# unpack each tuple from each rank
tuple_len = len(self.local_shape)
@@ -663,7 +590,9 @@ def dot(self, dist_array):
y = DistributedArray.to_dist(x=dist_array.local_array, base_comm=self.base_comm, base_comm_nccl=self.base_comm_nccl) \
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else dist_array
# Flatten the local arrays and calculate dot product
- return self._allreduce_subcomm(ncp.dot(x.local_array.flatten(), y.local_array.flatten()))
+ return self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl,
+ ncp.dot(x.local_array.flatten(), y.local_array.flatten()),
+ engine=self.engine)
def _compute_vector_norm(self, local_array: NDArray,
axis: int, ord: Optional[int] = None):
@@ -691,31 +620,49 @@ def _compute_vector_norm(self, local_array: NDArray,
raise ValueError(f"norm-{ord} not possible for vectors")
elif ord == 0:
# Count non-zero then sum reduction
- recv_buf = self._allreduce_subcomm(ncp.count_nonzero(local_array, axis=axis).astype(ncp.float64))
+ recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl,
+ ncp.count_nonzero(local_array, axis=axis).astype(ncp.float64),
+ engine=self.engine)
elif ord == ncp.inf:
# Calculate max followed by max reduction
- # TODO (tharitt): currently CuPy + MPI does not work well with buffered communication, particularly
+ # CuPy + non-CUDA-aware MPI does not work well with buffered communication, particularly
# with MAX, MIN operator. Here we copy the array back to CPU, transfer, and copy them back to GPUs
send_buf = ncp.max(ncp.abs(local_array), axis=axis).astype(ncp.float64)
- if self.engine == "cupy" and self.base_comm_nccl is None:
- recv_buf = self._allreduce_subcomm(send_buf.get(), recv_buf.get(), op=MPI.MAX)
+ if self.engine == "cupy" and self.base_comm_nccl is None and not deps.cuda_aware_mpi_enabled:
+ # CuPy + non-CUDA-aware MPI: This will call non-buffered communication
+ # which return a list of object - must be copied back to a GPU memory.
+ recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl,
+ send_buf.get(), recv_buf.get(),
+ op=MPI.MAX, engine=self.engine)
recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis))
else:
- recv_buf = self._allreduce_subcomm(send_buf, recv_buf, op=MPI.MAX)
- recv_buf = ncp.squeeze(recv_buf, axis=axis)
+ recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl,
+ send_buf, recv_buf, op=MPI.MAX,
+ engine=self.engine)
+ # TODO (tharitt): In current implementation, there seems to be a semantic difference between Buffered MPI and NCCL
+ # the (1, size) is collapsed to (size, ) with buffered MPI while NCCL retains it.
+ # There may be a way to unify it - may be something to do with how we allocate the recv_buf.
+ if self.base_comm_nccl:
+ recv_buf = ncp.squeeze(recv_buf, axis=axis)
elif ord == -ncp.inf:
# Calculate min followed by min reduction
- # TODO (tharitt): see the comment above in infinity norm
+ # See the comment above in +infinity norm
send_buf = ncp.min(ncp.abs(local_array), axis=axis).astype(ncp.float64)
- if self.engine == "cupy" and self.base_comm_nccl is None:
- recv_buf = self._allreduce_subcomm(send_buf.get(), recv_buf.get(), op=MPI.MIN)
+ if self.engine == "cupy" and self.base_comm_nccl is None and not deps.cuda_aware_mpi_enabled:
+ recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl,
+ send_buf.get(), recv_buf.get(),
+ op=MPI.MIN, engine=self.engine)
recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis))
else:
- recv_buf = self._allreduce_subcomm(send_buf, recv_buf, op=MPI.MIN)
- recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis))
-
+ recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl,
+ send_buf, recv_buf,
+ op=MPI.MIN, engine=self.engine)
+ if self.base_comm_nccl:
+ recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis))
else:
- recv_buf = self._allreduce_subcomm(ncp.sum(ncp.abs(ncp.float_power(local_array, ord)), axis=axis))
+ recv_buf = self._allreduce_subcomm(self.sub_comm, self.base_comm_nccl,
+ ncp.sum(ncp.abs(ncp.float_power(local_array, ord)), axis=axis),
+ engine=self.engine)
recv_buf = ncp.power(recv_buf, 1.0 / ord)
return recv_buf
@@ -850,7 +797,9 @@ def add_ghost_cells(self, cells_front: Optional[int] = None,
# Transfer of ghost cells can be skipped if len(recv_buf) = 0
# Additionally, NCCL will hang if the buffer size is 0 so this optimization is somewhat mandatory
if len(recv_buf) != 0:
- ghosted_array = ncp.concatenate([self._recv(recv_buf, source=self.rank - 1, tag=1), ghosted_array], axis=self.axis)
+ ghosted_array = ncp.concatenate([self._recv(self.base_comm, self.base_comm_nccl,
+ recv_buf, source=self.rank - 1, tag=1,
+ engine=self.engine), ghosted_array], axis=self.axis)
# The skip in sender is to match with what described in receiver
if self.rank != self.size - 1 and len(send_buf) != 0:
if cells_front > self.local_shape[self.axis]:
@@ -859,7 +808,9 @@ def add_ghost_cells(self, cells_front: Optional[int] = None,
f"{self.local_shape[self.axis]} < {cells_front}; "
f"to achieve this use NUM_PROCESSES <= "
f"{max(1, self.global_shape[self.axis] // cells_front)}")
- self._send(send_buf, dest=self.rank + 1, tag=1)
+ self._send(self.base_comm, self.base_comm_nccl,
+ send_buf, dest=self.rank + 1, tag=1,
+ engine=self.engine)
if cells_back is not None:
total_cells_back = self.base_comm.allgather(cells_back) + [0]
# Read cells_back which needs to be sent to rank - 1(cells_back for rank - 1)
@@ -874,13 +825,17 @@ def add_ghost_cells(self, cells_front: Optional[int] = None,
f"{self.local_shape[self.axis]} < {cells_back}; "
f"to achieve this use NUM_PROCESSES <= "
f"{max(1, self.global_shape[self.axis] // cells_back)}")
- self._send(send_buf, dest=self.rank - 1, tag=0)
+ self._send(self.base_comm, self.base_comm_nccl,
+ send_buf, dest=self.rank - 1, tag=0,
+ engine=self.engine)
if self.rank != self.size - 1:
recv_shape = list(recv_shapes[self.rank + 1])
recv_shape[self.axis] = total_cells_back[self.rank]
recv_buf = ncp.zeros(recv_shape, dtype=ghosted_array.dtype)
if len(recv_buf) != 0:
- ghosted_array = ncp.append(ghosted_array, self._recv(recv_buf, source=self.rank + 1, tag=0),
+ ghosted_array = ncp.append(ghosted_array, self._recv(self.base_comm, self.base_comm_nccl,
+ recv_buf, source=self.rank + 1, tag=0,
+ engine=self.engine),
axis=self.axis)
return ghosted_array
diff --git a/pylops_mpi/basicoperators/VStack.py b/pylops_mpi/basicoperators/VStack.py
index 58581565..de66c342 100644
--- a/pylops_mpi/basicoperators/VStack.py
+++ b/pylops_mpi/basicoperators/VStack.py
@@ -6,7 +6,6 @@
from pylops import LinearOperator
from pylops.utils import DTypeLike
from pylops.utils.backend import get_module
-from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils
from pylops_mpi import (
MPILinearOperator,
@@ -15,17 +14,11 @@
Partition,
StackedDistributedArray
)
+from pylops_mpi.Distributed import DistributedMixIn
from pylops_mpi.utils.decorators import reshaped
-from pylops_mpi.utils import deps
-cupy_message = pylops_deps.cupy_import("the VStack module")
-nccl_message = deps.nccl_import("the VStack module")
-if nccl_message is None and cupy_message is None:
- from pylops_mpi.utils._nccl import nccl_allreduce
-
-
-class MPIVStack(MPILinearOperator):
+class MPIVStack(DistributedMixIn, MPILinearOperator):
r"""MPI VStack Operator
Create a vertical stack of a set of linear operators using MPI. Each rank must
@@ -141,16 +134,18 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
@reshaped(forward=False, stacking=True)
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
- y = DistributedArray(global_shape=self.shape[1], base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, partition=Partition.BROADCAST,
- engine=x.engine, dtype=self.dtype)
+ y = DistributedArray(global_shape=self.shape[1],
+ base_comm=x.base_comm,
+ base_comm_nccl=x.base_comm_nccl,
+ partition=Partition.BROADCAST,
+ engine=x.engine,
+ dtype=self.dtype)
y1 = []
for iop, oper in enumerate(self.ops):
y1.append(oper.rmatvec(x.local_array[self.nnops[iop]: self.nnops[iop + 1]]))
y1 = ncp.sum(ncp.vstack(y1), axis=0)
- if deps.nccl_enabled and x.base_comm_nccl:
- y[:] = nccl_allreduce(x.base_comm_nccl, y1, op=MPI.SUM)
- else:
- y[:] = self.base_comm.allreduce(y1, op=MPI.SUM)
+ y[:] = self._allreduce(x.base_comm, x.base_comm_nccl,
+ y1, op=MPI.SUM, engine=x.engine)
return y
diff --git a/pylops_mpi/signalprocessing/Fredholm1.py b/pylops_mpi/signalprocessing/Fredholm1.py
index 6ccd9d21..2969e3c9 100644
--- a/pylops_mpi/signalprocessing/Fredholm1.py
+++ b/pylops_mpi/signalprocessing/Fredholm1.py
@@ -128,7 +128,8 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
for isl in range(self.nsls[self.rank]):
y1[isl] = ncp.dot(self.G[isl], x[isl])
# gather results
- y[:] = ncp.vstack(y._allgather(y1)).ravel()
+ y[:] = ncp.vstack(y._allgather(y.base_comm, y.base_comm_nccl, y1,
+ engine=y.engine)).ravel()
return y
def _rmatvec(self, x: NDArray) -> NDArray:
@@ -165,5 +166,6 @@ def _rmatvec(self, x: NDArray) -> NDArray:
y1[isl] = ncp.dot(x[isl].T.conj(), self.G[isl]).T.conj()
# gather results
- y[:] = ncp.vstack(y._allgather(y1)).ravel()
+ y[:] = ncp.vstack(y._allgather(y.base_comm, y.base_comm_nccl, y1,
+ engine=y.engine)).ravel()
return y
diff --git a/pylops_mpi/utils/_common.py b/pylops_mpi/utils/_common.py
new file mode 100644
index 00000000..895265df
--- /dev/null
+++ b/pylops_mpi/utils/_common.py
@@ -0,0 +1,89 @@
+__all__ = [
+ "_prepare_allgather_inputs",
+ "_unroll_allgather_recv"
+]
+
+
+import numpy as np
+from pylops.utils.backend import get_module
+
+
+# TODO: return type annotation for both cupy and numpy
+def _prepare_allgather_inputs(send_buf, send_buf_shapes, engine):
+ r""" Prepare send_buf and recv_buf for NCCL allgather (nccl_allgather)
+
+ Buffered Allgather (MPI and NCCL) requires the sending buffer to have the same size for every device.
+ Therefore, padding is required when the array is not evenly partitioned across
+ all the ranks. The padding is applied such that the each dimension of the sending buffers
+ is equal to the max size of that dimension across all ranks.
+
+ Similarly, each receiver buffer (recv_buf) is created with size equal to :math:n_rank \cdot send_buf.size
+
+ Parameters
+ ----------
+ send_buf : :obj: `numpy.ndarray` or `cupy.ndarray` or array-like
+ The data buffer from the local GPU to be sent for allgather.
+ send_buf_shapes: :obj:`list`
+ A list of shapes for each GPU send_buf (used to calculate padding size)
+ engine : :obj:`str`
+ Engine used to store array (``numpy`` or ``cupy``)
+
+ Returns
+ -------
+ send_buf: :obj:`cupy.ndarray`
+ A buffer containing the data and padded elements to be sent by this rank.
+ recv_buf : :obj:`cupy.ndarray`
+ An empty, padded buffer to gather data from all GPUs.
+ """
+ ncp = get_module(engine)
+ sizes_each_dim = list(zip(*send_buf_shapes))
+ send_shape = tuple(map(max, sizes_each_dim))
+ pad_size = [
+ (0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, send_buf.shape)
+ ]
+
+ send_buf = ncp.pad(
+ send_buf, pad_size, mode="constant", constant_values=0
+ )
+
+ ndev = len(send_buf_shapes)
+ recv_buf = ncp.zeros(ndev * send_buf.size, dtype=send_buf.dtype)
+
+ return send_buf, recv_buf
+
+
+def _unroll_allgather_recv(recv_buf, padded_send_buf_shape, send_buf_shapes) -> list:
+ r"""Unrolll recv_buf after Buffered Allgather (MPI and NCCL)
+
+ Remove the padded elements in recv_buff, extract an individual array from each device and return them as a list of arrays
+ Each GPU may send array with a different shape, so the return type has to be a list of array
+ instead of the concatenated array.
+
+ Parameters
+ ----------
+ recv_buf: :obj:`cupy.ndarray` or array-like
+ The data buffer returned from nccl_allgather call
+ padded_send_buf_shape: :obj:`tuple`:int
+ The size of send_buf after padding used in nccl_allgather
+ send_buf_shapes: :obj:`list`
+ A list of original shapes for each GPU send_buf prior to padding
+
+ Returns
+ -------
+ chunks: :obj:`list`
+ A list of `cupy.ndarray` from each GPU with the padded element removed
+ """
+ ndev = len(send_buf_shapes)
+ # extract an individual array from each device
+ chunk_size = np.prod(padded_send_buf_shape)
+ chunks = [
+ recv_buf[i * chunk_size:(i + 1) * chunk_size] for i in range(ndev)
+ ]
+
+ # Remove padding from each array: the padded value may appear somewhere
+ # in the middle of the flat array and thus the reshape and slicing for each dimension is required
+ for i in range(ndev):
+ slicing = tuple(slice(0, end) for end in send_buf_shapes[i])
+ chunks[i] = chunks[i].reshape(padded_send_buf_shape)[slicing]
+
+ return chunks
diff --git a/pylops_mpi/utils/_mpi.py b/pylops_mpi/utils/_mpi.py
new file mode 100644
index 00000000..d0c8c73f
--- /dev/null
+++ b/pylops_mpi/utils/_mpi.py
@@ -0,0 +1,229 @@
+__all__ = [
+ "mpi_allgather",
+ "mpi_allreduce",
+ "mpi_bcast",
+ "mpi_send",
+ "mpi_recv",
+]
+
+from typing import Optional, Union
+
+from mpi4py import MPI
+from pylops.utils import NDArray
+from pylops.utils.backend import get_module
+from pylops_mpi.utils import deps
+from pylops_mpi.utils._common import _prepare_allgather_inputs, _unroll_allgather_recv
+
+
+def mpi_allgather(base_comm: MPI.Comm,
+ send_buf: NDArray,
+ recv_buf: Optional[NDArray] = None,
+ engine: str = "numpy",
+ ) -> NDArray:
+ """MPI_Allallgather/allallgather
+
+ Dispatch allgather routine based on type of input and availability of
+ CUDA-Aware MPI
+
+ Parameters
+ ----------
+ base_comm : :obj:`MPI.Comm`
+ Base MPI Communicator.
+ send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
+ The data buffer from the local rank to be gathered.
+ recv_buf : :obj:`cupy.ndarray`, optional
+ The buffer to store the result of the gathering. If None,
+ a new buffer will be allocated with the appropriate shape.
+ engine : :obj:`str`, optional
+ Engine used to store array (``numpy`` or ``cupy``)
+
+ Returns
+ -------
+ recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
+ A buffer containing the gathered data from all ranks.
+
+ """
+ if deps.cuda_aware_mpi_enabled or engine == "numpy":
+ send_shapes = base_comm.allgather(send_buf.shape)
+ (padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine=engine)
+ recv_buffer_to_use = recv_buf if recv_buf else padded_recv
+ base_comm.Allgather(padded_send, recv_buffer_to_use)
+ return _unroll_allgather_recv(recv_buffer_to_use, padded_send.shape, send_shapes)
+ else:
+ # CuPy with non-CUDA-aware MPI
+ if recv_buf is None:
+ return base_comm.allgather(send_buf)
+ base_comm.Allgather(send_buf, recv_buf)
+ return recv_buf
+
+
+def mpi_allreduce(base_comm: MPI.Comm,
+ send_buf: NDArray,
+ recv_buf: Optional[NDArray] = None,
+ engine: str = "numpy",
+ op: MPI.Op = MPI.SUM,
+ ) -> NDArray:
+ """MPI_Allreduce/allreduce
+
+ Dispatch allreduce routine based on type of input and availability of
+ CUDA-Aware MPI
+
+ Parameters
+ ----------
+ base_comm : :obj:`MPI.Comm`
+ Base MPI Communicator.
+ send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
+ The data buffer from the local rank to be reduced.
+ recv_buf : :obj:`cupy.ndarray`, optional
+ The buffer to store the result of the reduction. If None,
+ a new buffer will be allocated with the appropriate shape.
+ engine : :obj:`str`, optional
+ Engine used to store array (``numpy`` or ``cupy``)
+ op : :obj:mpi4py.MPI.Op, optional
+ The reduction operation to apply. Defaults to MPI.SUM.
+
+ Returns
+ -------
+ recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
+ A buffer containing the result of the reduction, broadcasted
+ to all ranks.
+
+ """
+ if deps.cuda_aware_mpi_enabled or engine == "numpy":
+ ncp = get_module(engine)
+ recv_buf = ncp.zeros(send_buf.size, dtype=send_buf.dtype)
+ base_comm.Allreduce(send_buf, recv_buf, op)
+ return recv_buf
+ else:
+ # CuPy with non-CUDA-aware MPI
+ if recv_buf is None:
+ return base_comm.allreduce(send_buf, op)
+ # For MIN and MAX which require recv_buf
+ base_comm.Allreduce(send_buf, recv_buf, op)
+ return recv_buf
+
+
+def mpi_bcast(base_comm: MPI.Comm,
+ rank: int,
+ local_array: NDArray,
+ index: int,
+ value: Union[int, NDArray],
+ engine: Optional[str] = "numpy",
+ ) -> None:
+ """MPI_Bcast/bcast
+
+ Dispatch bcast routine based on type of input and availability of
+ CUDA-Aware MPI
+
+ Parameters
+ ----------
+ base_comm : :obj:`MPI.Comm`
+ Base MPI Communicator.
+ rank : :obj:`int`
+ Rank.
+ local_array : :obj:`numpy.ndarray`
+ Localy array to be broadcasted.
+ index : :obj:`int` or :obj:`slice`
+ Represents the index positions where a value needs to be assigned.
+ value : :obj:`int` or :obj:`numpy.ndarray`
+ Represents the value that will be assigned to the local array at
+ the specified index positions.
+ engine : :obj:`str`, optional
+ Engine used to store array (``numpy`` or ``cupy``)
+
+ """
+ if deps.cuda_aware_mpi_enabled or engine == "numpy":
+ if rank == 0:
+ local_array[index] = value
+ base_comm.Bcast(local_array[index])
+ else:
+ # CuPy with non-CUDA-aware MPI
+ local_array[index] = base_comm.bcast(value)
+
+
+def mpi_send(base_comm: MPI.Comm,
+ send_buf: NDArray,
+ dest: int,
+ count: Optional[int] = None,
+ tag: int = 0,
+ engine: str = "numpy",
+ ) -> None:
+ """MPI_Send/send
+
+ Dispatch send routine based on type of input and availability of
+ CUDA-Aware MPI
+
+ Parameters
+ ----------
+ base_comm : :obj:`MPI.Comm`
+ Base MPI Communicator.
+ send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
+ The array containing data to send.
+ dest: :obj:`int`
+ The rank of the destination CPU/GPU device.
+ count : :obj:`int`
+ Number of elements to send from `send_buf`.
+ tag : :obj:`int`
+ Tag of the message to be sent.
+ engine : :obj:`str`, optional
+ Engine used to store array (``numpy`` or ``cupy``)
+
+ """
+ if deps.cuda_aware_mpi_enabled or engine == "numpy":
+ # Determine MPI type based on array dtype
+ mpi_type = MPI._typedict[send_buf.dtype.char]
+ if count is None:
+ count = send_buf.size
+ base_comm.Send([send_buf, count, mpi_type], dest=dest, tag=tag)
+ else:
+ # Uses CuPy without CUDA-aware MPI
+ base_comm.send(send_buf, dest, tag)
+
+
+def mpi_recv(base_comm: MPI.Comm,
+ recv_buf: Optional[NDArray] = None,
+ source: int = 0,
+ count: Optional[int] = None,
+ tag: int = 0,
+ engine: Optional[str] = "numpy",
+ ) -> NDArray:
+ """ MPI_Recv/recv
+ Dispatch receive routine based on type of input and availability of
+ CUDA-Aware MPI
+
+ Parameters
+ ----------
+ base_comm : :obj:`MPI.Comm`
+ Base MPI Communicator.
+ recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`, optional
+ The buffered array to receive data.
+ source : :obj:`int`
+ The rank of the sending CPU/GPU device.
+ count : :obj:`int`
+ Number of elements to receive.
+ tag : :obj:`int`
+ Tag of the message to be sent.
+ engine : :obj:`str`, optional
+ Engine used to store array (``numpy`` or ``cupy``)
+
+ Returns
+ -------
+ recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
+ The buffer containing the received data.
+
+ """
+ if deps.cuda_aware_mpi_enabled or engine == "numpy":
+ ncp = get_module(engine)
+ if recv_buf is None:
+ if count is None:
+ raise ValueError("Must provide either recv_buf or count for MPI receive")
+ # Default to int32 works currently because add_ghost_cells() is called
+ # with recv_buf and is not affected by this branch. The int32 is for when
+ # dimension or shape-related integers are send/recv
+ recv_buf = ncp.zeros(count, dtype=ncp.int32)
+ mpi_type = MPI._typedict[recv_buf.dtype.char]
+ base_comm.Recv([recv_buf, recv_buf.size, mpi_type], source=source, tag=tag)
+ else:
+ # Uses CuPy without CUDA-aware MPI
+ recv_buf = base_comm.recv(source=source, tag=tag)
+ return recv_buf
diff --git a/pylops_mpi/utils/_nccl.py b/pylops_mpi/utils/_nccl.py
index 19c09922..cac5b61c 100644
--- a/pylops_mpi/utils/_nccl.py
+++ b/pylops_mpi/utils/_nccl.py
@@ -1,6 +1,4 @@
__all__ = [
- "_prepare_nccl_allgather_inputs",
- "_unroll_nccl_allgather_recv",
"_nccl_sync",
"initialize_nccl_comm",
"nccl_split",
@@ -13,12 +11,11 @@
]
from enum import IntEnum
-from typing import Tuple
from mpi4py import MPI
import os
-import numpy as np
import cupy as cp
import cupy.cuda.nccl as nccl
+from pylops_mpi.utils._common import _prepare_allgather_inputs, _unroll_allgather_recv
cupy_to_nccl_dtype = {
"float32": nccl.NCCL_FLOAT32,
@@ -70,85 +67,6 @@ def _nccl_sync():
cp.cuda.runtime.deviceSynchronize()
-def _prepare_nccl_allgather_inputs(send_buf, send_buf_shapes) -> Tuple[cp.ndarray, cp.ndarray]:
- r""" Prepare send_buf and recv_buf for NCCL allgather (nccl_allgather)
-
- NCCL's allGather requires the sending buffer to have the same size for every device.
- Therefore, padding is required when the array is not evenly partitioned across
- all the ranks. The padding is applied such that the each dimension of the sending buffers
- is equal to the max size of that dimension across all ranks.
-
- Similarly, each receiver buffer (recv_buf) is created with size equal to :math:n_rank \cdot send_buf.size
-
- Parameters
- ----------
- send_buf : :obj:`cupy.ndarray` or array-like
- The data buffer from the local GPU to be sent for allgather.
- send_buf_shapes: :obj:`list`
- A list of shapes for each GPU send_buf (used to calculate padding size)
-
- Returns
- -------
- send_buf: :obj:`cupy.ndarray`
- A buffer containing the data and padded elements to be sent by this rank.
- recv_buf : :obj:`cupy.ndarray`
- An empty, padded buffer to gather data from all GPUs.
- """
- sizes_each_dim = list(zip(*send_buf_shapes))
- send_shape = tuple(map(max, sizes_each_dim))
- pad_size = [
- (0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, send_buf.shape)
- ]
-
- send_buf = cp.pad(
- send_buf, pad_size, mode="constant", constant_values=0
- )
-
- # NCCL recommends to use one MPI Process per GPU and so size of receiving buffer can be inferred
- ndev = len(send_buf_shapes)
- recv_buf = cp.zeros(ndev * send_buf.size, dtype=send_buf.dtype)
-
- return send_buf, recv_buf
-
-
-def _unroll_nccl_allgather_recv(recv_buf, padded_send_buf_shape, send_buf_shapes) -> list:
- """Unrolll recv_buf after NCCL allgather (nccl_allgather)
-
- Remove the padded elements in recv_buff, extract an individual array from each device and return them as a list of arrays
- Each GPU may send array with a different shape, so the return type has to be a list of array
- instead of the concatenated array.
-
- Parameters
- ----------
- recv_buf: :obj:`cupy.ndarray` or array-like
- The data buffer returned from nccl_allgather call
- padded_send_buf_shape: :obj:`tuple`:int
- The size of send_buf after padding used in nccl_allgather
- send_buf_shapes: :obj:`list`
- A list of original shapes for each GPU send_buf prior to padding
-
- Returns
- -------
- chunks: :obj:`list`
- A list of `cupy.ndarray` from each GPU with the padded element removed
- """
-
- ndev = len(send_buf_shapes)
- # extract an individual array from each device
- chunk_size = np.prod(padded_send_buf_shape)
- chunks = [
- recv_buf[i * chunk_size:(i + 1) * chunk_size] for i in range(ndev)
- ]
-
- # Remove padding from each array: the padded value may appear somewhere
- # in the middle of the flat array and thus the reshape and slicing for each dimension is required
- for i in range(ndev):
- slicing = tuple(slice(0, end) for end in send_buf_shapes[i])
- chunks[i] = chunks[i].reshape(padded_send_buf_shape)[slicing]
-
- return chunks
-
-
def mpi_op_to_nccl(mpi_op) -> NcclOp:
""" Map MPI reduction operation to NCCL equivalent
@@ -363,9 +281,9 @@ def nccl_asarray(nccl_comm, local_array, local_shapes, axis) -> cp.ndarray:
Global array gathered from all GPUs and concatenated along `axis`.
"""
- send_buf, recv_buf = _prepare_nccl_allgather_inputs(local_array, local_shapes)
+ send_buf, recv_buf = _prepare_allgather_inputs(local_array, local_shapes, engine="cupy")
nccl_allgather(nccl_comm, send_buf, recv_buf)
- chunks = _unroll_nccl_allgather_recv(recv_buf, send_buf.shape, local_shapes)
+ chunks = _unroll_allgather_recv(recv_buf, send_buf.shape, local_shapes)
# combine back to single global array
return cp.concatenate(chunks, axis=axis)
diff --git a/pylops_mpi/utils/deps.py b/pylops_mpi/utils/deps.py
index 9d983f60..f0279ceb 100644
--- a/pylops_mpi/utils/deps.py
+++ b/pylops_mpi/utils/deps.py
@@ -39,6 +39,10 @@ def nccl_import(message: Optional[str] = None) -> str:
return nccl_message
+cuda_aware_mpi_enabled: bool = (
+ True if int(os.getenv("PYLOPS_MPI_CUDA_AWARE", 1)) == 1 else False
+)
+
nccl_enabled: bool = (
True if (nccl_import() is None and int(os.getenv("NCCL_PYLOPS_MPI", 1)) == 1) else False
)
diff --git a/tests/test_matrixmult.py b/tests/test_matrixmult.py
index a596ffe9..20a11783 100644
--- a/tests/test_matrixmult.py
+++ b/tests/test_matrixmult.py
@@ -219,7 +219,8 @@ def test_MPIMatrixMult_summa(N, K, M, dtype_str):
local_shapes=comm.allgather(X_p.shape[0] * X_p.shape[1]),
partition=Partition.SCATTER,
base_comm=comm,
- dtype=dtype
+ dtype=dtype,
+ engine=backend,
)
x_dist.local_array[:] = X_p.ravel()
diff --git a/tests_nccl/test_ncclutils_nccl.py b/tests_nccl/test_ncclutils_nccl.py
index 21b28ca3..eaae0a69 100644
--- a/tests_nccl/test_ncclutils_nccl.py
+++ b/tests_nccl/test_ncclutils_nccl.py
@@ -8,7 +8,8 @@
from numpy.testing import assert_allclose
import pytest
-from pylops_mpi.utils._nccl import initialize_nccl_comm, nccl_allgather, _prepare_nccl_allgather_inputs, _unroll_nccl_allgather_recv
+from pylops_mpi.utils._nccl import initialize_nccl_comm, nccl_allgather
+from pylops_mpi.utils._common import _prepare_allgather_inputs, _unroll_allgather_recv
np.random.seed(42)
@@ -83,9 +84,9 @@ def test_allgather_differentsize_withrecbuf(par):
# Gathered array
send_shapes = MPI.COMM_WORLD.allgather(local_array.shape)
- send_buf, recv_buf = _prepare_nccl_allgather_inputs(local_array, send_shapes)
+ send_buf, recv_buf = _prepare_allgather_inputs(local_array, send_shapes, engine="cupy")
recv_buf = nccl_allgather(nccl_comm, send_buf, recv_buf)
- chunks = _unroll_nccl_allgather_recv(recv_buf, send_buf.shape, send_shapes)
+ chunks = _unroll_allgather_recv(recv_buf, send_buf.shape, send_shapes)
gathered_array = cp.concatenate(chunks)
# Compare with global array created in rank0