2
2
from collections .abc import Callable
3
3
from typing import Any
4
4
from typing import Dict
5
+ from typing import TypeVar
5
6
6
7
from fastapi import Depends
7
8
from fastapi_users .models import ID
8
9
from fastapi_users .models import UP
9
10
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
10
11
from fastapi_users_db_sqlalchemy .access_token import SQLAlchemyAccessTokenDatabase
11
12
from sqlalchemy import func
13
+ from sqlalchemy import Select
12
14
from sqlalchemy .ext .asyncio import AsyncSession
13
15
from sqlalchemy .future import select
14
16
from sqlalchemy .orm import Session
15
17
16
- from onyx .auth .invited_users import get_invited_users
17
18
from onyx .auth .schemas import UserRole
18
19
from onyx .db .api_key import get_api_key_email_pattern
19
20
from onyx .db .engine import get_async_session
25
26
fetch_versioned_implementation_with_fallback ,
26
27
)
27
28
29
+ T = TypeVar ("T" , bound = tuple [Any , ...])
30
+
28
31
29
32
def get_default_admin_user_emails () -> list [str ]:
30
33
"""Returns a list of emails who should default to Admin role.
@@ -37,31 +40,44 @@ def get_default_admin_user_emails() -> list[str]:
37
40
return get_default_admin_user_emails_fn ()
38
41
39
42
40
- def get_total_users_count (db_session : Session ) -> int :
43
+ def _add_live_user_count_where_clause (
44
+ select_stmt : Select [T ],
45
+ only_admin_users : bool ,
46
+ ) -> Select [T ]:
41
47
"""
42
- Returns the total number of users in the system.
43
- This is the sum of users and invited users .
48
+ Builds a SQL column expression that can be used to filter out
49
+ users who should not be included in the live user count .
44
50
"""
45
- user_count = (
46
- db_session .query (User )
47
- .filter (
48
- ~ User .email .endswith (get_api_key_email_pattern ()), # type: ignore
49
- User .role != UserRole .EXT_PERM_USER ,
50
- )
51
- .count ()
51
+ select_stmt = select_stmt .where (~ User .email .endswith (get_api_key_email_pattern ())) # type: ignore
52
+ if only_admin_users :
53
+ return select_stmt .where (User .role == UserRole .ADMIN )
54
+
55
+ return select_stmt .where (
56
+ User .role != UserRole .EXT_PERM_USER ,
52
57
)
53
- invited_users = len (get_invited_users ())
54
- return user_count + invited_users
58
+
59
+
60
+ def get_live_users_count (db_session : Session ) -> int :
61
+ """
62
+ Returns the number of users in the system.
63
+ This does NOT include invited users, "users" pulled in
64
+ from external connectors, or API keys.
65
+ """
66
+ count_stmt = func .count (User .id ) # type: ignore
67
+ select_stmt = select (count_stmt )
68
+ select_stmt_w_filters = _add_live_user_count_where_clause (select_stmt , False )
69
+ user_count = db_session .scalar (select_stmt_w_filters )
70
+ if user_count is None :
71
+ raise RuntimeError ("Was not able to fetch the user count." )
72
+ return user_count
55
73
56
74
57
75
async def get_user_count (only_admin_users : bool = False ) -> int :
58
76
async with get_async_session_context_manager () as session :
59
77
count_stmt = func .count (User .id ) # type: ignore
60
78
stmt = select (count_stmt )
61
- if only_admin_users :
62
- stmt = stmt .where (User .role == UserRole .ADMIN )
63
- result = await session .execute (stmt )
64
- user_count = result .scalar ()
79
+ stmt_w_filters = _add_live_user_count_where_clause (stmt , only_admin_users )
80
+ user_count = await session .scalar (stmt_w_filters )
65
81
if user_count is None :
66
82
raise RuntimeError ("Was not able to fetch the user count." )
67
83
return user_count
0 commit comments