Skip to content

Commit c4cf9ce

Browse files
Added stubs for all formats in kron and kronsum
1 parent 7d4b902 commit c4cf9ce

File tree

2 files changed

+112
-23
lines changed

2 files changed

+112
-23
lines changed

scipy-stubs/sparse/_construct.pyi

Lines changed: 90 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@ import optype.typing as opt
1111
from ._base import _spbase, sparray
1212
from ._bsr import bsr_array, bsr_matrix
1313
from ._coo import coo_array, coo_matrix
14+
from ._csc import csc_array, csc_matrix
1415
from ._csr import csr_array, csr_matrix
1516
from ._dia import dia_array, dia_matrix
17+
from ._dok import dok_array, dok_matrix
18+
from ._lil import lil_array, lil_matrix
1619
from ._matrix import spmatrix
1720
from ._typing import Numeric, SPFormat, ToShape2D, _CanStack, _CanStackAs
1821

@@ -43,7 +46,7 @@ _ShapeT = TypeVar("_ShapeT", bound=tuple[int, *tuple[int, ...]], default=tuple[A
4346
_ToArray1D: TypeAlias = Seq[_SCT] | onp.CanArrayND[_SCT]
4447
_ToArray2D: TypeAlias = Seq[Seq[_SCT | int] | onp.CanArrayND[_SCT]] | onp.CanArrayND[_SCT]
4548
_ToSpMatrix: TypeAlias = spmatrix[_SCT] | _ToArray2D[_SCT]
46-
_ToSparse: TypeAlias = _spbase[_SCT] | _ToArray2D[_SCT]
49+
_ToSparse2D: TypeAlias = _spbase[_SCT, tuple[int, int]] | _ToArray2D[_SCT]
4750

4851
_SpBase: TypeAlias = _spbase[_SCT, _ShapeT] | Any
4952
_SpMatrix: TypeAlias = spmatrix[_SCT] | Any
@@ -54,15 +57,21 @@ _SpArray1D: TypeAlias = _SpArray[_SCT, tuple[int]]
5457
_SpArray2D: TypeAlias = _SpArray[_SCT, tuple[int, int]]
5558

5659
_BSRArray: TypeAlias = bsr_array[_SCT]
57-
_CSRArray: TypeAlias = csr_array[_SCT, tuple[int, int]]
60+
_COOArray2D: TypeAlias = coo_array[_SCT, tuple[int, int]]
61+
_CSCArray: TypeAlias = csc_array[_SCT]
62+
_CSRArray2D: TypeAlias = csr_array[_SCT, tuple[int, int]]
63+
_DIAArray: TypeAlias = dia_array[_SCT]
64+
_DOKArray2D: TypeAlias = dok_array[_SCT, tuple[int, int]]
65+
_LILArray: TypeAlias = lil_array[_SCT]
5866

5967
_FmtBSR: TypeAlias = Literal["bsr"]
6068
_FmtCOO: TypeAlias = Literal["coo"]
69+
_FmtCSC: TypeAlias = Literal["csc"]
6170
_FmtCSR: TypeAlias = Literal["csr"]
6271
_FmtDIA: TypeAlias = Literal["dia"]
63-
_FmtNonBSR: TypeAlias = Literal["coo", "csc", "csr", "dia", "dok", "lil"]
72+
_FmtDOK: TypeAlias = Literal["dok"]
73+
_FmtLIL: TypeAlias = Literal["lil"]
6474
_FmtNonCOO: TypeAlias = Literal["bsr", "csc", "csr", "dia", "dok", "lil"]
65-
_FmtNonCSR: TypeAlias = Literal["bsr", "coo", "csc", "dia", "dok", "lil"]
6675
_FmtNonDIA: TypeAlias = Literal["bsr", "coo", "csc", "csr", "dok", "lil"]
6776

6877
_DataRVS: TypeAlias = Callable[[int], onp.ArrayND[Numeric]]
@@ -506,32 +515,92 @@ def eye(
506515
#
507516
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: {"bsr", None} = ...
508517
def kron(A: _ToSpMatrix[_SCT], B: _ToSpMatrix[_SCT], format: _FmtBSR | None = None) -> bsr_matrix[_SCT]: ...
509-
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: <otherwise>
510-
def kron(A: _ToSpMatrix[_SCT], B: _ToSpMatrix[_SCT], format: _FmtNonBSR) -> _SpMatrix[_SCT]: ...
511-
@overload # A: sparray, B: sparse, format: {"bsr", None} = ...
512-
def kron(A: sparray[_SCT], B: _ToSparse[_SCT], format: _FmtBSR | None = None) -> _BSRArray[_SCT]: ...
513-
@overload # A: sparray, B: sparse, format: <otherwise>
514-
def kron(A: sparray[_SCT], B: _ToSparse[_SCT], format: _FmtNonBSR) -> _SpArray2D[_SCT]: ...
518+
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "coo"
519+
def kron(A: _ToSpMatrix[_SCT], B: _ToSpMatrix[_SCT], format: _FmtCOO) -> coo_matrix[_SCT]: ...
520+
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "csc"
521+
def kron(A: _ToSpMatrix[_SCT], B: _ToSpMatrix[_SCT], format: _FmtCSC) -> csc_matrix[_SCT]: ...
522+
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "csr"
523+
def kron(A: _ToSpMatrix[_SCT], B: _ToSpMatrix[_SCT], format: _FmtCSR) -> csr_matrix[_SCT]: ...
524+
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "dia"
525+
def kron(A: _ToSpMatrix[_SCT], B: _ToSpMatrix[_SCT], format: _FmtDIA) -> dia_matrix[_SCT]: ...
526+
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "dok"
527+
def kron(A: _ToSpMatrix[_SCT], B: _ToSpMatrix[_SCT], format: _FmtDOK) -> dok_matrix[_SCT]: ...
528+
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "lil"
529+
def kron(A: _ToSpMatrix[_SCT], B: _ToSpMatrix[_SCT], format: _FmtLIL) -> lil_matrix[_SCT]: ...
530+
@overload # A: sparray, B: 2D sparse, format: {"bsr", None} = ...
531+
def kron(A: sparray[_SCT, tuple[int, int]], B: _ToSparse2D[_SCT], format: _FmtBSR | None = None) -> _BSRArray[_SCT]: ...
532+
@overload # A: sparray, B: sparse, format: "coo"
533+
def kron(A: sparray[_SCT, tuple[int, int]], B: _ToSparse2D[_SCT], format: _FmtCOO) -> _COOArray2D[_SCT]: ...
534+
@overload # A: sparray, B: sparse, format: "csc"
535+
def kron(A: sparray[_SCT, tuple[int, int]], B: _ToSparse2D[_SCT], format: _FmtCSC) -> _CSCArray[_SCT]: ...
536+
@overload # A: sparray, B: sparse, format: "csr"
537+
def kron(A: sparray[_SCT, tuple[int, int]], B: _ToSparse2D[_SCT], format: _FmtCSR) -> _CSRArray2D[_SCT]: ...
538+
@overload # A: sparray, B: sparse, format: "dia"
539+
def kron(A: sparray[_SCT, tuple[int, int]], B: _ToSparse2D[_SCT], format: _FmtDIA) -> _DIAArray[_SCT]: ...
540+
@overload # A: sparray, B: sparse, format: "dok"
541+
def kron(A: sparray[_SCT, tuple[int, int]], B: _ToSparse2D[_SCT], format: _FmtDOK) -> _DOKArray2D[_SCT]: ...
542+
@overload # A: sparray, B: sparse, format: "lil"
543+
def kron(A: sparray[_SCT, tuple[int, int]], B: _ToSparse2D[_SCT], format: _FmtLIL) -> _LILArray[_SCT]: ...
515544
@overload # A: sparse, B: sparray, format: {"bsr", None} = ...
516-
def kron(A: _ToSparse[_SCT], B: sparray[_SCT], format: _FmtBSR | None = None) -> _BSRArray[_SCT]: ...
517-
@overload # A: sparse, B: sparray, format: <otherwise>
518-
def kron(A: _ToSparse[_SCT], B: sparray[_SCT], format: _FmtNonBSR) -> _SpArray2D[_SCT]: ...
545+
def kron(A: _ToSparse2D[_SCT], B: sparray[_SCT, tuple[int, int]], format: _FmtBSR | None = None) -> _BSRArray[_SCT]: ...
546+
@overload # A: sparray, B: sparse, format: "coo"
547+
def kron(A: _ToSparse2D[_SCT], B: sparray[_SCT, tuple[int, int]], format: _FmtCOO) -> _COOArray2D[_SCT]: ...
548+
@overload # A: sparray, B: sparse, format: "csc"
549+
def kron(A: _ToSparse2D[_SCT], B: sparray[_SCT, tuple[int, int]], format: _FmtCSC) -> _CSCArray[_SCT]: ...
550+
@overload # A: sparray, B: sparse, format: "csr"
551+
def kron(A: _ToSparse2D[_SCT], B: sparray[_SCT, tuple[int, int]], format: _FmtCSR) -> _CSRArray2D[_SCT]: ...
552+
@overload # A: sparray, B: sparse, format: "dia"
553+
def kron(A: _ToSparse2D[_SCT], B: sparray[_SCT, tuple[int, int]], format: _FmtDIA) -> _DIAArray[_SCT]: ...
554+
@overload # A: sparray, B: sparse, format: "dok"
555+
def kron(A: _ToSparse2D[_SCT], B: sparray[_SCT, tuple[int, int]], format: _FmtDOK) -> _DOKArray2D[_SCT]: ...
556+
@overload # A: sparray, B: sparse, format: "lil"
557+
def kron(A: _ToSparse2D[_SCT], B: sparray[_SCT, tuple[int, int]], format: _FmtLIL) -> _LILArray[_SCT]: ...
519558
@overload # A: unknown array-like, B: unknown array-like (catch-all)
520559
def kron(A: onp.ToComplex2D, B: onp.ToComplex2D, format: SPFormat | None = None) -> _SpBase2D[Incomplete]: ...
521560

522561
# NOTE: The `overload-overlap` mypy errors are false positives.
523562
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: {"csr", None} = ...
524563
def kronsum(A: _ToSpMatrix[_SCT], B: _ToSpMatrix[_SCT], format: _FmtCSR | None = None) -> csr_matrix[_SCT]: ...
525-
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: <otherwise>
526-
def kronsum(A: _ToSpMatrix[_SCT], B: _ToSpMatrix[_SCT], format: _FmtNonCSR) -> _SpMatrix[_SCT]: ...
564+
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "bsr"
565+
def kronsum(A: _ToSpMatrix[_SCT], B: _ToSpMatrix[_SCT], format: _FmtBSR) -> bsr_matrix[_SCT]: ...
566+
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "coo"
567+
def kronsum(A: _ToSpMatrix[_SCT], B: _ToSpMatrix[_SCT], format: _FmtCOO) -> coo_matrix[_SCT]: ...
568+
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "csc"
569+
def kronsum(A: _ToSpMatrix[_SCT], B: _ToSpMatrix[_SCT], format: _FmtCSC) -> csc_matrix[_SCT]: ...
570+
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "dia"
571+
def kronsum(A: _ToSpMatrix[_SCT], B: _ToSpMatrix[_SCT], format: _FmtDIA) -> dia_matrix[_SCT]: ...
572+
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "dok"
573+
def kronsum(A: _ToSpMatrix[_SCT], B: _ToSpMatrix[_SCT], format: _FmtDOK) -> dok_matrix[_SCT]: ...
574+
@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "lil"
575+
def kronsum(A: _ToSpMatrix[_SCT], B: _ToSpMatrix[_SCT], format: _FmtLIL) -> lil_matrix[_SCT]: ...
527576
@overload # A: sparray, B: sparse, format: {"csr", None} = ...
528-
def kronsum(A: sparray[_SCT], B: _ToSparse[_SCT], format: _FmtCSR | None = None) -> _CSRArray[_SCT]: ...
529-
@overload # A: sparray, B: sparse, format: <otherwise>
530-
def kronsum(A: sparray[_SCT], B: _ToSparse[_SCT], format: _FmtNonCSR) -> _SpArray2D[_SCT]: ...
577+
def kronsum(A: sparray[_SCT, tuple[int, int]], B: _ToSparse2D[_SCT], format: _FmtCSR | None = None) -> _CSRArray2D[_SCT]: ...
578+
@overload # A: sparray, B: sparse, format: "bsr"
579+
def kronsum(A: sparray[_SCT, tuple[int, int]], B: _ToSparse2D[_SCT], format: _FmtBSR) -> _BSRArray[_SCT]: ...
580+
@overload # A: sparray, B: sparse, format: "coo"
581+
def kronsum(A: sparray[_SCT, tuple[int, int]], B: _ToSparse2D[_SCT], format: _FmtCOO) -> _COOArray2D[_SCT]: ...
582+
@overload # A: sparray, B: sparse, format: "csc"
583+
def kronsum(A: sparray[_SCT, tuple[int, int]], B: _ToSparse2D[_SCT], format: _FmtCSC) -> _CSCArray[_SCT]: ...
584+
@overload # A: sparray, B: sparse, format: "dia"
585+
def kronsum(A: sparray[_SCT, tuple[int, int]], B: _ToSparse2D[_SCT], format: _FmtDIA) -> _DIAArray[_SCT]: ...
586+
@overload # A: sparray, B: sparse, format: "dok"
587+
def kronsum(A: sparray[_SCT, tuple[int, int]], B: _ToSparse2D[_SCT], format: _FmtDOK) -> _DOKArray2D[_SCT]: ...
588+
@overload # A: sparray, B: sparse, format: "lil"
589+
def kronsum(A: sparray[_SCT, tuple[int, int]], B: _ToSparse2D[_SCT], format: _FmtLIL) -> _LILArray[_SCT]: ...
531590
@overload # A: sparse, B: sparray, format: {"csr", None} = ...
532-
def kronsum(A: _ToSparse[_SCT], B: sparray[_SCT], format: _FmtCSR | None = None) -> _CSRArray[_SCT]: ...
533-
@overload # A: sparse, B: sparray, format: <otherwise>
534-
def kronsum(A: _ToSparse[_SCT], B: sparray[_SCT], format: _FmtNonCSR) -> _SpArray2D[_SCT]: ...
591+
def kronsum(A: _ToSparse2D[_SCT], B: sparray[_SCT, tuple[int, int]], format: _FmtCSR | None = None) -> _CSRArray2D[_SCT]: ...
592+
@overload # A: sparse, B: sparray, format: "bsr"
593+
def kronsum(A: _ToSparse2D[_SCT], B: sparray[_SCT, tuple[int, int]], format: _FmtBSR) -> _BSRArray[_SCT]: ...
594+
@overload # A: sparse, B: sparray, format: "coo"
595+
def kronsum(A: _ToSparse2D[_SCT], B: sparray[_SCT, tuple[int, int]], format: _FmtCOO) -> _COOArray2D[_SCT]: ...
596+
@overload # A: sparse, B: sparray, format: "csc"
597+
def kronsum(A: _ToSparse2D[_SCT], B: sparray[_SCT, tuple[int, int]], format: _FmtCSC) -> _CSCArray[_SCT]: ...
598+
@overload # A: sparse, B: sparray, format: "dia"
599+
def kronsum(A: _ToSparse2D[_SCT], B: sparray[_SCT, tuple[int, int]], format: _FmtDIA) -> _DIAArray[_SCT]: ...
600+
@overload # A: sparse, B: sparray, format: "dok"
601+
def kronsum(A: _ToSparse2D[_SCT], B: sparray[_SCT, tuple[int, int]], format: _FmtDOK) -> _DOKArray2D[_SCT]: ...
602+
@overload # A: sparse, B: sparray, format: "lil"
603+
def kronsum(A: _ToSparse2D[_SCT], B: sparray[_SCT, tuple[int, int]], format: _FmtLIL) -> _LILArray[_SCT]: ...
535604
@overload # A: unknown array-like, B: unknown array-like (catch-all)
536605
def kronsum(A: onp.ToComplex2D, B: onp.ToComplex2D, format: SPFormat | None = None) -> _SpBase2D[Incomplete]: ...
537606

@@ -571,8 +640,6 @@ def vstack(blocks: Seq[_CanStackAs[Any, _T]], format: None = None, *, dtype: npt
571640
@overload # TODO(jorenham): Support for `format=...`
572641
def vstack(blocks: Seq[_spbase], format: SPFormat, dtype: npt.DTypeLike | None = None) -> Incomplete: ...
573642

574-
_COOArray2D: TypeAlias = coo_array[_SCT, tuple[int, int]]
575-
576643
# TODO(jorenham): Use `_CanStack` here, which requires a way to map matrix types to array types.
577644
@overload # blocks: <known dtype>, format: <default>, dtype: <default>
578645
def block_array(blocks: _ToBlocks[_SCT], *, format: _FmtCOO | None = None, dtype: None = None) -> _COOArray2D[_SCT]: ...

tests/sparse/test_construct.pyi

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,35 @@ assert_type(sparse.kron(any_mat, any_mat), sparse.bsr_matrix[ScalarType])
6565
assert_type(sparse.kron(any_mat, any_arr), sparse.bsr_array[ScalarType])
6666
assert_type(sparse.kron(any_arr, any_mat), sparse.bsr_array[ScalarType])
6767
assert_type(sparse.kron(any_arr, any_arr), sparse.bsr_array[ScalarType])
68+
assert_type(sparse.kron(dense_2d, any_arr), sparse.bsr_array[ScalarType])
69+
assert_type(sparse.kron(any_arr, dense_2d), sparse.bsr_array[ScalarType])
70+
assert_type(sparse.kron(any_arr, any_arr, format="bsr"), sparse.bsr_array[ScalarType])
71+
assert_type(sparse.kron(any_arr, any_arr, format="coo"), sparse.coo_array[ScalarType, tuple[int, int]])
72+
assert_type(sparse.kron(any_arr, any_arr, format="csc"), sparse.csc_array[ScalarType])
73+
assert_type(sparse.kron(any_arr, any_arr, format="csr"), sparse.csr_array[ScalarType, tuple[int, int]])
74+
assert_type(sparse.kron(any_arr, any_arr, format="dia"), sparse.dia_array[ScalarType])
75+
assert_type(sparse.kron(any_arr, any_arr, format="dok"), sparse.dok_array[ScalarType, tuple[int, int]])
76+
assert_type(sparse.kron(any_arr, any_arr, format="lil"), sparse.lil_array[ScalarType])
77+
assert_type(sparse.kron(any_arr, dense_2d, format="lil"), sparse.lil_array[ScalarType])
78+
assert_type(sparse.kron(dense_2d, any_arr, format="lil"), sparse.lil_array[ScalarType])
6879
# kronsum
6980
assert_type(sparse.kronsum(any_mat, any_mat), sparse.csr_matrix[ScalarType])
7081
assert_type(sparse.kronsum(any_mat, any_arr), sparse.csr_array[ScalarType])
7182
assert_type(sparse.kronsum(any_arr, any_mat), sparse.csr_array[ScalarType])
7283
assert_type(sparse.kronsum(any_arr, any_arr), sparse.csr_array[ScalarType])
7384
assert_type(sparse.kronsum(any_mat, [[1, 2], [3, 4]]), sparse.csr_matrix[ScalarType])
7485
assert_type(sparse.kronsum(any_arr, [[1, 2], [3, 4]]), sparse.csr_array[ScalarType])
86+
assert_type(sparse.kronsum(dense_2d, any_arr), sparse.csr_array[ScalarType])
87+
assert_type(sparse.kronsum(any_arr, dense_2d), sparse.csr_array[ScalarType])
88+
assert_type(sparse.kronsum(any_arr, any_arr, format="bsr"), sparse.bsr_array[ScalarType])
89+
assert_type(sparse.kronsum(any_arr, any_arr, format="coo"), sparse.coo_array[ScalarType, tuple[int, int]])
90+
assert_type(sparse.kronsum(any_arr, any_arr, format="csc"), sparse.csc_array[ScalarType])
91+
assert_type(sparse.kronsum(any_arr, any_arr, format="csr"), sparse.csr_array[ScalarType, tuple[int, int]])
92+
assert_type(sparse.kronsum(any_arr, any_arr, format="dia"), sparse.dia_array[ScalarType])
93+
assert_type(sparse.kronsum(any_arr, any_arr, format="dok"), sparse.dok_array[ScalarType, tuple[int, int]])
94+
assert_type(sparse.kronsum(any_arr, any_arr, format="lil"), sparse.lil_array[ScalarType])
95+
assert_type(sparse.kronsum(any_arr, dense_2d, format="lil"), sparse.lil_array[ScalarType])
96+
assert_type(sparse.kronsum(dense_2d, any_arr, format="lil"), sparse.lil_array[ScalarType])
7597

7698
###
7799
# hstack (same as vstack)

0 commit comments

Comments
 (0)