@@ -11,8 +11,11 @@ import optype.typing as opt
11
11
from ._base import _spbase , sparray
12
12
from ._bsr import bsr_array , bsr_matrix
13
13
from ._coo import coo_array , coo_matrix
14
+ from ._csc import csc_array , csc_matrix
14
15
from ._csr import csr_array , csr_matrix
15
16
from ._dia import dia_array , dia_matrix
17
+ from ._dok import dok_array , dok_matrix
18
+ from ._lil import lil_array , lil_matrix
16
19
from ._matrix import spmatrix
17
20
from ._typing import Numeric , SPFormat , ToShape2D , _CanStack , _CanStackAs
18
21
@@ -43,7 +46,7 @@ _ShapeT = TypeVar("_ShapeT", bound=tuple[int, *tuple[int, ...]], default=tuple[A
43
46
_ToArray1D : TypeAlias = Seq [_SCT ] | onp .CanArrayND [_SCT ]
44
47
_ToArray2D : TypeAlias = Seq [Seq [_SCT | int ] | onp .CanArrayND [_SCT ]] | onp .CanArrayND [_SCT ]
45
48
_ToSpMatrix : TypeAlias = spmatrix [_SCT ] | _ToArray2D [_SCT ]
46
- _ToSparse : TypeAlias = _spbase [_SCT ] | _ToArray2D [_SCT ]
49
+ _ToSparse2D : TypeAlias = _spbase [_SCT , tuple [ int , int ] ] | _ToArray2D [_SCT ]
47
50
48
51
_SpBase : TypeAlias = _spbase [_SCT , _ShapeT ] | Any
49
52
_SpMatrix : TypeAlias = spmatrix [_SCT ] | Any
@@ -54,15 +57,21 @@ _SpArray1D: TypeAlias = _SpArray[_SCT, tuple[int]]
54
57
_SpArray2D : TypeAlias = _SpArray [_SCT , tuple [int , int ]]
55
58
56
59
_BSRArray : TypeAlias = bsr_array [_SCT ]
57
- _CSRArray : TypeAlias = csr_array [_SCT , tuple [int , int ]]
60
+ _COOArray2D : TypeAlias = coo_array [_SCT , tuple [int , int ]]
61
+ _CSCArray : TypeAlias = csc_array [_SCT ]
62
+ _CSRArray2D : TypeAlias = csr_array [_SCT , tuple [int , int ]]
63
+ _DIAArray : TypeAlias = dia_array [_SCT ]
64
+ _DOKArray2D : TypeAlias = dok_array [_SCT , tuple [int , int ]]
65
+ _LILArray : TypeAlias = lil_array [_SCT ]
58
66
59
67
_FmtBSR : TypeAlias = Literal ["bsr" ]
60
68
_FmtCOO : TypeAlias = Literal ["coo" ]
69
+ _FmtCSC : TypeAlias = Literal ["csc" ]
61
70
_FmtCSR : TypeAlias = Literal ["csr" ]
62
71
_FmtDIA : TypeAlias = Literal ["dia" ]
63
- _FmtNonBSR : TypeAlias = Literal ["coo" , "csc" , "csr" , "dia" , "dok" , "lil" ]
72
+ _FmtDOK : TypeAlias = Literal ["dok" ]
73
+ _FmtLIL : TypeAlias = Literal ["lil" ]
64
74
_FmtNonCOO : TypeAlias = Literal ["bsr" , "csc" , "csr" , "dia" , "dok" , "lil" ]
65
- _FmtNonCSR : TypeAlias = Literal ["bsr" , "coo" , "csc" , "dia" , "dok" , "lil" ]
66
75
_FmtNonDIA : TypeAlias = Literal ["bsr" , "coo" , "csc" , "csr" , "dok" , "lil" ]
67
76
68
77
_DataRVS : TypeAlias = Callable [[int ], onp .ArrayND [Numeric ]]
@@ -506,32 +515,92 @@ def eye(
506
515
#
507
516
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: {"bsr", None} = ...
508
517
def kron (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtBSR | None = None ) -> bsr_matrix [_SCT ]: ...
509
- @overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: <otherwise>
510
- def kron (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtNonBSR ) -> _SpMatrix [_SCT ]: ...
511
- @overload # A: sparray, B: sparse, format: {"bsr", None} = ...
512
- def kron (A : sparray [_SCT ], B : _ToSparse [_SCT ], format : _FmtBSR | None = None ) -> _BSRArray [_SCT ]: ...
513
- @overload # A: sparray, B: sparse, format: <otherwise>
514
- def kron (A : sparray [_SCT ], B : _ToSparse [_SCT ], format : _FmtNonBSR ) -> _SpArray2D [_SCT ]: ...
518
+ @overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "coo"
519
+ def kron (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtCOO ) -> coo_matrix [_SCT ]: ...
520
+ @overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "csc"
521
+ def kron (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtCSC ) -> csc_matrix [_SCT ]: ...
522
+ @overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "csr"
523
+ def kron (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtCSR ) -> csr_matrix [_SCT ]: ...
524
+ @overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "dia"
525
+ def kron (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtDIA ) -> dia_matrix [_SCT ]: ...
526
+ @overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "dok"
527
+ def kron (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtDOK ) -> dok_matrix [_SCT ]: ...
528
+ @overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "lil"
529
+ def kron (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtLIL ) -> lil_matrix [_SCT ]: ...
530
+ @overload # A: sparray, B: 2D sparse, format: {"bsr", None} = ...
531
+ def kron (A : sparray [_SCT , tuple [int , int ]], B : _ToSparse2D [_SCT ], format : _FmtBSR | None = None ) -> _BSRArray [_SCT ]: ...
532
+ @overload # A: sparray, B: sparse, format: "coo"
533
+ def kron (A : sparray [_SCT , tuple [int , int ]], B : _ToSparse2D [_SCT ], format : _FmtCOO ) -> _COOArray2D [_SCT ]: ...
534
+ @overload # A: sparray, B: sparse, format: "csc"
535
+ def kron (A : sparray [_SCT , tuple [int , int ]], B : _ToSparse2D [_SCT ], format : _FmtCSC ) -> _CSCArray [_SCT ]: ...
536
+ @overload # A: sparray, B: sparse, format: "csr"
537
+ def kron (A : sparray [_SCT , tuple [int , int ]], B : _ToSparse2D [_SCT ], format : _FmtCSR ) -> _CSRArray2D [_SCT ]: ...
538
+ @overload # A: sparray, B: sparse, format: "dia"
539
+ def kron (A : sparray [_SCT , tuple [int , int ]], B : _ToSparse2D [_SCT ], format : _FmtDIA ) -> _DIAArray [_SCT ]: ...
540
+ @overload # A: sparray, B: sparse, format: "dok"
541
+ def kron (A : sparray [_SCT , tuple [int , int ]], B : _ToSparse2D [_SCT ], format : _FmtDOK ) -> _DOKArray2D [_SCT ]: ...
542
+ @overload # A: sparray, B: sparse, format: "lil"
543
+ def kron (A : sparray [_SCT , tuple [int , int ]], B : _ToSparse2D [_SCT ], format : _FmtLIL ) -> _LILArray [_SCT ]: ...
515
544
@overload # A: sparse, B: sparray, format: {"bsr", None} = ...
516
- def kron (A : _ToSparse [_SCT ], B : sparray [_SCT ], format : _FmtBSR | None = None ) -> _BSRArray [_SCT ]: ...
517
- @overload # A: sparse, B: sparray, format: <otherwise>
518
- def kron (A : _ToSparse [_SCT ], B : sparray [_SCT ], format : _FmtNonBSR ) -> _SpArray2D [_SCT ]: ...
545
+ def kron (A : _ToSparse2D [_SCT ], B : sparray [_SCT , tuple [int , int ]], format : _FmtBSR | None = None ) -> _BSRArray [_SCT ]: ...
546
+ @overload # A: sparray, B: sparse, format: "coo"
547
+ def kron (A : _ToSparse2D [_SCT ], B : sparray [_SCT , tuple [int , int ]], format : _FmtCOO ) -> _COOArray2D [_SCT ]: ...
548
+ @overload # A: sparray, B: sparse, format: "csc"
549
+ def kron (A : _ToSparse2D [_SCT ], B : sparray [_SCT , tuple [int , int ]], format : _FmtCSC ) -> _CSCArray [_SCT ]: ...
550
+ @overload # A: sparray, B: sparse, format: "csr"
551
+ def kron (A : _ToSparse2D [_SCT ], B : sparray [_SCT , tuple [int , int ]], format : _FmtCSR ) -> _CSRArray2D [_SCT ]: ...
552
+ @overload # A: sparray, B: sparse, format: "dia"
553
+ def kron (A : _ToSparse2D [_SCT ], B : sparray [_SCT , tuple [int , int ]], format : _FmtDIA ) -> _DIAArray [_SCT ]: ...
554
+ @overload # A: sparray, B: sparse, format: "dok"
555
+ def kron (A : _ToSparse2D [_SCT ], B : sparray [_SCT , tuple [int , int ]], format : _FmtDOK ) -> _DOKArray2D [_SCT ]: ...
556
+ @overload # A: sparray, B: sparse, format: "lil"
557
+ def kron (A : _ToSparse2D [_SCT ], B : sparray [_SCT , tuple [int , int ]], format : _FmtLIL ) -> _LILArray [_SCT ]: ...
519
558
@overload # A: unknown array-like, B: unknown array-like (catch-all)
520
559
def kron (A : onp .ToComplex2D , B : onp .ToComplex2D , format : SPFormat | None = None ) -> _SpBase2D [Incomplete ]: ...
521
560
522
561
# NOTE: The `overload-overlap` mypy errors are false positives.
523
562
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: {"csr", None} = ...
524
563
def kronsum (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtCSR | None = None ) -> csr_matrix [_SCT ]: ...
525
- @overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: <otherwise>
526
- def kronsum (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtNonCSR ) -> _SpMatrix [_SCT ]: ...
564
+ @overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "bsr"
565
+ def kronsum (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtBSR ) -> bsr_matrix [_SCT ]: ...
566
+ @overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "coo"
567
+ def kronsum (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtCOO ) -> coo_matrix [_SCT ]: ...
568
+ @overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "csc"
569
+ def kronsum (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtCSC ) -> csc_matrix [_SCT ]: ...
570
+ @overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "dia"
571
+ def kronsum (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtDIA ) -> dia_matrix [_SCT ]: ...
572
+ @overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "dok"
573
+ def kronsum (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtDOK ) -> dok_matrix [_SCT ]: ...
574
+ @overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "lil"
575
+ def kronsum (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtLIL ) -> lil_matrix [_SCT ]: ...
527
576
@overload # A: sparray, B: sparse, format: {"csr", None} = ...
528
- def kronsum (A : sparray [_SCT ], B : _ToSparse [_SCT ], format : _FmtCSR | None = None ) -> _CSRArray [_SCT ]: ...
529
- @overload # A: sparray, B: sparse, format: <otherwise>
530
- def kronsum (A : sparray [_SCT ], B : _ToSparse [_SCT ], format : _FmtNonCSR ) -> _SpArray2D [_SCT ]: ...
577
+ def kronsum (A : sparray [_SCT , tuple [int , int ]], B : _ToSparse2D [_SCT ], format : _FmtCSR | None = None ) -> _CSRArray2D [_SCT ]: ...
578
+ @overload # A: sparray, B: sparse, format: "bsr"
579
+ def kronsum (A : sparray [_SCT , tuple [int , int ]], B : _ToSparse2D [_SCT ], format : _FmtBSR ) -> _BSRArray [_SCT ]: ...
580
+ @overload # A: sparray, B: sparse, format: "coo"
581
+ def kronsum (A : sparray [_SCT , tuple [int , int ]], B : _ToSparse2D [_SCT ], format : _FmtCOO ) -> _COOArray2D [_SCT ]: ...
582
+ @overload # A: sparray, B: sparse, format: "csc"
583
+ def kronsum (A : sparray [_SCT , tuple [int , int ]], B : _ToSparse2D [_SCT ], format : _FmtCSC ) -> _CSCArray [_SCT ]: ...
584
+ @overload # A: sparray, B: sparse, format: "dia"
585
+ def kronsum (A : sparray [_SCT , tuple [int , int ]], B : _ToSparse2D [_SCT ], format : _FmtDIA ) -> _DIAArray [_SCT ]: ...
586
+ @overload # A: sparray, B: sparse, format: "dok"
587
+ def kronsum (A : sparray [_SCT , tuple [int , int ]], B : _ToSparse2D [_SCT ], format : _FmtDOK ) -> _DOKArray2D [_SCT ]: ...
588
+ @overload # A: sparray, B: sparse, format: "lil"
589
+ def kronsum (A : sparray [_SCT , tuple [int , int ]], B : _ToSparse2D [_SCT ], format : _FmtLIL ) -> _LILArray [_SCT ]: ...
531
590
@overload # A: sparse, B: sparray, format: {"csr", None} = ...
532
- def kronsum (A : _ToSparse [_SCT ], B : sparray [_SCT ], format : _FmtCSR | None = None ) -> _CSRArray [_SCT ]: ...
533
- @overload # A: sparse, B: sparray, format: <otherwise>
534
- def kronsum (A : _ToSparse [_SCT ], B : sparray [_SCT ], format : _FmtNonCSR ) -> _SpArray2D [_SCT ]: ...
591
+ def kronsum (A : _ToSparse2D [_SCT ], B : sparray [_SCT , tuple [int , int ]], format : _FmtCSR | None = None ) -> _CSRArray2D [_SCT ]: ...
592
+ @overload # A: sparse, B: sparray, format: "bsr"
593
+ def kronsum (A : _ToSparse2D [_SCT ], B : sparray [_SCT , tuple [int , int ]], format : _FmtBSR ) -> _BSRArray [_SCT ]: ...
594
+ @overload # A: sparse, B: sparray, format: "coo"
595
+ def kronsum (A : _ToSparse2D [_SCT ], B : sparray [_SCT , tuple [int , int ]], format : _FmtCOO ) -> _COOArray2D [_SCT ]: ...
596
+ @overload # A: sparse, B: sparray, format: "csc"
597
+ def kronsum (A : _ToSparse2D [_SCT ], B : sparray [_SCT , tuple [int , int ]], format : _FmtCSC ) -> _CSCArray [_SCT ]: ...
598
+ @overload # A: sparse, B: sparray, format: "dia"
599
+ def kronsum (A : _ToSparse2D [_SCT ], B : sparray [_SCT , tuple [int , int ]], format : _FmtDIA ) -> _DIAArray [_SCT ]: ...
600
+ @overload # A: sparse, B: sparray, format: "dok"
601
+ def kronsum (A : _ToSparse2D [_SCT ], B : sparray [_SCT , tuple [int , int ]], format : _FmtDOK ) -> _DOKArray2D [_SCT ]: ...
602
+ @overload # A: sparse, B: sparray, format: "lil"
603
+ def kronsum (A : _ToSparse2D [_SCT ], B : sparray [_SCT , tuple [int , int ]], format : _FmtLIL ) -> _LILArray [_SCT ]: ...
535
604
@overload # A: unknown array-like, B: unknown array-like (catch-all)
536
605
def kronsum (A : onp .ToComplex2D , B : onp .ToComplex2D , format : SPFormat | None = None ) -> _SpBase2D [Incomplete ]: ...
537
606
@@ -571,8 +640,6 @@ def vstack(blocks: Seq[_CanStackAs[Any, _T]], format: None = None, *, dtype: npt
571
640
@overload # TODO(jorenham): Support for `format=...`
572
641
def vstack (blocks : Seq [_spbase ], format : SPFormat , dtype : npt .DTypeLike | None = None ) -> Incomplete : ...
573
642
574
- _COOArray2D : TypeAlias = coo_array [_SCT , tuple [int , int ]]
575
-
576
643
# TODO(jorenham): Use `_CanStack` here, which requires a way to map matrix types to array types.
577
644
@overload # blocks: <known dtype>, format: <default>, dtype: <default>
578
645
def block_array (blocks : _ToBlocks [_SCT ], * , format : _FmtCOO | None = None , dtype : None = None ) -> _COOArray2D [_SCT ]: ...
0 commit comments