Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions backend/ee/onyx/server/reporting/usage_export_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
from ee.onyx.db.usage_export import write_usage_report
from ee.onyx.server.reporting.usage_export_models import UsageReportMetadata
from ee.onyx.server.reporting.usage_export_models import UserSkeleton
from onyx.auth.schemas import UserStatus
from onyx.configs.constants import FileOrigin
from onyx.db.users import list_users
from onyx.db.users import get_all_users
from onyx.file_store.constants import MAX_IN_MEMORY_SIZE
from onyx.file_store.file_store import FileStore
from onyx.file_store.file_store import get_default_file_store
Expand Down Expand Up @@ -84,15 +83,15 @@ def generate_user_report(
max_size=MAX_IN_MEMORY_SIZE, mode="w+"
) as temp_file:
csvwriter = csv.writer(temp_file, delimiter=",")
csvwriter.writerow(["user_id", "status"])
csvwriter.writerow(["user_id", "is_active"])

users = list_users(db_session)
users = get_all_users(db_session)
for user in users:
user_skeleton = UserSkeleton(
user_id=str(user.id),
status=UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED,
is_active=user.is_active,
)
csvwriter.writerow([user_skeleton.user_id, user_skeleton.status])
csvwriter.writerow([user_skeleton.user_id, user_skeleton.is_active])

temp_file.seek(0)
file_store.save_file(
Expand Down
4 changes: 1 addition & 3 deletions backend/ee/onyx/server/reporting/usage_export_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

from pydantic import BaseModel

from onyx.auth.schemas import UserStatus


class FlowType(str, Enum):
CHAT = "chat"
Expand All @@ -22,7 +20,7 @@ class ChatMessageSkeleton(BaseModel):

class UserSkeleton(BaseModel):
user_id: str
status: UserStatus
is_active: bool


class UsageReportMetadata(BaseModel):
Expand Down
6 changes: 0 additions & 6 deletions backend/onyx/auth/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,6 @@ def is_web_login(self) -> bool:
]


class UserStatus(str, Enum):
LIVE = "live"
INVITED = "invited"
DEACTIVATED = "deactivated"


class UserRead(schemas.BaseUser[uuid.UUID]):
role: UserRole

Expand Down
103 changes: 99 additions & 4 deletions backend/onyx/db/users.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from collections.abc import Sequence
from typing import Any
from uuid import UUID

from fastapi import HTTPException
from fastapi_users.password import PasswordHelper
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.sql import expression
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import KeyedColumnElement

from onyx.auth.invited_users import get_invited_users
from onyx.auth.invited_users import write_invited_users
from onyx.auth.schemas import UserRole
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.db.models import DocumentSet__User
from onyx.db.models import Persona__User
from onyx.db.models import SamlAccount
Expand Down Expand Up @@ -90,8 +95,10 @@ def validate_user_role_update(requested_role: UserRole, current_role: UserRole)
)


def list_users(
db_session: Session, email_filter_string: str = "", include_external: bool = False
def get_all_users(
db_session: Session,
email_filter_string: str | None = None,
include_external: bool = False,
) -> Sequence[User]:
"""List all users. No pagination as of now, as the # of users
is assumed to be relatively small (<< 1 million)"""
Expand All @@ -102,21 +109,109 @@ def list_users(
if not include_external:
where_clause.append(User.role != UserRole.EXT_PERM_USER)

if email_filter_string:
if email_filter_string is not None:
where_clause.append(User.email.ilike(f"%{email_filter_string}%")) # type: ignore

stmt = stmt.where(*where_clause)

return db_session.scalars(stmt).unique().all()


def _get_accepted_user_where_clause(
email_filter_string: str | None = None,
roles_filter: list[UserRole] = [],
include_external: bool = False,
is_active_filter: bool | None = None,
) -> list[ColumnElement[bool]]:
"""
Generates a SQLAlchemy where clause for filtering users based on the provided parameters.
This is used to build the filters for the function that retrieves the users for the users table in the admin panel.

Parameters:
- email_filter_string: A substring to filter user emails. Only users whose emails contain this substring will be included.
- is_active_filter: When True, only active users will be included. When False, only inactive users will be included.
- roles_filter: A list of user roles to filter by. Only users with roles in this list will be included.
- include_external: If False, external permissioned users will be excluded.

Returns:
- list: A list of conditions to be used in a SQLAlchemy query to filter users.
"""

# Access table columns directly via __table__.c to get proper SQLAlchemy column types
# This ensures type checking works correctly for SQL operations like ilike, endswith, and is_
email_col: KeyedColumnElement[Any] = User.__table__.c.email
is_active_col: KeyedColumnElement[Any] = User.__table__.c.is_active

where_clause: list[ColumnElement[bool]] = [
expression.not_(email_col.endswith(DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN))
]

if not include_external:
where_clause.append(User.role != UserRole.EXT_PERM_USER)

if email_filter_string is not None:
where_clause.append(email_col.ilike(f"%{email_filter_string}%"))

if roles_filter:
where_clause.append(User.role.in_(roles_filter))

if is_active_filter is not None:
where_clause.append(is_active_col.is_(is_active_filter))

return where_clause


def get_page_of_filtered_users(
db_session: Session,
page_size: int,
page_num: int,
email_filter_string: str | None = None,
is_active_filter: bool | None = None,
roles_filter: list[UserRole] = [],
include_external: bool = False,
) -> Sequence[User]:
users_stmt = select(User)

where_clause = _get_accepted_user_where_clause(
email_filter_string=email_filter_string,
roles_filter=roles_filter,
include_external=include_external,
is_active_filter=is_active_filter,
)
# Apply pagination
users_stmt = users_stmt.offset((page_num) * page_size).limit(page_size)
# Apply filtering
users_stmt = users_stmt.where(*where_clause)

return db_session.scalars(users_stmt).unique().all()


def get_total_filtered_users_count(
db_session: Session,
email_filter_string: str | None = None,
is_active_filter: bool | None = None,
roles_filter: list[UserRole] = [],
include_external: bool = False,
) -> int:
where_clause = _get_accepted_user_where_clause(
email_filter_string=email_filter_string,
roles_filter=roles_filter,
include_external=include_external,
is_active_filter=is_active_filter,
)
total_count_stmt = select(func.count()).select_from(User)
# Apply filtering
total_count_stmt = total_count_stmt.where(*where_clause)

return db_session.scalar(total_count_stmt) or 0


def get_user_by_email(email: str, db_session: Session) -> User | None:
user = (
db_session.query(User)
.filter(func.lower(User.email) == func.lower(email))
.first()
)

return user


Expand Down
16 changes: 9 additions & 7 deletions backend/onyx/server/documents/cc_pair.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import math
from datetime import datetime
from http import HTTPStatus

Expand Down Expand Up @@ -48,7 +47,8 @@
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
from onyx.server.documents.models import ConnectorCredentialPairMetadata
from onyx.server.documents.models import DocumentSyncStatus
from onyx.server.documents.models import PaginatedIndexAttempts
from onyx.server.documents.models import IndexAttemptSnapshot
from onyx.server.documents.models import PaginatedReturn
from onyx.server.models import StatusResponse
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
Expand All @@ -64,7 +64,7 @@ def get_cc_pair_index_attempts(
page_size: int = Query(10, ge=1, le=1000),
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> PaginatedIndexAttempts:
) -> PaginatedReturn[IndexAttemptSnapshot]:
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id, db_session, user, get_editable=False
)
Expand All @@ -82,10 +82,12 @@ def get_cc_pair_index_attempts(
page=page,
page_size=page_size,
)
return PaginatedIndexAttempts.from_models(
index_attempt_models=index_attempts,
page=page,
total_pages=math.ceil(total_count / page_size),
return PaginatedReturn(
items=[
IndexAttemptSnapshot.from_index_attempt_db_model(index_attempt)
for index_attempt in index_attempts
],
total_items=total_count,
)


Expand Down
35 changes: 16 additions & 19 deletions backend/onyx/server/documents/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from datetime import datetime
from typing import Any
from typing import Generic
from typing import TypeVar
from uuid import UUID

from pydantic import BaseModel
Expand All @@ -19,6 +21,8 @@
from onyx.db.models import IndexAttemptError as DbIndexAttemptError
from onyx.db.models import IndexingStatus
from onyx.db.models import TaskStatus
from onyx.server.models import FullUserSnapshot
from onyx.server.models import InvitedUserSnapshot
from onyx.server.utils import mask_credential_dict


Expand Down Expand Up @@ -201,26 +205,19 @@ def from_db_model(cls, error: DbIndexAttemptError) -> "IndexAttemptError":
)


class PaginatedIndexAttempts(BaseModel):
index_attempts: list[IndexAttemptSnapshot]
page: int
total_pages: int
# These are the types currently supported by the pagination hook
# More api endpoints can be refactored and be added here for use with the pagination hook
PaginatedType = TypeVar(
"PaginatedType",
IndexAttemptSnapshot,
FullUserSnapshot,
InvitedUserSnapshot,
)

@classmethod
def from_models(
cls,
index_attempt_models: list[IndexAttempt],
page: int,
total_pages: int,
) -> "PaginatedIndexAttempts":
return cls(
index_attempts=[
IndexAttemptSnapshot.from_index_attempt_db_model(index_attempt_model)
for index_attempt_model in index_attempt_models
],
page=page,
total_pages=total_pages,
)

class PaginatedReturn(BaseModel, Generic[PaginatedType]):
items: list[PaginatedType]
total_items: int


class CCPairFullInfo(BaseModel):
Expand Down
Loading
Loading