@@ -74,7 +74,7 @@ def active_grid_comm(base_comm: MPI.Comm, N: int, M: int):
7474def local_block_spit (global_shape : Tuple [int , int ],
7575 rank : int ,
7676 comm : MPI .Comm ) -> Tuple [slice , slice ]:
77- """
77+ r """
7878 Compute the local sub‐block of a 2D global array for a process in a square process grid.
7979
8080 Parameters
@@ -122,7 +122,7 @@ def local_block_spit(global_shape: Tuple[int, int],
122122
123123
124124def block_gather (x : DistributedArray , new_shape : Tuple [int , int ], orig_shape : Tuple [int , int ], comm : MPI .Comm ):
125- """
125+ r """
126126 Gather distributed local blocks from 2D block distributed matrix distributed
127127 amongst a square process grid into the full global array.
128128
@@ -351,19 +351,19 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
351351 ncp = get_module (x .engine )
352352 if x .partition != Partition .SCATTER :
353353 raise ValueError (f"x should have partition={ Partition .SCATTER } Got { x .partition } instead..." )
354-
354+ output_dtype = np . result_type ( self . dtype , x . dtype )
355355 y = DistributedArray (
356356 global_shape = (self .N * self .dimsd [1 ]),
357357 local_shapes = [(self .N * c ) for c in self ._rank_col_lens ],
358358 mask = x .mask ,
359359 partition = Partition .SCATTER ,
360- dtype = self . dtype ,
360+ dtype = output_dtype ,
361361 base_comm = self .base_comm
362362 )
363363
364364 my_own_cols = self ._rank_col_lens [self .rank ]
365365 x_arr = x .local_array .reshape ((self .dims [0 ], my_own_cols ))
366- X_local = x_arr .astype (self . dtype )
366+ X_local = x_arr .astype (output_dtype )
367367 Y_local = ncp .vstack (
368368 self ._row_comm .allgather (
369369 ncp .matmul (self .A , X_local )
@@ -377,16 +377,28 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
377377 if x .partition != Partition .SCATTER :
378378 raise ValueError (f"x should have partition={ Partition .SCATTER } . Got { x .partition } instead." )
379379
380+ # - If A is real: A^H = A^T,
381+ # so result_type(real_A, x.dtype) = x.dtype (if x is complex) or real (if x is real)
382+ # - If A is complex: A^H is complex,
383+ # so result will be complex regardless of x
384+ if np .iscomplexobj (self .A ):
385+ output_dtype = np .result_type (self .dtype , x .dtype )
386+ else :
387+ # Real matrix: A^T @ x preserves input type complexity
388+ output_dtype = x .dtype if np .iscomplexobj (x .local_array ) else self .dtype
389+ # But still need to check type promotion for precision
390+ output_dtype = np .result_type (self .dtype , output_dtype )
391+
380392 y = DistributedArray (
381393 global_shape = (self .K * self .dimsd [1 ]),
382394 local_shapes = [self .K * c for c in self ._rank_col_lens ],
383395 mask = x .mask ,
384396 partition = Partition .SCATTER ,
385- dtype = self . dtype ,
397+ dtype = output_dtype ,
386398 base_comm = self .base_comm
387399 )
388400
389- x_arr = x .local_array .reshape ((self .N , self ._local_ncols )).astype (self . dtype )
401+ x_arr = x .local_array .reshape ((self .N , self ._local_ncols )).astype (output_dtype )
390402 X_tile = x_arr [self ._row_start :self ._row_end , :]
391403 A_local = self .At if hasattr (self , "At" ) else self .A .T .conj ()
392404 Y_local = ncp .matmul (A_local , X_tile )
@@ -536,7 +548,6 @@ def __init__(
536548 self ._col_comm = base_comm .Split (color = self ._col_id , key = self ._row_id )
537549
538550 self .A = A .astype (np .dtype (dtype ))
539- if saveAt : self .At = A .T .conj ()
540551
541552 self .N = self ._col_comm .allreduce (A .shape [0 ])
542553 self .K = self ._row_comm .allreduce (A .shape [1 ])
@@ -569,6 +580,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
569580 if x .partition != Partition .SCATTER :
570581 raise ValueError (f"x should have partition={ Partition .SCATTER } Got { x .partition } instead..." )
571582
583+ output_dtype = np .result_type (self .dtype , x .dtype )
572584 # Calculate local shapes for block distribution
573585 bn = self ._N_padded // self ._P_prime # block size in N dimension
574586 bm = self ._M_padded // self ._P_prime # block size in M dimension
@@ -582,9 +594,8 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
582594 mask = x .mask ,
583595 local_shapes = local_shapes ,
584596 partition = Partition .SCATTER ,
585- dtype = self .dtype ,
586- base_comm = self .base_comm
587- )
597+ dtype = output_dtype ,
598+ base_comm = self .base_comm )
588599
589600 # Calculate expected padded dimensions for x
590601 bk = self ._K_padded // self ._P_prime # block size in K dimension
@@ -603,13 +614,13 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
603614 if pad_k > 0 or pad_m > 0 :
604615 x_block = np .pad (x_block , [(0 , pad_k ), (0 , pad_m )], mode = 'constant' )
605616
606- Y_local = np .zeros ((self .A .shape [0 ], bm ))
617+ Y_local = np .zeros ((self .A .shape [0 ], bm ), dtype = output_dtype )
607618
608619 for k in range (self ._P_prime ):
609620 Atemp = self .A .copy () if self ._col_id == k else np .empty_like (self .A )
610621 Xtemp = x_block .copy () if self ._row_id == k else np .empty_like (x_block )
611- self ._row_comm .bcast (Atemp , root = k )
612- self ._col_comm .bcast (Xtemp , root = k )
622+ self ._row_comm .Bcast (Atemp , root = k )
623+ self ._col_comm .Bcast (Xtemp , root = k )
613624 Y_local += ncp .dot (Atemp , Xtemp )
614625
615626 Y_local_unpadded = Y_local [:local_n , :local_m ]
@@ -631,13 +642,24 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
631642 local_m = bm if self ._col_id != self ._P_prime - 1 else self .M - (self ._P_prime - 1 ) * bm
632643
633644 local_shapes = self .base_comm .allgather (local_k * local_m )
645+ # - If A is real: A^H = A^T,
646+ # so result_type(real_A, x.dtype) = x.dtype (if x is complex) or real (if x is real)
647+ # - If A is complex: A^H is complex,
648+ # so result will be complex regardless of x
649+ if np .iscomplexobj (self .A ):
650+ output_dtype = np .result_type (self .dtype , x .dtype )
651+ else :
652+ # Real matrix: A^T @ x preserves input type complexity
653+ output_dtype = x .dtype if np .iscomplexobj (x .local_array ) else self .dtype
654+ # But still need to check type promotion for precision
655+ output_dtype = np .result_type (self .dtype , output_dtype )
634656
635657 y = DistributedArray (
636658 global_shape = (self .K * self .M ),
637659 mask = x .mask ,
638660 local_shapes = local_shapes ,
639661 partition = Partition .SCATTER ,
640- dtype = self . dtype ,
662+ dtype = output_dtype ,
641663 base_comm = self .base_comm
642664 )
643665
@@ -659,7 +681,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
659681 x_block = np .pad (x_block , [(0 , pad_n ), (0 , pad_m )], mode = 'constant' )
660682
661683 A_local = self .At if hasattr (self , "At" ) else self .A .T .conj ()
662- Y_local = np .zeros ((self .A .shape [1 ], bm ))
684+ Y_local = np .zeros ((self .A .shape [1 ], bm ), dtype = output_dtype )
663685
664686 for k in range (self ._P_prime ):
665687 requests = []
0 commit comments