Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 33 additions & 17 deletions backend/onyx/db/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@
from collections.abc import Callable
from typing import Any
from typing import Dict
from typing import TypeVar

from fastapi import Depends
from fastapi_users.models import ID
from fastapi_users.models import UP
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyAccessTokenDatabase
from sqlalchemy import func
from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import Session

from onyx.auth.invited_users import get_invited_users
from onyx.auth.schemas import UserRole
from onyx.db.api_key import get_api_key_email_pattern
from onyx.db.engine import get_async_session
Expand All @@ -25,6 +26,8 @@
fetch_versioned_implementation_with_fallback,
)

T = TypeVar("T", bound=tuple[Any, ...])


def get_default_admin_user_emails() -> list[str]:
"""Returns a list of emails who should default to Admin role.
Expand All @@ -37,31 +40,44 @@ def get_default_admin_user_emails() -> list[str]:
return get_default_admin_user_emails_fn()


def get_total_users_count(db_session: Session) -> int:
def _add_live_user_count_where_clause(
select_stmt: Select[T],
only_admin_users: bool,
) -> Select[T]:
"""
Returns the total number of users in the system.
This is the sum of users and invited users.
Builds a SQL column expression that can be used to filter out
users who should not be included in the live user count.
"""
user_count = (
db_session.query(User)
.filter(
~User.email.endswith(get_api_key_email_pattern()), # type: ignore
User.role != UserRole.EXT_PERM_USER,
)
.count()
select_stmt = select_stmt.where(~User.email.endswith(get_api_key_email_pattern())) # type: ignore
if only_admin_users:
return select_stmt.where(User.role == UserRole.ADMIN)

return select_stmt.where(
User.role != UserRole.EXT_PERM_USER,
)
invited_users = len(get_invited_users())
return user_count + invited_users


def get_live_users_count(db_session: Session) -> int:
"""
Returns the number of users in the system.
This does NOT include invited users, "users" pulled in
from external connectors, or API keys.
"""
count_stmt = func.count(User.id) # type: ignore
select_stmt = select(count_stmt)
select_stmt_w_filters = _add_live_user_count_where_clause(select_stmt, False)
user_count = db_session.scalar(select_stmt_w_filters)
if user_count is None:
raise RuntimeError("Was not able to fetch the user count.")
return user_count


async def get_user_count(only_admin_users: bool = False) -> int:
async with get_async_session_context_manager() as session:
count_stmt = func.count(User.id) # type: ignore
stmt = select(count_stmt)
if only_admin_users:
stmt = stmt.where(User.role == UserRole.ADMIN)
result = await session.execute(stmt)
user_count = result.scalar()
stmt_w_filters = _add_live_user_count_where_clause(stmt, only_admin_users)
user_count = await session.scalar(stmt_w_filters)
if user_count is None:
raise RuntimeError("Was not able to fetch the user count.")
return user_count
Expand Down
6 changes: 3 additions & 3 deletions backend/onyx/server/manage/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from onyx.configs.constants import AuthType
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.api_key import is_api_key_email_address
from onyx.db.auth import get_total_users_count
from onyx.db.auth import get_live_users_count
from onyx.db.engine import get_session
from onyx.db.models import AccessToken
from onyx.db.models import User
Expand Down Expand Up @@ -343,7 +343,7 @@ def bulk_invite_users(
logger.info("Registering tenant users")
fetch_ee_implementation_or_noop(
"onyx.server.tenants.billing", "register_tenant_users", None
)(tenant_id, get_total_users_count(db_session))
)(tenant_id, get_live_users_count(db_session))

return number_of_invited_users
except Exception as e:
Expand Down Expand Up @@ -379,7 +379,7 @@ def remove_invited_user(
if MULTI_TENANT and not DEV_MODE:
fetch_ee_implementation_or_noop(
"onyx.server.tenants.billing", "register_tenant_users", None
)(tenant_id, get_total_users_count(db_session))
)(tenant_id, get_live_users_count(db_session))
except Exception:
logger.error(
"Request to update number of seats taken in control plane failed. "
Expand Down