Skip to content

Conversation

tharittk
Copy link
Collaborator

@tharittk tharittk commented Aug 28, 2025

Objective

This PR has two main goals:

  • moving all MPI/NCCL communication calls from within DistributedArray and various linear operators to a single common place, namely the DistributedMixIn class, whose methods are used by both DistributedArray and the linear operators
  • implement support for mpi4py buffered communications to be used in spite of object communications in NumPy+MPI (better performance, and supported with any version of MPI/mpi4py) and CuPy+Cuda-aware MPI scenarios (whilst still falling back to object communications in CuPy+Non-Cuda-aware MPI scenario). Note that to allow users to force object communications when dealing with CuPy, a new environment variable PYLOPS_MPI_CUDA_AWARE is introduced (defaults to 1 but can be set to 0 to force object communications)

CUDA-Aware MPI

In order to have a CUDA-aware mpi4py installation mpi4py must be build against a CUDA-aware MPI installation. Since conda installations of mpi4py do not ship with a CUDA-aware MPI, it is therefore required to use pip for installing mpi4py. In the case for NCSA Delta, I create a new conda environment and do
module load openmpi/5.0.5+cuda
then
MPICC=/path/to/mpicc pip install --no-cache-dir --force-reinstall mpi4py
(where --force-reinstall is only needed because we install already mpi4py as part of the conda environment creation.

And to run the test (assuming you're in the compute node already):

module load openmpi/5.0.5+cuda
export PYLOPS_MPI_CUDA_AWARE=1
echo "TESTING **WITH** CUDA_AWARE"

echo "TEST NUMPY MPI"
export TEST_CUPY_PYLOPS=0
mpirun -n 2 pytest tests/ --with-mpi

echo "TEST CUPY MPI"
export TEST_CUPY_PYLOPS=1
mpirun -n 2 pytest tests/ --with-mpi

echo " TEST NCCL "
mpirun -n 2 pytest tests_nccl/ --with-mpi

To Do

  • So far the mpi_allgather method uses the _prepare_allgather_inputs method to prepare inputs such that they are all of the same size (via zero-padding). Whilst this is strictly needed for NCCL, we should instead consider leveraging MPI `AllGatherv' instead to avoid extra padding and cutting of arrays - Use in AllGatherv in mpi_allgather #169
  • Modify building process in Makefile and environment/requirement files: I suggest to have some targets for cuda-aware MPI where we put mpi4py in the pip section of the environment file and we ask users to set MPICC upfront (can be documented in the installation.rst section)
  • Modify MatrixMult to remove any direct call to mpi4py communication methods - Use new unified communication methods in MatrixMult #170

tharittk and others added 5 commits August 17, 2025 04:07
A new DistributedMix class is create with the aim of simpflify and unify
all comm. calls in both DistributedArray and operators (further hiding
away all implementation details).
@mrava87
Copy link
Contributor

mrava87 commented Sep 7, 2025

@tharittk great start!

Regarding the setup, I completely agree with the need to change the installation process for CUDA-Aware MPI. I have personally so far mostly relied on conda to install mpi as part of the installation of mpi4py, but it seems like this cannot be done to get CUDA-Aware MPI (see https://urldefense.com/v3/https://chatgpt.com/share/68bdf141-0658-800d-9c6c-e85aa4ab6d87;!!BgN1JKhRo9Eh4Q!SnZ79GzfYSo75i0MB4v9O_mBEnH1UA5IVYuisb-NWb0p9kRXKab9gydJlsLTleI51ozFLiVK8FDInCoRknrulElJpw$); so whilst the module load ... part would change (one may have the same luck that you have to get a pre-installed MPI with CUDA support or may need to install themselves), the second part should be universal, so we may want to add some Makefile targets for this setup 😄

Regarding the code, as I briefly mentioned offline, whilst I think this is the right way to go:

  • buffer comms for NumPy
  • have the PYLOPS_MPI_CUDA_AWARE env variable for CuPy to allow using object comms for non CUDA-Aware MPI + CuPy

i am starting to feel that the number of branches in code is growing and it is about time to put it all in one place... what I am mostly concerned is that this kind of branches will not only be present in DistributedArray but they will start to permeate into operators. I had a first go at it, only with the allgather method to give you and idea and discuss together if you think this is a good approach before we implement it for all the other comm method. The approach I took is two-fold:

  • create a _mpi subpackage (similar to _ncll) where all MPI methods are implemented with the various branches - what so far you had in the else branch in the _allreduce method in DistributedArray
  • create a mixin class DistributedMixIn (in Distributed file) where we can basically move all comm methods that are currently in DistributedArray. However, by doing so, also operators can inherit this class and access those methods - I used VStack as an example.

@astroC86 we have also talked a bit about this in the context of your MatrixMult operator. Pinging you so you can folllow this space, and hopefully once this PR is merged the bar for the implementation of operators that support all backends (Numpy+MPI, Cupy+MPI, Cupy+NCCL) will be lowered as one would just need to know what communication pattern they want to use and call the one from the mixin class without worrying about the subtleties of the different backends

@mrava87 mrava87 mentioned this pull request Sep 9, 2025
@mrava87 mrava87 changed the title Buffered communication for CUDA-Aware MPI Feat: restructuring of communication methods (and buffered communication for CUDA-Aware MPI) Sep 23, 2025
@mrava87
Copy link
Contributor

mrava87 commented Sep 23, 2025

@tharittk I worked a bit more on this, but there is still quite a bit to do (added to the TODO list in the PR main comment)....

Also i am not really sure why some tests are failing on some specific combinations of python/mpi/nranks but not on others... have not investigated yet...

@tharittk tharittk marked this pull request as ready for review October 9, 2025 08:14
Copy link
Contributor

@mrava87 mrava87 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tharittk I left a few comments, Distributed is partially unfinished as we need to make sure all methods get passed the same inputs and don't rely on self.* so that they can be used by both DistributedArray and operators.

Also running tests locally I pass NumPy+MPI and CuPY+MPI but for CuPy+NCCL I get a seg fault at tests_nccl/test_solver_nccl.py::test_cgls_broadcastdata_nccl[par0] Fatal Python error: Segmentation fault. Same for you?

return base_comm.allgather(send_buf)
return mpi_allgather(base_comm, send_buf, recv_buf, engine)

def _allgather_subcomm(self, send_buf, recv_buf=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still needs to be modified like _allgather to avoid using self inside..

else:
return mpi_allgather(self.sub_comm, send_buf, recv_buf, self.engine)

def _bcast(self, local_array, index, value):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

mpi_bcast(self.base_comm, self.rank, self.local_array, index, value,
engine=self.engine)

def _send(self, send_buf, dest, count=None, tag=0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

send_buf, dest, count, tag=tag,
engine=self.engine)

def _recv(self, recv_buf=None, source=0, count=None, tag=0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

from pylops.utils.backend import get_module


# TODO: return type annotation for both cupy and numpy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be handled and removed.

@mrava87
Copy link
Contributor

mrava87 commented Oct 19, 2025

@rohanbabbar04 I remember we discussed long time ago about this and you were actually the first to suggest using mixins... feel free to take a look and provide any feedback 😄

@rohanbabbar04
Copy link
Collaborator

rohanbabbar04 commented Oct 20, 2025

Thanks @tharittk and @mrava87
I will take a look into this tomorrow. 🙂

@tharittk
Copy link
Collaborator Author

@tharittk I left a few comments, Distributed is partially unfinished as we need to make sure all methods get passed the same inputs and don't rely on self.* so that they can be used by both DistributedArray and operators.

Also running tests locally I pass NumPy+MPI and CuPY+MPI but for CuPy+NCCL I get a seg fault at tests_nccl/test_solver_nccl.py::test_cgls_broadcastdata_nccl[par0] Fatal Python error: Segmentation fault. Same for you?

I don't have the problem with the CuPy + NCCL - I still got 309 test passed.

This is my seqeuence of command:
$ conda activate cuda-mpi # env that was built with cuda-aware mpi
$ module load openmpi/5.0.5+cuda # NCSA module load
$ export TEST_CUPY_PYLOPS=1
$ export PYLOPS_MPI_CUDA_AWARE=1
$ mpirun -n 2 pytest tests_nccl/ --with-mpi

I switch to mpiexec and it is still doing ok
$ mpiexec -n 2 pytest tests_nccl/ --with-mpi

@mrava87
Copy link
Contributor

mrava87 commented Oct 20, 2025

mpiexec -n 2 pytest tests_nccl/ --with-mpi

MMh interesting... I installed nccl in my newer openmpi env and I also don't get that error anymore... but I get a new one due to https://github.com/tharittk/pylops-mpi/blob/a317a884efc556419eac0b5652b67207edb3eb97/tests_nccl/test_ncclutils_nccl.py#L12... surely you must get it to, as those methods have been moved to _common?

I fixed that 😄

So I can now run locally the following with success:

make tests
make tests_nccl
export PYLOPS_MPI_CUDA_AWARE=0; make tests_gpu

Seems like that the installation I thought had Cuda-aware MPI was not and things worked as I had PYLOPS_MPI_CUDA_AWARE=0 set... since you have Cuda-aware MPI can you please test the entire suite?

make tests
make tests_nccl
export PYLOPS_MPI_CUDA_AWARE=0; make tests_gpu
export PYLOPS_MPI_CUDA_AWARE=1; make tests_gpu

Apart from this (which we definitely need to try to put on a CI (it is just too many things to do locally now...), once you can handle my code comments above we should be almost ready to merge 🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants