7
7
import sys
8
8
import warnings
9
9
from types import TracebackType
10
- from typing import Any , ClassVar , Dict , List , Literal , Optional , Type , TypedDict , TypeVar , Union , overload
10
+ from typing import Any , ClassVar , Literal , Optional , TypedDict , TypeVar , Union , overload
11
11
12
12
import httpx
13
+ import jwt
13
14
import pydantic
14
15
from oauthlib .oauth2 import WebApplicationClient
15
16
from starlette .exceptions import HTTPException
33
34
P = ParamSpec ("P" )
34
35
35
36
37
+ def _decode_id_token (id_token : str , verify : bool = False ) -> dict :
38
+ return jwt .decode (id_token , options = {"verify_signature" : verify })
39
+
40
+
36
41
class DiscoveryDocument (TypedDict ):
37
42
"""Discovery document."""
38
43
@@ -95,10 +100,11 @@ class SSOBase:
95
100
client_id : str = NotImplemented
96
101
client_secret : str = NotImplemented
97
102
redirect_uri : Optional [Union [pydantic .AnyHttpUrl , str ]] = NotImplemented
98
- scope : ClassVar [List [str ]] = []
99
- additional_headers : ClassVar [Optional [Dict [str , Any ]]] = None
103
+ scope : ClassVar [list [str ]] = []
104
+ additional_headers : ClassVar [Optional [dict [str , Any ]]] = None
100
105
uses_pkce : bool = False
101
106
requires_state : bool = False
107
+ use_id_token_for_user_info : ClassVar [bool ] = False
102
108
103
109
_pkce_challenge_length : int = 96
104
110
@@ -109,7 +115,7 @@ def __init__(
109
115
redirect_uri : Optional [Union [pydantic .AnyHttpUrl , str ]] = None ,
110
116
allow_insecure_http : bool = False ,
111
117
use_state : bool = False ,
112
- scope : Optional [List [str ]] = None ,
118
+ scope : Optional [list [str ]] = None ,
113
119
):
114
120
"""Base class (mixin) for all SSO providers."""
115
121
self .client_id : str = client_id
@@ -224,6 +230,18 @@ async def openid_from_response(self, response: dict, session: Optional[httpx.Asy
224
230
"""
225
231
raise NotImplementedError (f"Provider { self .provider } not supported" )
226
232
233
+ async def openid_from_token (self , id_token : dict , session : Optional [httpx .AsyncClient ] = None ) -> OpenID :
234
+ """Converts an ID token from the provider's token endpoint to an OpenID object.
235
+
236
+ Args:
237
+ id_token (dict): The id token data retrieved from the token endpoint.
238
+ session: (Optional[httpx.AsyncClient]): The HTTPX AsyncClient session.
239
+
240
+ Returns:
241
+ OpenID: The user information in a standardized format.
242
+ """
243
+ raise NotImplementedError (f"Provider { self .provider } not supported" )
244
+
227
245
async def get_discovery_document (self ) -> DiscoveryDocument :
228
246
"""Retrieves the discovery document containing useful URLs.
229
247
@@ -257,14 +275,14 @@ async def get_login_url(
257
275
self ,
258
276
* ,
259
277
redirect_uri : Optional [Union [pydantic .AnyHttpUrl , str ]] = None ,
260
- params : Optional [Dict [str , Any ]] = None ,
278
+ params : Optional [dict [str , Any ]] = None ,
261
279
state : Optional [str ] = None ,
262
280
) -> str :
263
281
"""Generates and returns the prepared login URL.
264
282
265
283
Args:
266
284
redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance.
267
- params (Optional[Dict [str, Any]]): Additional query parameters to add to the login request.
285
+ params (Optional[dict [str, Any]]): Additional query parameters to add to the login request.
268
286
state (Optional[str]): The state parameter for the OAuth 2.0 authorization request.
269
287
270
288
Raises:
@@ -304,14 +322,14 @@ async def get_login_redirect(
304
322
self ,
305
323
* ,
306
324
redirect_uri : Optional [str ] = None ,
307
- params : Optional [Dict [str , Any ]] = None ,
325
+ params : Optional [dict [str , Any ]] = None ,
308
326
state : Optional [str ] = None ,
309
327
) -> RedirectResponse :
310
328
"""Constructs and returns a redirect response to the login page of OAuth SSO provider.
311
329
312
330
Args:
313
331
redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance.
314
- params (Optional[Dict [str, Any]]): Additional query parameters to add to the login request.
332
+ params (Optional[dict [str, Any]]): Additional query parameters to add to the login request.
315
333
state (Optional[str]): The state parameter for the OAuth 2.0 authorization request.
316
334
317
335
Returns:
@@ -330,8 +348,8 @@ async def verify_and_process(
330
348
self ,
331
349
request : Request ,
332
350
* ,
333
- params : Optional [Dict [str , Any ]] = None ,
334
- headers : Optional [Dict [str , Any ]] = None ,
351
+ params : Optional [dict [str , Any ]] = None ,
352
+ headers : Optional [dict [str , Any ]] = None ,
335
353
redirect_uri : Optional [str ] = None ,
336
354
convert_response : Literal [True ] = True ,
337
355
) -> Optional [OpenID ]: ...
@@ -341,28 +359,28 @@ async def verify_and_process(
341
359
self ,
342
360
request : Request ,
343
361
* ,
344
- params : Optional [Dict [str , Any ]] = None ,
345
- headers : Optional [Dict [str , Any ]] = None ,
362
+ params : Optional [dict [str , Any ]] = None ,
363
+ headers : Optional [dict [str , Any ]] = None ,
346
364
redirect_uri : Optional [str ] = None ,
347
365
convert_response : Literal [False ],
348
- ) -> Optional [Dict [str , Any ]]: ...
366
+ ) -> Optional [dict [str , Any ]]: ...
349
367
350
368
@requires_async_context
351
369
async def verify_and_process (
352
370
self ,
353
371
request : Request ,
354
372
* ,
355
- params : Optional [Dict [str , Any ]] = None ,
356
- headers : Optional [Dict [str , Any ]] = None ,
373
+ params : Optional [dict [str , Any ]] = None ,
374
+ headers : Optional [dict [str , Any ]] = None ,
357
375
redirect_uri : Optional [str ] = None ,
358
376
convert_response : Union [Literal [True ], Literal [False ]] = True ,
359
- ) -> Union [Optional [OpenID ], Optional [Dict [str , Any ]]]:
377
+ ) -> Union [Optional [OpenID ], Optional [dict [str , Any ]]]:
360
378
"""Processes the login given a FastAPI (Starlette) Request object. This should be used for the /callback path.
361
379
362
380
Args:
363
381
request (Request): FastAPI or Starlette request object.
364
- params (Optional[Dict [str, Any]]): Additional query parameters to pass to the provider.
365
- headers (Optional[Dict [str, Any]]): Additional headers to pass to the provider.
382
+ params (Optional[dict [str, Any]]): Additional query parameters to pass to the provider.
383
+ headers (Optional[dict [str, Any]]): Additional headers to pass to the provider.
366
384
redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance.
367
385
convert_response (bool): If True, userinfo response is converted to OpenID object.
368
386
@@ -371,7 +389,7 @@ async def verify_and_process(
371
389
372
390
Returns:
373
391
Optional[OpenID]: User information as OpenID instance (if convert_response == True)
374
- Optional[Dict [str, Any]]: The original JSON response from the API.
392
+ Optional[dict [str, Any]]: The original JSON response from the API.
375
393
"""
376
394
headers = headers or {}
377
395
code = request .query_params .get ("code" )
@@ -433,7 +451,7 @@ async def __aenter__(self) -> "SSOBase":
433
451
434
452
async def __aexit__ (
435
453
self ,
436
- _exc_type : Optional [Type [BaseException ]],
454
+ _exc_type : Optional [type [BaseException ]],
437
455
_exc_val : Optional [BaseException ],
438
456
_exc_tb : Optional [TracebackType ],
439
457
) -> None :
@@ -442,14 +460,14 @@ async def __aexit__(
442
460
443
461
def __exit__ (
444
462
self ,
445
- _exc_type : Optional [Type [BaseException ]],
463
+ _exc_type : Optional [type [BaseException ]],
446
464
_exc_val : Optional [BaseException ],
447
465
_exc_tb : Optional [TracebackType ],
448
466
) -> None :
449
467
return None
450
468
451
469
@property
452
- def _extra_query_params (self ) -> Dict :
470
+ def _extra_query_params (self ) -> dict :
453
471
return {}
454
472
455
473
@overload
@@ -458,8 +476,8 @@ async def process_login(
458
476
code : str ,
459
477
request : Request ,
460
478
* ,
461
- params : Optional [Dict [str , Any ]] = None ,
462
- additional_headers : Optional [Dict [str , Any ]] = None ,
479
+ params : Optional [dict [str , Any ]] = None ,
480
+ additional_headers : Optional [dict [str , Any ]] = None ,
463
481
redirect_uri : Optional [str ] = None ,
464
482
pkce_code_verifier : Optional [str ] = None ,
465
483
convert_response : Literal [True ] = True ,
@@ -471,33 +489,33 @@ async def process_login(
471
489
code : str ,
472
490
request : Request ,
473
491
* ,
474
- params : Optional [Dict [str , Any ]] = None ,
475
- additional_headers : Optional [Dict [str , Any ]] = None ,
492
+ params : Optional [dict [str , Any ]] = None ,
493
+ additional_headers : Optional [dict [str , Any ]] = None ,
476
494
redirect_uri : Optional [str ] = None ,
477
495
pkce_code_verifier : Optional [str ] = None ,
478
496
convert_response : Literal [False ],
479
- ) -> Optional [Dict [str , Any ]]: ...
497
+ ) -> Optional [dict [str , Any ]]: ...
480
498
481
499
@requires_async_context
482
500
async def process_login (
483
501
self ,
484
502
code : str ,
485
503
request : Request ,
486
504
* ,
487
- params : Optional [Dict [str , Any ]] = None ,
488
- additional_headers : Optional [Dict [str , Any ]] = None ,
505
+ params : Optional [dict [str , Any ]] = None ,
506
+ additional_headers : Optional [dict [str , Any ]] = None ,
489
507
redirect_uri : Optional [str ] = None ,
490
508
pkce_code_verifier : Optional [str ] = None ,
491
509
convert_response : Union [Literal [True ], Literal [False ]] = True ,
492
- ) -> Union [Optional [OpenID ], Optional [Dict [str , Any ]]]:
510
+ ) -> Union [Optional [OpenID ], Optional [dict [str , Any ]]]:
493
511
"""Processes login from the callback endpoint to verify the user and request user info endpoint.
494
512
It's a lower-level method, typically, you should use `verify_and_process` instead.
495
513
496
514
Args:
497
515
code (str): The authorization code.
498
516
request (Request): FastAPI or Starlette request object.
499
- params (Optional[Dict [str, Any]]): Additional query parameters to pass to the provider.
500
- additional_headers (Optional[Dict [str, Any]]): Additional headers to be added to all requests.
517
+ params (Optional[dict [str, Any]]): Additional query parameters to pass to the provider.
518
+ additional_headers (Optional[dict [str, Any]]): Additional headers to be added to all requests.
501
519
redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance.
502
520
pkce_code_verifier (Optional[str]): A PKCE code verifier sent to the server to verify the login request.
503
521
convert_response (bool): If True, userinfo response is converted to OpenID object.
@@ -507,7 +525,7 @@ async def process_login(
507
525
508
526
Returns:
509
527
Optional[OpenID]: User information in OpenID format if the login was successful (convert_response == True).
510
- Optional[Dict [str, Any]]: Original userinfo API endpoint response.
528
+ Optional[dict [str, Any]]: Original userinfo API endpoint response.
511
529
"""
512
530
if self ._oauth_client is not None : # pragma: no cover
513
531
self ._oauth_client = None
@@ -565,5 +583,9 @@ async def process_login(
565
583
response = await session .get (uri )
566
584
content = response .json ()
567
585
if convert_response :
586
+ if self .use_id_token_for_user_info :
587
+ if not self ._id_token :
588
+ raise SSOLoginError (401 , f"Provider { self .provider !r} did not return id token." )
589
+ return await self .openid_from_token (_decode_id_token (self ._id_token ), session )
568
590
return await self .openid_from_response (content , session )
569
591
return content
0 commit comments