Skip to content

Commit 5a1a00b

Browse files
committed
TYP: reject bool in the ord params of vector_norm and matrix_norm
1 parent b0c88f4 commit 5a1a00b

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

array_api_compat/common/_linalg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from .._internal import get_xp
1414
from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot
15-
from ._typing import Array, DType, Namespace
15+
from ._typing import Array, DType, JustFloat, JustInt, Namespace
1616

1717

1818
# These are in the main NumPy namespace but not in numpy.linalg
@@ -139,7 +139,7 @@ def matrix_norm(
139139
xp: Namespace,
140140
*,
141141
keepdims: bool = False,
142-
ord: float | Literal["fro", "nuc"] | None = "fro",
142+
ord: JustInt | JustFloat | Literal["fro", "nuc"] | None = "fro",
143143
) -> Array:
144144
return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord)
145145

@@ -155,7 +155,7 @@ def vector_norm(
155155
*,
156156
axis: int | tuple[int, ...] | None = None,
157157
keepdims: bool = False,
158-
ord: float = 2,
158+
ord: JustInt | JustFloat = 2,
159159
) -> Array:
160160
# xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
161161
# when axis=None and the input is 2-D, so to force a vector norm, we make

array_api_compat/torch/linalg.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# These functions are in both the main and linalg namespaces
1717
from ._aliases import matmul, matrix_transpose, tensordot
1818
from ._typing import Array, DType
19+
from ..common._typing import JustInt, JustFloat
1920

2021
# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
2122
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
@@ -84,8 +85,8 @@ def vector_norm(
8485
*,
8586
axis: Optional[Union[int, Tuple[int, ...]]] = None,
8687
keepdims: bool = False,
87-
# float stands for inf | -inf, which are not valid for Literal
88-
ord: Union[int, float] = 2,
88+
# JustFloat stands for inf | -inf, which are not valid for Literal
89+
ord: JustInt | JustFloat = 2,
8990
**kwargs,
9091
) -> Array:
9192
# torch.vector_norm incorrectly treats axis=() the same as axis=None
@@ -115,3 +116,6 @@ def vector_norm(
115116
_all_ignore = ['torch_linalg', 'sum']
116117

117118
del linalg_all
119+
120+
def __dir__() -> list[str]:
121+
return __all__

0 commit comments

Comments
 (0)