Skip to content

Commit d54e827

Browse files
committed
Added SUMMA tests and fixed dtype problem
1 parent 0956e7b commit d54e827

File tree

3 files changed

+189
-66
lines changed

3 files changed

+189
-66
lines changed

examples/plot_matrixmult.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
import pylops_mpi
3030
from pylops_mpi import Partition
31+
from pylops_mpi.basicoperators.MatrixMult import active_grid_comm, MPIMatrixMult
3132

3233
plt.close("all")
3334

@@ -88,8 +89,7 @@
8889
# than the row or columm ranks.
8990

9091
base_comm = MPI.COMM_WORLD
91-
comm, rank, row_id, col_id, is_active = \
92-
pylops_mpi.MPIMatrixMult.active_grid_comm(base_comm, N, M)
92+
comm, rank, row_id, col_id, is_active = active_grid_comm(base_comm, N, M)
9393
print(f"Process {base_comm.Get_rank()} is {'active' if is_active else 'inactive'}")
9494
if not is_active: exit(0)
9595

@@ -147,7 +147,7 @@
147147
################################################################################
148148
# We are now ready to create the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult`
149149
# operator and the input matrix :math:`\mathbf{X}`
150-
Aop = pylops_mpi.MPIMatrixMult(A_p, M, base_comm=comm, dtype="float32")
150+
Aop = MPIMatrixMult(A_p, M, base_comm=comm, dtype="float32", kind="block")
151151

152152
col_lens = comm.allgather(my_own_cols)
153153
total_cols = np.sum(col_lens)

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def active_grid_comm(base_comm: MPI.Comm, N: int, M: int):
7474
def local_block_spit(global_shape: Tuple[int, int],
7575
rank: int,
7676
comm: MPI.Comm) -> Tuple[slice, slice]:
77-
"""
77+
r"""
7878
Compute the local sub‐block of a 2D global array for a process in a square process grid.
7979
8080
Parameters
@@ -122,7 +122,7 @@ def local_block_spit(global_shape: Tuple[int, int],
122122

123123

124124
def block_gather(x: DistributedArray, new_shape: Tuple[int, int], orig_shape: Tuple[int, int], comm: MPI.Comm):
125-
"""
125+
r"""
126126
Gather distributed local blocks from 2D block distributed matrix distributed
127127
amongst a square process grid into the full global array.
128128
@@ -351,19 +351,19 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
351351
ncp = get_module(x.engine)
352352
if x.partition != Partition.SCATTER:
353353
raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...")
354-
354+
output_dtype = np.result_type(self.dtype, x.dtype)
355355
y = DistributedArray(
356356
global_shape=(self.N * self.dimsd[1]),
357357
local_shapes=[(self.N * c) for c in self._rank_col_lens],
358358
mask=x.mask,
359359
partition=Partition.SCATTER,
360-
dtype=self.dtype,
360+
dtype=output_dtype,
361361
base_comm=self.base_comm
362362
)
363363

364364
my_own_cols = self._rank_col_lens[self.rank]
365365
x_arr = x.local_array.reshape((self.dims[0], my_own_cols))
366-
X_local = x_arr.astype(self.dtype)
366+
X_local = x_arr.astype(output_dtype)
367367
Y_local = ncp.vstack(
368368
self._row_comm.allgather(
369369
ncp.matmul(self.A, X_local)
@@ -377,16 +377,28 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
377377
if x.partition != Partition.SCATTER:
378378
raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.")
379379

380+
# - If A is real: A^H = A^T,
381+
# so result_type(real_A, x.dtype) = x.dtype (if x is complex) or real (if x is real)
382+
# - If A is complex: A^H is complex,
383+
# so result will be complex regardless of x
384+
if np.iscomplexobj(self.A):
385+
output_dtype = np.result_type(self.dtype, x.dtype)
386+
else:
387+
# Real matrix: A^T @ x preserves input type complexity
388+
output_dtype = x.dtype if np.iscomplexobj(x.local_array) else self.dtype
389+
# But still need to check type promotion for precision
390+
output_dtype = np.result_type(self.dtype, output_dtype)
391+
380392
y = DistributedArray(
381393
global_shape=(self.K * self.dimsd[1]),
382394
local_shapes=[self.K * c for c in self._rank_col_lens],
383395
mask=x.mask,
384396
partition=Partition.SCATTER,
385-
dtype=self.dtype,
397+
dtype=output_dtype,
386398
base_comm=self.base_comm
387399
)
388400

389-
x_arr = x.local_array.reshape((self.N, self._local_ncols)).astype(self.dtype)
401+
x_arr = x.local_array.reshape((self.N, self._local_ncols)).astype(output_dtype)
390402
X_tile = x_arr[self._row_start:self._row_end, :]
391403
A_local = self.At if hasattr(self, "At") else self.A.T.conj()
392404
Y_local = ncp.matmul(A_local, X_tile)
@@ -536,7 +548,6 @@ def __init__(
536548
self._col_comm = base_comm.Split(color=self._col_id, key=self._row_id)
537549

538550
self.A = A.astype(np.dtype(dtype))
539-
if saveAt: self.At = A.T.conj()
540551

541552
self.N = self._col_comm.allreduce(A.shape[0])
542553
self.K = self._row_comm.allreduce(A.shape[1])
@@ -569,6 +580,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
569580
if x.partition != Partition.SCATTER:
570581
raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...")
571582

583+
output_dtype = np.result_type(self.dtype, x.dtype)
572584
# Calculate local shapes for block distribution
573585
bn = self._N_padded // self._P_prime # block size in N dimension
574586
bm = self._M_padded // self._P_prime # block size in M dimension
@@ -582,9 +594,8 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
582594
mask=x.mask,
583595
local_shapes=local_shapes,
584596
partition=Partition.SCATTER,
585-
dtype=self.dtype,
586-
base_comm=self.base_comm
587-
)
597+
dtype=output_dtype,
598+
base_comm=self.base_comm)
588599

589600
# Calculate expected padded dimensions for x
590601
bk = self._K_padded // self._P_prime # block size in K dimension
@@ -603,13 +614,13 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
603614
if pad_k > 0 or pad_m > 0:
604615
x_block = np.pad(x_block, [(0, pad_k), (0, pad_m)], mode='constant')
605616

606-
Y_local = np.zeros((self.A.shape[0], bm))
617+
Y_local = np.zeros((self.A.shape[0], bm),dtype=output_dtype)
607618

608619
for k in range(self._P_prime):
609620
Atemp = self.A.copy() if self._col_id == k else np.empty_like(self.A)
610621
Xtemp = x_block.copy() if self._row_id == k else np.empty_like(x_block)
611-
self._row_comm.bcast(Atemp, root=k)
612-
self._col_comm.bcast(Xtemp, root=k)
622+
self._row_comm.Bcast(Atemp, root=k)
623+
self._col_comm.Bcast(Xtemp, root=k)
613624
Y_local += ncp.dot(Atemp, Xtemp)
614625

615626
Y_local_unpadded = Y_local[:local_n, :local_m]
@@ -631,13 +642,24 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
631642
local_m = bm if self._col_id != self._P_prime - 1 else self.M - (self._P_prime - 1) * bm
632643

633644
local_shapes = self.base_comm.allgather(local_k * local_m)
645+
# - If A is real: A^H = A^T,
646+
# so result_type(real_A, x.dtype) = x.dtype (if x is complex) or real (if x is real)
647+
# - If A is complex: A^H is complex,
648+
# so result will be complex regardless of x
649+
if np.iscomplexobj(self.A):
650+
output_dtype = np.result_type(self.dtype, x.dtype)
651+
else:
652+
# Real matrix: A^T @ x preserves input type complexity
653+
output_dtype = x.dtype if np.iscomplexobj(x.local_array) else self.dtype
654+
# But still need to check type promotion for precision
655+
output_dtype = np.result_type(self.dtype, output_dtype)
634656

635657
y = DistributedArray(
636658
global_shape=(self.K * self.M),
637659
mask=x.mask,
638660
local_shapes=local_shapes,
639661
partition=Partition.SCATTER,
640-
dtype=self.dtype,
662+
dtype=output_dtype,
641663
base_comm=self.base_comm
642664
)
643665

@@ -659,7 +681,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
659681
x_block = np.pad(x_block, [(0, pad_n), (0, pad_m)], mode='constant')
660682

661683
A_local = self.At if hasattr(self, "At") else self.A.T.conj()
662-
Y_local = np.zeros((self.A.shape[1], bm))
684+
Y_local = np.zeros((self.A.shape[1], bm), dtype=output_dtype)
663685

664686
for k in range(self._P_prime):
665687
requests = []

0 commit comments

Comments
 (0)