Skip to content

Commit 648ed7b

Browse files
authored
jwcrypto: type most of the rest of JWT and JWKSet.generate function (#13807)
1 parent a045be8 commit 648ed7b

File tree

2 files changed

+39
-17
lines changed

2 files changed

+39
-17
lines changed

stubs/jwcrypto/jwcrypto/jwk.pyi

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from collections.abc import Callable, Sequence
22
from enum import Enum
33
from typing import Any, Literal, NamedTuple, TypeVar, overload
4-
from typing_extensions import Self, deprecated
4+
from typing_extensions import Self, TypeAlias, deprecated
55

66
from cryptography.hazmat.primitives import hashes
77
from cryptography.hazmat.primitives.asymmetric import ec, rsa
@@ -46,7 +46,8 @@ class _X448_CURVE(NamedTuple):
4646
pubkey: UnimplementedOKPCurveKey
4747
privkey: UnimplementedOKPCurveKey
4848

49-
JWKTypesRegistry: dict[str, str]
49+
_JWKKeyTypeSupported: TypeAlias = Literal["oct", "RSA", "EC", "OKP"]
50+
JWKTypesRegistry: dict[_JWKKeyTypeSupported, str]
5051

5152
class ParmType(Enum):
5253
name = "A string with a name" # pyright: ignore[reportAssignmentType]
@@ -63,8 +64,12 @@ class JWKParameter(NamedTuple):
6364
JWKValuesRegistry: dict[str, dict[str, JWKParameter]]
6465
JWKParamsRegistry: dict[str, JWKParameter]
6566
JWKEllipticCurveRegistry: dict[str, str]
66-
JWKUseRegistry: dict[str, str]
67-
JWKOperationsRegistry: dict[str, str]
67+
_JWKUseSupported: TypeAlias = Literal["sig", "enc"]
68+
JWKUseRegistry: dict[_JWKUseSupported, str]
69+
_JWKOperationSupported: TypeAlias = Literal[
70+
"sign", "verify", "encrypt", "decrypt", "wrapKey", "unwrapKey", "deriveKey", "deriveBits"
71+
]
72+
JWKOperationsRegistry: dict[_JWKOperationSupported, str]
6873
JWKpycaCurveMap: dict[str, str]
6974
IANANamedInformationHashAlgorithmRegistry: dict[
7075
str,
@@ -98,9 +103,26 @@ class InvalidJWKValue(JWException): ...
98103

99104
class JWK(dict[str, Any]):
100105
def __init__(self, **kwargs) -> None: ...
106+
# `kty` and the other keyword arguments are passed as `params` to the called generator
107+
# function. The possible arguments depend on the value of `kty`.
108+
# TODO: Add overloads for the individual `kty` values.
109+
@classmethod
110+
@overload
111+
def generate(
112+
cls,
113+
*,
114+
kty: Literal["RSA"],
115+
public_exponent: int | None = None,
116+
size: int | None = None,
117+
kid: str | None = None,
118+
alg: str | None = None,
119+
use: _JWKUseSupported | None = None,
120+
key_ops: list[_JWKOperationSupported] | None = None,
121+
) -> Self: ...
101122
@classmethod
102-
def generate(cls, **kwargs) -> Self: ...
103-
def generate_key(self, **params) -> None: ...
123+
@overload
124+
def generate(cls, *, kty: _JWKKeyTypeSupported, **kwargs) -> Self: ...
125+
def generate_key(self, *, kty: _JWKKeyTypeSupported, **kwargs) -> None: ...
104126
def import_key(self, **kwargs) -> None: ...
105127
@classmethod
106128
def from_json(cls, key) -> Self: ...

stubs/jwcrypto/jwcrypto/jwt.pyi

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from _typeshed import Incomplete
22
from collections.abc import Mapping
3-
from typing import Any
3+
from typing import Any, SupportsInt
44
from typing_extensions import deprecated
55

66
from jwcrypto.common import JWException, JWKeyNotFound
@@ -49,31 +49,31 @@ class JWT:
4949
@header.setter
5050
def header(self, h: dict[str, Any] | str) -> None: ...
5151
@property
52-
def claims(self): ...
52+
def claims(self) -> str: ...
5353
@claims.setter
54-
def claims(self, data) -> None: ...
54+
def claims(self, data: str) -> None: ...
5555
@property
5656
def token(self): ...
5757
@token.setter
5858
def token(self, t) -> None: ...
5959
@property
60-
def leeway(self): ...
60+
def leeway(self) -> int: ...
6161
@leeway.setter
62-
def leeway(self, lwy) -> None: ...
62+
def leeway(self, lwy: SupportsInt) -> None: ...
6363
@property
64-
def validity(self): ...
64+
def validity(self) -> int: ...
6565
@validity.setter
66-
def validity(self, v) -> None: ...
66+
def validity(self, v: SupportsInt) -> None: ...
6767
@property
6868
def expected_type(self): ...
6969
@expected_type.setter
7070
def expected_type(self, v) -> None: ...
7171
def norm_typ(self, val): ...
72-
def make_signed_token(self, key) -> None: ...
73-
def make_encrypted_token(self, key) -> None: ...
74-
def validate(self, key) -> None: ...
72+
def make_signed_token(self, key: JWK) -> None: ...
73+
def make_encrypted_token(self, key: JWK) -> None: ...
74+
def validate(self, key: JWK | JWKSet) -> None: ...
7575
def deserialize(self, jwt, key: Incomplete | None = None) -> None: ...
76-
def serialize(self, compact: bool = True): ...
76+
def serialize(self, compact: bool = True) -> str: ...
7777
@classmethod
7878
def from_jose_token(cls, token): ...
7979
def __eq__(self, other: object) -> bool: ...

0 commit comments

Comments
 (0)