Skip to content

OAuth2 & OpenID auth handler #67

@asaf

Description

@asaf

Support endpoints protection with OpenID issued tokens.

At this point NOOP & JWT_OIDC are sufficient (no need local)

Configuration:

class AuthType(Enum):
    NOOP = "noop"
    JWT_LOCAL = "jwt_local"
    JWT_OIDC = "jwt_oidc"

class NoopAuthSettings(BaseModel):
    auth_type: AuthType = AuthType.NOOP
    moo: str

class JWTLocalAuthSettings(BaseSettings):  # Inherits from BaseSettings!
    secret_key: str
    algorithm: str = "HS256"

    model_config = SettingsConfigDict(env_prefix="FLUX0_AUTH_")


AuthSettings = Union[NoopAuthSettings, JWTLocalAuthSettings]

Except FLUX0_AUTH_TYPE some types require extra env vars such as OIDC well known URL, ...

A starting point:

class JWTAuthBase(AuthHandler):
    user_store: UserStore

    def __init__(self, user_store: UserStore):
        self.user_store = user_store

    async def __call__(self, request: Request) -> User:
        http_bearer = await HTTPBearer()(request)
        if not http_bearer:
            raise HTTPException(status_code=401, detail="Invalid token")
        token = http_bearer.credentials

        try:
            payload = self.decode_token(token, self.get_decode_key(token))
        except jwt.PyJWTError as e:
            raise HTTPException(status_code=401, detail=str(e))

        sub = payload["sub"]
        if not sub:
            raise HTTPException(status_code=401, detail="Invalid token")

        user = await self.user_store.read_user_by_sub(sub)
        if not user:
            user = await self.user_store.create_user(
                sub=sub, name=NOOP_AUTH_HANDLER_DEFAULT_SUB.capitalize()
            )

        return user

    @abstractmethod
    def decode_token(self, token: str, decode_key: str) -> dict[str, str]: ...

    @abstractmethod
    def get_decode_key(self, token: str) -> str: ...


class JWTAuthLocal(JWTAuthBase):
    """Auth handler that uses a hardcoded decode key from env."""

    def decode_token(self, token: str, decode_key: str) -> Any:
        return jwt.decode(
            token,
            decode_key,
            issuer=settings.jwt_local.iss,
            audience=settings.jwt_local.aud,
            algorithms=[settings.jwt_local.alg.upper()],
            options={"require": ["exp", "iss", "aud", "sub"]},
        )

    def get_decode_key(self, token: str) -> str:
        return settings.jwt_local.decode_key

Metadata

Metadata

Assignees

Labels

apiAPI module

Projects

Status

Ready

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions