Skip to content

Commit 5921196

Browse files
authored
feat:use id token for linkedin userinfo (#229)
* feat:use id token for linkedin userinfo + remove Python 3.8 support * chore: implement missing tests * chore: reorganize imports in test files for clarity
1 parent 2f5db65 commit 5921196

19 files changed

+150
-136
lines changed

.github/workflows/lint.yml

-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ jobs:
88
strategy:
99
matrix:
1010
python-version:
11-
- "3.8"
1211
- "3.9"
1312
- "3.10"
1413
- "3.11"

.github/workflows/test.yml

-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ jobs:
1111
strategy:
1212
matrix:
1313
python-version:
14-
- "3.8"
1514
- "3.9"
1615
- "3.10"
1716
- "3.11"

examples/generic.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""This is an example usage of fastapi-sso."""
22

3-
from typing import Any, Dict, Union
3+
from typing import Any, Union
44
from httpx import AsyncClient
55
import uvicorn
66
from fastapi import FastAPI, HTTPException
@@ -23,7 +23,7 @@
2323
# and then python examples/generic.py
2424

2525

26-
def convert_openid(response: Dict[str, Any], _client: Union[AsyncClient, None]) -> OpenID:
26+
def convert_openid(response: dict[str, Any], _client: Union[AsyncClient, None]) -> OpenID:
2727
"""Convert user information returned by OIDC"""
2828
print(response)
2929
return OpenID(display_name=response["sub"])

examples/linkedin.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
sso = LinkedInSSO(
1414
client_id=CLIENT_ID,
1515
client_secret=CLIENT_SECRET,
16-
redirect_uri="http://localhost:5000/auth/callback",
16+
redirect_uri="http://localhost:5050/auth/callback",
1717
allow_insecure_http=True,
1818
)
1919

@@ -34,4 +34,4 @@ async def auth_callback(request: Request):
3434

3535

3636
if __name__ == "__main__":
37-
uvicorn.run(app="examples.linkedin:app", host="127.0.0.1", port=5000)
37+
uvicorn.run(app="examples.linkedin:app", host="127.0.0.1", port=5050)

fastapi_sso/pkce.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import base64
44
import hashlib
55
import os
6-
from typing import Tuple
76

87

98
def get_code_verifier(length: int = 96) -> str:
@@ -13,7 +12,7 @@ def get_code_verifier(length: int = 96) -> str:
1312
return base64.urlsafe_b64encode(os.urandom(bytes_length)).decode("utf-8").replace("=", "")[:length]
1413

1514

16-
def get_pkce_challenge_pair(verifier_length: int = 96) -> Tuple[str, str]:
15+
def get_pkce_challenge_pair(verifier_length: int = 96) -> tuple[str, str]:
1716
"""Get tuple of (verifier, challenge) for PKCE challenge."""
1817
code_verifier = get_code_verifier(verifier_length)
1918
code_challenge = (

fastapi_sso/sso/base.py

+55-33
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
import sys
88
import warnings
99
from types import TracebackType
10-
from typing import Any, ClassVar, Dict, List, Literal, Optional, Type, TypedDict, TypeVar, Union, overload
10+
from typing import Any, ClassVar, Literal, Optional, TypedDict, TypeVar, Union, overload
1111

1212
import httpx
13+
import jwt
1314
import pydantic
1415
from oauthlib.oauth2 import WebApplicationClient
1516
from starlette.exceptions import HTTPException
@@ -33,6 +34,10 @@
3334
P = ParamSpec("P")
3435

3536

37+
def _decode_id_token(id_token: str, verify: bool = False) -> dict:
38+
return jwt.decode(id_token, options={"verify_signature": verify})
39+
40+
3641
class DiscoveryDocument(TypedDict):
3742
"""Discovery document."""
3843

@@ -95,10 +100,11 @@ class SSOBase:
95100
client_id: str = NotImplemented
96101
client_secret: str = NotImplemented
97102
redirect_uri: Optional[Union[pydantic.AnyHttpUrl, str]] = NotImplemented
98-
scope: ClassVar[List[str]] = []
99-
additional_headers: ClassVar[Optional[Dict[str, Any]]] = None
103+
scope: ClassVar[list[str]] = []
104+
additional_headers: ClassVar[Optional[dict[str, Any]]] = None
100105
uses_pkce: bool = False
101106
requires_state: bool = False
107+
use_id_token_for_user_info: ClassVar[bool] = False
102108

103109
_pkce_challenge_length: int = 96
104110

@@ -109,7 +115,7 @@ def __init__(
109115
redirect_uri: Optional[Union[pydantic.AnyHttpUrl, str]] = None,
110116
allow_insecure_http: bool = False,
111117
use_state: bool = False,
112-
scope: Optional[List[str]] = None,
118+
scope: Optional[list[str]] = None,
113119
):
114120
"""Base class (mixin) for all SSO providers."""
115121
self.client_id: str = client_id
@@ -224,6 +230,18 @@ async def openid_from_response(self, response: dict, session: Optional[httpx.Asy
224230
"""
225231
raise NotImplementedError(f"Provider {self.provider} not supported")
226232

233+
async def openid_from_token(self, id_token: dict, session: Optional[httpx.AsyncClient] = None) -> OpenID:
234+
"""Converts an ID token from the provider's token endpoint to an OpenID object.
235+
236+
Args:
237+
id_token (dict): The id token data retrieved from the token endpoint.
238+
session: (Optional[httpx.AsyncClient]): The HTTPX AsyncClient session.
239+
240+
Returns:
241+
OpenID: The user information in a standardized format.
242+
"""
243+
raise NotImplementedError(f"Provider {self.provider} not supported")
244+
227245
async def get_discovery_document(self) -> DiscoveryDocument:
228246
"""Retrieves the discovery document containing useful URLs.
229247
@@ -257,14 +275,14 @@ async def get_login_url(
257275
self,
258276
*,
259277
redirect_uri: Optional[Union[pydantic.AnyHttpUrl, str]] = None,
260-
params: Optional[Dict[str, Any]] = None,
278+
params: Optional[dict[str, Any]] = None,
261279
state: Optional[str] = None,
262280
) -> str:
263281
"""Generates and returns the prepared login URL.
264282
265283
Args:
266284
redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance.
267-
params (Optional[Dict[str, Any]]): Additional query parameters to add to the login request.
285+
params (Optional[dict[str, Any]]): Additional query parameters to add to the login request.
268286
state (Optional[str]): The state parameter for the OAuth 2.0 authorization request.
269287
270288
Raises:
@@ -304,14 +322,14 @@ async def get_login_redirect(
304322
self,
305323
*,
306324
redirect_uri: Optional[str] = None,
307-
params: Optional[Dict[str, Any]] = None,
325+
params: Optional[dict[str, Any]] = None,
308326
state: Optional[str] = None,
309327
) -> RedirectResponse:
310328
"""Constructs and returns a redirect response to the login page of OAuth SSO provider.
311329
312330
Args:
313331
redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance.
314-
params (Optional[Dict[str, Any]]): Additional query parameters to add to the login request.
332+
params (Optional[dict[str, Any]]): Additional query parameters to add to the login request.
315333
state (Optional[str]): The state parameter for the OAuth 2.0 authorization request.
316334
317335
Returns:
@@ -330,8 +348,8 @@ async def verify_and_process(
330348
self,
331349
request: Request,
332350
*,
333-
params: Optional[Dict[str, Any]] = None,
334-
headers: Optional[Dict[str, Any]] = None,
351+
params: Optional[dict[str, Any]] = None,
352+
headers: Optional[dict[str, Any]] = None,
335353
redirect_uri: Optional[str] = None,
336354
convert_response: Literal[True] = True,
337355
) -> Optional[OpenID]: ...
@@ -341,28 +359,28 @@ async def verify_and_process(
341359
self,
342360
request: Request,
343361
*,
344-
params: Optional[Dict[str, Any]] = None,
345-
headers: Optional[Dict[str, Any]] = None,
362+
params: Optional[dict[str, Any]] = None,
363+
headers: Optional[dict[str, Any]] = None,
346364
redirect_uri: Optional[str] = None,
347365
convert_response: Literal[False],
348-
) -> Optional[Dict[str, Any]]: ...
366+
) -> Optional[dict[str, Any]]: ...
349367

350368
@requires_async_context
351369
async def verify_and_process(
352370
self,
353371
request: Request,
354372
*,
355-
params: Optional[Dict[str, Any]] = None,
356-
headers: Optional[Dict[str, Any]] = None,
373+
params: Optional[dict[str, Any]] = None,
374+
headers: Optional[dict[str, Any]] = None,
357375
redirect_uri: Optional[str] = None,
358376
convert_response: Union[Literal[True], Literal[False]] = True,
359-
) -> Union[Optional[OpenID], Optional[Dict[str, Any]]]:
377+
) -> Union[Optional[OpenID], Optional[dict[str, Any]]]:
360378
"""Processes the login given a FastAPI (Starlette) Request object. This should be used for the /callback path.
361379
362380
Args:
363381
request (Request): FastAPI or Starlette request object.
364-
params (Optional[Dict[str, Any]]): Additional query parameters to pass to the provider.
365-
headers (Optional[Dict[str, Any]]): Additional headers to pass to the provider.
382+
params (Optional[dict[str, Any]]): Additional query parameters to pass to the provider.
383+
headers (Optional[dict[str, Any]]): Additional headers to pass to the provider.
366384
redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance.
367385
convert_response (bool): If True, userinfo response is converted to OpenID object.
368386
@@ -371,7 +389,7 @@ async def verify_and_process(
371389
372390
Returns:
373391
Optional[OpenID]: User information as OpenID instance (if convert_response == True)
374-
Optional[Dict[str, Any]]: The original JSON response from the API.
392+
Optional[dict[str, Any]]: The original JSON response from the API.
375393
"""
376394
headers = headers or {}
377395
code = request.query_params.get("code")
@@ -433,7 +451,7 @@ async def __aenter__(self) -> "SSOBase":
433451

434452
async def __aexit__(
435453
self,
436-
_exc_type: Optional[Type[BaseException]],
454+
_exc_type: Optional[type[BaseException]],
437455
_exc_val: Optional[BaseException],
438456
_exc_tb: Optional[TracebackType],
439457
) -> None:
@@ -442,14 +460,14 @@ async def __aexit__(
442460

443461
def __exit__(
444462
self,
445-
_exc_type: Optional[Type[BaseException]],
463+
_exc_type: Optional[type[BaseException]],
446464
_exc_val: Optional[BaseException],
447465
_exc_tb: Optional[TracebackType],
448466
) -> None:
449467
return None
450468

451469
@property
452-
def _extra_query_params(self) -> Dict:
470+
def _extra_query_params(self) -> dict:
453471
return {}
454472

455473
@overload
@@ -458,8 +476,8 @@ async def process_login(
458476
code: str,
459477
request: Request,
460478
*,
461-
params: Optional[Dict[str, Any]] = None,
462-
additional_headers: Optional[Dict[str, Any]] = None,
479+
params: Optional[dict[str, Any]] = None,
480+
additional_headers: Optional[dict[str, Any]] = None,
463481
redirect_uri: Optional[str] = None,
464482
pkce_code_verifier: Optional[str] = None,
465483
convert_response: Literal[True] = True,
@@ -471,33 +489,33 @@ async def process_login(
471489
code: str,
472490
request: Request,
473491
*,
474-
params: Optional[Dict[str, Any]] = None,
475-
additional_headers: Optional[Dict[str, Any]] = None,
492+
params: Optional[dict[str, Any]] = None,
493+
additional_headers: Optional[dict[str, Any]] = None,
476494
redirect_uri: Optional[str] = None,
477495
pkce_code_verifier: Optional[str] = None,
478496
convert_response: Literal[False],
479-
) -> Optional[Dict[str, Any]]: ...
497+
) -> Optional[dict[str, Any]]: ...
480498

481499
@requires_async_context
482500
async def process_login(
483501
self,
484502
code: str,
485503
request: Request,
486504
*,
487-
params: Optional[Dict[str, Any]] = None,
488-
additional_headers: Optional[Dict[str, Any]] = None,
505+
params: Optional[dict[str, Any]] = None,
506+
additional_headers: Optional[dict[str, Any]] = None,
489507
redirect_uri: Optional[str] = None,
490508
pkce_code_verifier: Optional[str] = None,
491509
convert_response: Union[Literal[True], Literal[False]] = True,
492-
) -> Union[Optional[OpenID], Optional[Dict[str, Any]]]:
510+
) -> Union[Optional[OpenID], Optional[dict[str, Any]]]:
493511
"""Processes login from the callback endpoint to verify the user and request user info endpoint.
494512
It's a lower-level method, typically, you should use `verify_and_process` instead.
495513
496514
Args:
497515
code (str): The authorization code.
498516
request (Request): FastAPI or Starlette request object.
499-
params (Optional[Dict[str, Any]]): Additional query parameters to pass to the provider.
500-
additional_headers (Optional[Dict[str, Any]]): Additional headers to be added to all requests.
517+
params (Optional[dict[str, Any]]): Additional query parameters to pass to the provider.
518+
additional_headers (Optional[dict[str, Any]]): Additional headers to be added to all requests.
501519
redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance.
502520
pkce_code_verifier (Optional[str]): A PKCE code verifier sent to the server to verify the login request.
503521
convert_response (bool): If True, userinfo response is converted to OpenID object.
@@ -507,7 +525,7 @@ async def process_login(
507525
508526
Returns:
509527
Optional[OpenID]: User information in OpenID format if the login was successful (convert_response == True).
510-
Optional[Dict[str, Any]]: Original userinfo API endpoint response.
528+
Optional[dict[str, Any]]: Original userinfo API endpoint response.
511529
"""
512530
if self._oauth_client is not None: # pragma: no cover
513531
self._oauth_client = None
@@ -565,5 +583,9 @@ async def process_login(
565583
response = await session.get(uri)
566584
content = response.json()
567585
if convert_response:
586+
if self.use_id_token_for_user_info:
587+
if not self._id_token:
588+
raise SSOLoginError(401, f"Provider {self.provider!r} did not return id token.")
589+
return await self.openid_from_token(_decode_id_token(self._id_token), session)
568590
return await self.openid_from_response(content, session)
569591
return content

fastapi_sso/sso/bitbucket.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""BitBucket SSO Oauth Helper class"""
22

3-
from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
3+
from typing import TYPE_CHECKING, ClassVar, Optional, Union
44

55
import pydantic
66

@@ -23,7 +23,7 @@ def __init__(
2323
client_secret: str,
2424
redirect_uri: Optional[Union[pydantic.AnyHttpUrl, str]] = None,
2525
allow_insecure_http: bool = False,
26-
scope: Optional[List[str]] = None,
26+
scope: Optional[list[str]] = None,
2727
):
2828
super().__init__(
2929
client_id=client_id,

fastapi_sso/sso/discord.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Discord SSO Oauth Helper class"""
22

3-
from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
3+
from typing import TYPE_CHECKING, ClassVar, Optional, Union
44

55
import pydantic
66

@@ -22,7 +22,7 @@ def __init__(
2222
client_secret: str,
2323
redirect_uri: Optional[Union[pydantic.AnyHttpUrl, str]] = None,
2424
allow_insecure_http: bool = False,
25-
scope: Optional[List[str]] = None,
25+
scope: Optional[list[str]] = None,
2626
):
2727
super().__init__(
2828
client_id=client_id,

fastapi_sso/sso/generic.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44

55
import logging
6-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
6+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
77

88
from fastapi_sso.sso.base import DiscoveryDocument, OpenID, SSOBase
99

@@ -16,10 +16,10 @@
1616
def create_provider(
1717
*,
1818
name: str = "generic",
19-
default_scope: Optional[List[str]] = None,
19+
default_scope: Optional[list[str]] = None,
2020
discovery_document: Union[DiscoveryDocument, Callable[[SSOBase], DiscoveryDocument]],
21-
response_convertor: Optional[Callable[[Dict[str, Any], Optional["httpx.AsyncClient"]], OpenID]] = None
22-
) -> Type[SSOBase]:
21+
response_convertor: Optional[Callable[[dict[str, Any], Optional["httpx.AsyncClient"]], OpenID]] = None
22+
) -> type[SSOBase]:
2323
"""A factory to create a generic OAuth client usable with almost any OAuth provider.
2424
Returns a class.
2525

0 commit comments

Comments
 (0)