Skip to content

Commit f21ff02

Browse files
WevesZhipengHe
authored andcommitted
Fix user count (onyx-dot-app#4677)
* Fix user count * Add helper + fix async function as well * fix mypy * Address RK comment
1 parent d248bee commit f21ff02

File tree

2 files changed

+36
-20
lines changed

2 files changed

+36
-20
lines changed

backend/onyx/db/auth.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,19 @@
22
from collections.abc import Callable
33
from typing import Any
44
from typing import Dict
5+
from typing import TypeVar
56

67
from fastapi import Depends
78
from fastapi_users.models import ID
89
from fastapi_users.models import UP
910
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
1011
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyAccessTokenDatabase
1112
from sqlalchemy import func
13+
from sqlalchemy import Select
1214
from sqlalchemy.ext.asyncio import AsyncSession
1315
from sqlalchemy.future import select
1416
from sqlalchemy.orm import Session
1517

16-
from onyx.auth.invited_users import get_invited_users
1718
from onyx.auth.schemas import UserRole
1819
from onyx.db.api_key import get_api_key_email_pattern
1920
from onyx.db.engine import get_async_session
@@ -25,6 +26,8 @@
2526
fetch_versioned_implementation_with_fallback,
2627
)
2728

29+
T = TypeVar("T", bound=tuple[Any, ...])
30+
2831

2932
def get_default_admin_user_emails() -> list[str]:
3033
"""Returns a list of emails who should default to Admin role.
@@ -37,31 +40,44 @@ def get_default_admin_user_emails() -> list[str]:
3740
return get_default_admin_user_emails_fn()
3841

3942

40-
def get_total_users_count(db_session: Session) -> int:
43+
def _add_live_user_count_where_clause(
44+
select_stmt: Select[T],
45+
only_admin_users: bool,
46+
) -> Select[T]:
4147
"""
42-
Returns the total number of users in the system.
43-
This is the sum of users and invited users.
48+
Builds a SQL column expression that can be used to filter out
49+
users who should not be included in the live user count.
4450
"""
45-
user_count = (
46-
db_session.query(User)
47-
.filter(
48-
~User.email.endswith(get_api_key_email_pattern()), # type: ignore
49-
User.role != UserRole.EXT_PERM_USER,
50-
)
51-
.count()
51+
select_stmt = select_stmt.where(~User.email.endswith(get_api_key_email_pattern())) # type: ignore
52+
if only_admin_users:
53+
return select_stmt.where(User.role == UserRole.ADMIN)
54+
55+
return select_stmt.where(
56+
User.role != UserRole.EXT_PERM_USER,
5257
)
53-
invited_users = len(get_invited_users())
54-
return user_count + invited_users
58+
59+
60+
def get_live_users_count(db_session: Session) -> int:
61+
"""
62+
Returns the number of users in the system.
63+
This does NOT include invited users, "users" pulled in
64+
from external connectors, or API keys.
65+
"""
66+
count_stmt = func.count(User.id) # type: ignore
67+
select_stmt = select(count_stmt)
68+
select_stmt_w_filters = _add_live_user_count_where_clause(select_stmt, False)
69+
user_count = db_session.scalar(select_stmt_w_filters)
70+
if user_count is None:
71+
raise RuntimeError("Was not able to fetch the user count.")
72+
return user_count
5573

5674

5775
async def get_user_count(only_admin_users: bool = False) -> int:
5876
async with get_async_session_context_manager() as session:
5977
count_stmt = func.count(User.id) # type: ignore
6078
stmt = select(count_stmt)
61-
if only_admin_users:
62-
stmt = stmt.where(User.role == UserRole.ADMIN)
63-
result = await session.execute(stmt)
64-
user_count = result.scalar()
79+
stmt_w_filters = _add_live_user_count_where_clause(stmt, only_admin_users)
80+
user_count = await session.scalar(stmt_w_filters)
6581
if user_count is None:
6682
raise RuntimeError("Was not able to fetch the user count.")
6783
return user_count

backend/onyx/server/manage/users.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from onyx.configs.constants import AuthType
4545
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
4646
from onyx.db.api_key import is_api_key_email_address
47-
from onyx.db.auth import get_total_users_count
47+
from onyx.db.auth import get_live_users_count
4848
from onyx.db.engine import get_session
4949
from onyx.db.models import AccessToken
5050
from onyx.db.models import User
@@ -343,7 +343,7 @@ def bulk_invite_users(
343343
logger.info("Registering tenant users")
344344
fetch_ee_implementation_or_noop(
345345
"onyx.server.tenants.billing", "register_tenant_users", None
346-
)(tenant_id, get_total_users_count(db_session))
346+
)(tenant_id, get_live_users_count(db_session))
347347

348348
return number_of_invited_users
349349
except Exception as e:
@@ -379,7 +379,7 @@ def remove_invited_user(
379379
if MULTI_TENANT and not DEV_MODE:
380380
fetch_ee_implementation_or_noop(
381381
"onyx.server.tenants.billing", "register_tenant_users", None
382-
)(tenant_id, get_total_users_count(db_session))
382+
)(tenant_id, get_live_users_count(db_session))
383383
except Exception:
384384
logger.error(
385385
"Request to update number of seats taken in control plane failed. "

0 commit comments

Comments
 (0)