From bfd0638a98b28e239d96e37c6c5a4de8df016b4c Mon Sep 17 00:00:00 2001 From: Weves Date: Thu, 8 May 2025 14:41:40 -0700 Subject: [PATCH 1/4] Fix user count --- backend/onyx/db/auth.py | 11 +++++------ backend/onyx/server/manage/users.py | 6 +++--- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/backend/onyx/db/auth.py b/backend/onyx/db/auth.py index 974d1b83ee4..dadb7c0ce56 100644 --- a/backend/onyx/db/auth.py +++ b/backend/onyx/db/auth.py @@ -13,7 +13,6 @@ from sqlalchemy.future import select from sqlalchemy.orm import Session -from onyx.auth.invited_users import get_invited_users from onyx.auth.schemas import UserRole from onyx.db.api_key import get_api_key_email_pattern from onyx.db.engine import get_async_session @@ -37,10 +36,11 @@ def get_default_admin_user_emails() -> list[str]: return get_default_admin_user_emails_fn() -def get_total_users_count(db_session: Session) -> int: +def get_live_users_count(db_session: Session) -> int: """ - Returns the total number of users in the system. - This is the sum of users and invited users. + Returns the number of users in the system. + This does NOT include invited users, "users" pulled in + from external connectors, or API keys. """ user_count = ( db_session.query(User) @@ -50,8 +50,7 @@ def get_total_users_count(db_session: Session) -> int: ) .count() ) - invited_users = len(get_invited_users()) - return user_count + invited_users + return user_count async def get_user_count(only_admin_users: bool = False) -> int: diff --git a/backend/onyx/server/manage/users.py b/backend/onyx/server/manage/users.py index a179128dc6a..9c93b4e93ce 100644 --- a/backend/onyx/server/manage/users.py +++ b/backend/onyx/server/manage/users.py @@ -44,7 +44,7 @@ from onyx.configs.constants import AuthType from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME from onyx.db.api_key import is_api_key_email_address -from onyx.db.auth import get_total_users_count +from onyx.db.auth import get_live_users_count from onyx.db.engine import get_session from onyx.db.models import AccessToken from onyx.db.models import User @@ -343,7 +343,7 @@ def bulk_invite_users( logger.info("Registering tenant users") fetch_ee_implementation_or_noop( "onyx.server.tenants.billing", "register_tenant_users", None - )(tenant_id, get_total_users_count(db_session)) + )(tenant_id, get_live_users_count(db_session)) return number_of_invited_users except Exception as e: @@ -379,7 +379,7 @@ def remove_invited_user( if MULTI_TENANT and not DEV_MODE: fetch_ee_implementation_or_noop( "onyx.server.tenants.billing", "register_tenant_users", None - )(tenant_id, get_total_users_count(db_session)) + )(tenant_id, get_live_users_count(db_session)) except Exception: logger.error( "Request to update number of seats taken in control plane failed. " From d705d851e604f8796c778284b8860e038df86986 Mon Sep 17 00:00:00 2001 From: Weves Date: Thu, 8 May 2025 14:50:09 -0700 Subject: [PATCH 2/4] Add helper + fix async function as well --- backend/onyx/db/auth.py | 41 +++++++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/backend/onyx/db/auth.py b/backend/onyx/db/auth.py index dadb7c0ce56..0af94a1b9e2 100644 --- a/backend/onyx/db/auth.py +++ b/backend/onyx/db/auth.py @@ -2,6 +2,7 @@ from collections.abc import Callable from typing import Any from typing import Dict +from typing import TypeVar from fastapi import Depends from fastapi_users.models import ID @@ -9,6 +10,7 @@ from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyAccessTokenDatabase from sqlalchemy import func +from sqlalchemy import Select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import Session @@ -24,6 +26,8 @@ fetch_versioned_implementation_with_fallback, ) +T = TypeVar("T") + def get_default_admin_user_emails() -> list[str]: """Returns a list of emails who should default to Admin role. @@ -36,20 +40,35 @@ def get_default_admin_user_emails() -> list[str]: return get_default_admin_user_emails_fn() +def _add_live_user_count_where_clause( + select_stmt: Select[T], + only_admin_users: bool, +) -> Select[T]: + """ + Builds a SQL column expression that can be used to filter out + users who should not be included in the live user count. + """ + select_stmt = select_stmt.where(~User.email.endswith(get_api_key_email_pattern())) # type: ignore + if only_admin_users: + return select_stmt.where(User.role == UserRole.ADMIN) + else: + return select_stmt.where( + User.role != UserRole.EXT_PERM_USER, + ) + + def get_live_users_count(db_session: Session) -> int: """ Returns the number of users in the system. This does NOT include invited users, "users" pulled in from external connectors, or API keys. """ - user_count = ( - db_session.query(User) - .filter( - ~User.email.endswith(get_api_key_email_pattern()), # type: ignore - User.role != UserRole.EXT_PERM_USER, - ) - .count() - ) + count_stmt = func.count(User.id) # type: ignore + select_stmt = select(count_stmt) + select_stmt_w_filters = _add_live_user_count_where_clause(select_stmt, False) + user_count = db_session.scalar(select_stmt_w_filters) + if user_count is None: + raise RuntimeError("Was not able to fetch the user count.") return user_count @@ -57,10 +76,8 @@ async def get_user_count(only_admin_users: bool = False) -> int: async with get_async_session_context_manager() as session: count_stmt = func.count(User.id) # type: ignore stmt = select(count_stmt) - if only_admin_users: - stmt = stmt.where(User.role == UserRole.ADMIN) - result = await session.execute(stmt) - user_count = result.scalar() + stmt_w_filters = _add_live_user_count_where_clause(stmt, only_admin_users) + user_count = await session.scalar(stmt_w_filters) if user_count is None: raise RuntimeError("Was not able to fetch the user count.") return user_count From bac3e3e6be139825b3a0f7dfe3c536f9012fe65c Mon Sep 17 00:00:00 2001 From: Weves Date: Thu, 8 May 2025 15:00:25 -0700 Subject: [PATCH 3/4] fix mypy --- backend/onyx/db/auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/onyx/db/auth.py b/backend/onyx/db/auth.py index 0af94a1b9e2..4aa2d3acebf 100644 --- a/backend/onyx/db/auth.py +++ b/backend/onyx/db/auth.py @@ -26,7 +26,7 @@ fetch_versioned_implementation_with_fallback, ) -T = TypeVar("T") +T = TypeVar("T", bound=tuple[Any, ...]) def get_default_admin_user_emails() -> list[str]: From 6e6f6c97fbc29f6ae97105fcd139d93eaf33553f Mon Sep 17 00:00:00 2001 From: Weves Date: Thu, 8 May 2025 17:17:00 -0700 Subject: [PATCH 4/4] Address RK comment --- backend/onyx/db/auth.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/backend/onyx/db/auth.py b/backend/onyx/db/auth.py index 4aa2d3acebf..c9e5c57f797 100644 --- a/backend/onyx/db/auth.py +++ b/backend/onyx/db/auth.py @@ -51,10 +51,10 @@ def _add_live_user_count_where_clause( select_stmt = select_stmt.where(~User.email.endswith(get_api_key_email_pattern())) # type: ignore if only_admin_users: return select_stmt.where(User.role == UserRole.ADMIN) - else: - return select_stmt.where( - User.role != UserRole.EXT_PERM_USER, - ) + + return select_stmt.where( + User.role != UserRole.EXT_PERM_USER, + ) def get_live_users_count(db_session: Session) -> int: