Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions soauth/api/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ async def code(
return KeyRefreshResponse(
access_token=key_content.access_token,
refresh_token=key_content.refresh_token,
profile_data=key_content.profile_data,
access_token_expires=key_content.access_token_expires,
refresh_token_expires=key_content.refresh_token_expires,
redirect=redirect,
Expand Down
2 changes: 2 additions & 0 deletions soauth/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
class KeyRefreshResponse(BaseModel):
access_token: str
refresh_token: str
profile_data: dict[str, str | None]
redirect: str | None = None
access_token_expires: datetime
refresh_token_expires: datetime



class APIKeyCreationResponse(BaseModel):
app_name: str
app_id: UUID
Expand Down
1 change: 1 addition & 0 deletions soauth/core/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class UserData(BaseModel):
user_id: UUID
user_name: str
full_name: str | None
profile_image: str | None
email: str | None
grants: set[str] | None
group_names: list[str] | None
Expand Down
10 changes: 10 additions & 0 deletions soauth/database/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class User(SQLModel, table=True):

gh_access_token: str | None = None
# gh_refresh_token: str | None = None
gh_profile_image_url: str | None = None
gh_last_logged_in: datetime | None = Field(
sa_column=Column(DateTime(timezone=True)), default=None
)
Expand Down Expand Up @@ -112,6 +113,14 @@ def get_effective_grants(self, include_groups: bool = True) -> set[str]:
def has_effective_grant(self, grant: str) -> bool:
"""Check if user has grant either individually or through groups."""
return grant in self.get_effective_grants()

def to_public_profile_data(self) -> dict[str, str | None]:
"""Convert user data to public profile format."""
return {
"username": self.user_name,
"full_name": self.full_name,
"profile_image": self.gh_profile_image_url,
}

def to_core(self, include_groups=True) -> UserData:
return UserData(
Expand All @@ -124,4 +133,5 @@ def to_core(self, include_groups=True) -> UserData:
group_ids=[str(x.group_id) for x in self.groups]
if include_groups
else None,
profile_image=self.gh_profile_image_url,
)
7 changes: 6 additions & 1 deletion soauth/service/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from soauth.database.app import App
from soauth.database.user import User
from soauth.service.provider import AuthProvider
from soauth.service.user import read_by_id

from .auth import create_auth_key
from .refresh import (
Expand Down Expand Up @@ -81,10 +82,12 @@ async def primary(
refresh_key=refresh_key, settings=settings, conn=conn
)
await log.ainfo("primary.auth_key_created")
profile_data = user.to_public_profile_data()

return KeyRefreshResponse(
access_token=encoded_auth_key,
refresh_token=encoded_refresh_key,
profile_data=profile_data,
access_token_expires=auth_key_expires,
refresh_token_expires=refresh_key.expires_at,
)
Expand Down Expand Up @@ -164,12 +167,14 @@ async def secondary(
)

await log.ainfo("secondary.auth_key_created")

user = await read_by_id(user_id=refresh_key.user_id, conn=conn)
profile_data = user.to_public_profile_data()
return KeyRefreshResponse(
access_token=encoded_auth_key,
refresh_token=encoded_refresh_key,
access_token_expires=auth_key_expires,
refresh_token_expires=refresh_key.expires_at,
profile_data=profile_data,
)


Expand Down
8 changes: 5 additions & 3 deletions soauth/service/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,10 @@ async def user_from_access_token(
)

username = user_info["login"].lower()

profile_image = user_info["avatar_url"]
log = log.bind(
user_name=username, email=user_info["email"], full_name=user_info["name"]
user_name=username, email=user_info["email"], full_name=user_info["name"],
profile_image=profile_image,
)

user_email = user_info["email"]
Expand All @@ -304,20 +305,21 @@ async def user_from_access_token(
user = await read_by_name(user_name=username, conn=conn)
user.email = user_email
user.full_name = user_info["name"]
user.gh_profile_image_url = profile_image
log = log.bind(user_read=True, user_created=False)
except UserNotFound:
user = await create_user(
user_name=username,
email=user_email,
full_name=user_info["name"],
profile_image=profile_image,
grants="",
conn=conn,
log=log,
)
log = log.bind(user_created=True, user_read=False)

log.bind(user_id=user.user_id)

user.gh_access_token = access_token
user.gh_last_logged_in = gh_last_logged_in

Expand Down
4 changes: 3 additions & 1 deletion soauth/service/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ async def create(
user_name: str,
email: str,
full_name: str,
profile_image: str | None,
grants: str,
conn: AsyncSession,
log: FilteringBoundLogger,
Expand All @@ -43,7 +44,8 @@ async def create(

try:
user = User(
user_name=user_name, email=email, grants=grants, full_name=full_name
user_name=user_name, email=email, grants=grants, full_name=full_name,
gh_profile_image_url=profile_image,
)
except IntegrityError:
await log.ainfo("user.create.exists")
Expand Down
21 changes: 19 additions & 2 deletions soauth/toolkit/starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ async def lifespan(app: FastAPI):
There is a global setup function that can be used defined in the `fastapi.py`
file, for FastAPI services.
"""
import json

import httpx
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -301,7 +302,7 @@ def key_expired_handler(request: Request, exc: KeyExpiredError) -> RedirectRespo
return response

content = KeyRefreshResponse.model_validate_json(response.content)

response = RedirectResponse(request.url, status_code=302)

response.set_cookie(
Expand Down Expand Up @@ -331,6 +332,13 @@ def key_expired_handler(request: Request, exc: KeyExpiredError) -> RedirectRespo
httponly=False,
)

response.set_cookie(
key="profile_data",
value=json.dumps(content.profile_data),
expires=content.access_token_expires,
httponly=False,
)

log.info("tk.starlette.expired.refreshed")

return response
Expand Down Expand Up @@ -363,6 +371,8 @@ def key_decode_handler(request: Request, exc: KeyDecodeError) -> RedirectRespons
response.delete_cookie(refresh_token_name)
response.delete_cookie("valid_refresh_token")
response.delete_cookie("validate_access_token")
response.delete_cookie("profile_data")

log.info("tk.starlette.decode.redirecting")

return response
Expand Down Expand Up @@ -458,6 +468,13 @@ async def handle_redirect(code: str, state: str, request: Request) -> RedirectRe
httponly=False,
)

response.set_cookie(
key="profile_data",
value=json.dumps(content.profile_data),
expires=content.access_token_expires,
httponly=False,
)

return response


Expand Down Expand Up @@ -487,5 +504,5 @@ async def logout(request: Request) -> RedirectResponse:
response.delete_cookie(access_token_name)
response.delete_cookie("valid_refresh_token")
response.delete_cookie("validate_access_token")

response.delete_cookie("profile_data")
return response
1 change: 1 addition & 0 deletions tests/test_service/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ async def user(session_manager, logger):
user_name="admin",
email="admin@simonsobservatory.org",
full_name="Admin User",
profile_image=None,
grants="admin",
conn=conn,
log=logger,
Expand Down
1 change: 1 addition & 0 deletions tests/test_service/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ async def test_create_group(server_settings, session_manager, logger, user):
user_name="test_user2",
email="asfasdf@salsdfasd.com",
full_name="Test User 2",
profile_image=None,
grants="",
conn=conn,
log=logger,
Expand Down
3 changes: 3 additions & 0 deletions tests/test_service/test_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ async def test_create_user(server_settings, session_manager, logger):
user_name="test_user",
email="test_user@email.com",
full_name="Test User",
profile_image=None,
grants="",
conn=conn,
log=logger,
Expand Down Expand Up @@ -75,6 +76,7 @@ async def test_user_effective_grants_with_group(
user_name="test_user",
email="test_user@email.com",
full_name="Test User Group Grants",
profile_image=None,
grants="user_grant",
conn=conn,
log=logger,
Expand Down Expand Up @@ -120,6 +122,7 @@ async def test_user_with_no_groups(server_settings, session_manager, logger):
user_name="test_user_no_groups",
email="test@user.com",
full_name="Test User No Groups",
profile_image=None,
grants="",
conn=conn,
log=logger,
Expand Down
Loading