Skip to content

Commit b5f0095

Browse files
committed
Fix non-default schema in KV store (#4655)
* Fix non-default schema in KV store * Fix custom schema
1 parent efab12a commit b5f0095

File tree

6 files changed

+100
-127
lines changed

6 files changed

+100
-127
lines changed

backend/ee/onyx/server/saml.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from onyx.db.auth import get_user_count
2929
from onyx.db.auth import get_user_db
3030
from onyx.db.engine import get_async_session
31+
from onyx.db.engine import get_async_session_context_manager
3132
from onyx.db.engine import get_session
3233
from onyx.db.models import User
3334
from onyx.utils.logger import setup_logger
@@ -39,13 +40,10 @@
3940

4041
async def upsert_saml_user(email: str) -> User:
4142
logger.debug(f"Attempting to upsert SAML user with email: {email}")
42-
get_async_session_context = contextlib.asynccontextmanager(
43-
get_async_session
44-
) # type:ignore
4543
get_user_db_context = contextlib.asynccontextmanager(get_user_db)
4644
get_user_manager_context = contextlib.asynccontextmanager(get_user_manager)
4745

48-
async with get_async_session_context() as session:
46+
async with get_async_session_context_manager() as session:
4947
async with get_user_db_context(session) as user_db:
5048
async with get_user_manager_context(user_db) as user_manager:
5149
try:

backend/onyx/auth/users.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@
9292
from onyx.db.auth import get_user_db
9393
from onyx.db.auth import SQLAlchemyUserAdminDB
9494
from onyx.db.engine import get_async_session
95-
from onyx.db.engine import get_async_session_with_tenant
95+
from onyx.db.engine import get_async_session_context_manager
9696
from onyx.db.engine import get_session_with_tenant
9797
from onyx.db.models import AccessToken
9898
from onyx.db.models import OAuthAccount
@@ -252,7 +252,7 @@ async def get_by_email(self, user_email: str) -> User:
252252
tenant_id = fetch_ee_implementation_or_noop(
253253
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
254254
)(user_email)
255-
async with get_async_session_with_tenant(tenant_id) as db_session:
255+
async with get_async_session_context_manager(tenant_id) as db_session:
256256
if MULTI_TENANT:
257257
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
258258
db_session, User, OAuthAccount
@@ -295,7 +295,7 @@ async def create(
295295
)
296296
user: User
297297

298-
async with get_async_session_with_tenant(tenant_id) as db_session:
298+
async with get_async_session_context_manager(tenant_id) as db_session:
299299
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
300300
verify_email_is_invited(user_create.email)
301301
verify_email_domain(user_create.email)
@@ -395,7 +395,7 @@ async def oauth_callback(
395395

396396
# Proceed with the tenant context
397397
token = None
398-
async with get_async_session_with_tenant(tenant_id) as db_session:
398+
async with get_async_session_context_manager(tenant_id) as db_session:
399399
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
400400

401401
verify_email_in_whitelist(account_email, tenant_id)
@@ -634,7 +634,7 @@ async def authenticate(
634634
return None
635635

636636
# Create a tenant-specific session
637-
async with get_async_session_with_tenant(tenant_id) as tenant_session:
637+
async with get_async_session_context_manager(tenant_id) as tenant_session:
638638
tenant_user_db: SQLAlchemyUserDatabase = SQLAlchemyUserDatabase(
639639
tenant_session, User
640640
)

backend/onyx/db/auth.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from onyx.auth.schemas import UserRole
1818
from onyx.db.api_key import get_api_key_email_pattern
1919
from onyx.db.engine import get_async_session
20-
from onyx.db.engine import get_async_session_with_tenant
20+
from onyx.db.engine import get_async_session_context_manager
2121
from onyx.db.models import AccessToken
2222
from onyx.db.models import OAuthAccount
2323
from onyx.db.models import User
@@ -55,8 +55,9 @@ def get_total_users_count(db_session: Session) -> int:
5555

5656

5757
async def get_user_count(only_admin_users: bool = False) -> int:
58-
async with get_async_session_with_tenant() as session:
59-
stmt = select(func.count(User.id))
58+
async with get_async_session_context_manager() as session:
59+
count_stmt = func.count(User.id) # type: ignore
60+
stmt = select(count_stmt)
6061
if only_admin_users:
6162
stmt = stmt.where(User.role == UserRole.ADMIN)
6263
result = await session.execute(stmt)

backend/onyx/db/engine.py

Lines changed: 73 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from contextlib import contextmanager
1111
from datetime import datetime
1212
from typing import Any
13-
from typing import ContextManager
13+
from typing import AsyncContextManager
1414

1515
import asyncpg # type: ignore
1616
import boto3
@@ -46,6 +46,7 @@
4646
from onyx.utils.logger import setup_logger
4747
from shared_configs.configs import MULTI_TENANT
4848
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
49+
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
4950
from shared_configs.configs import TENANT_ID_PREFIX
5051
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
5152
from shared_configs.contextvars import get_current_tenant_id
@@ -352,18 +353,6 @@ def provide_iam_token_async(
352353
return _ASYNC_ENGINE
353354

354355

355-
# Listen for events on the synchronous Session class
356-
@event.listens_for(Session, "after_begin")
357-
def _set_search_path(
358-
session: Session, transaction: Any, connection: Any, *args: Any, **kwargs: Any
359-
) -> None:
360-
"""Every time a new transaction is started,
361-
set the search_path from the session's info."""
362-
tenant_id = session.info.get("tenant_id")
363-
if tenant_id:
364-
connection.exec_driver_sql(f'SET search_path = "{tenant_id}"')
365-
366-
367356
engine = get_sqlalchemy_async_engine()
368357
AsyncSessionLocal = sessionmaker( # type: ignore
369358
bind=engine,
@@ -372,33 +361,6 @@ def _set_search_path(
372361
)
373362

374363

375-
@asynccontextmanager
376-
async def get_async_session_with_tenant(
377-
tenant_id: str | None = None,
378-
) -> AsyncGenerator[AsyncSession, None]:
379-
if tenant_id is None:
380-
tenant_id = get_current_tenant_id()
381-
382-
if not is_valid_schema_name(tenant_id):
383-
logger.error(f"Invalid tenant ID: {tenant_id}")
384-
raise ValueError("Invalid tenant ID")
385-
386-
async with AsyncSessionLocal() as session:
387-
session.sync_session.info["tenant_id"] = tenant_id
388-
389-
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
390-
await session.execute(
391-
text(
392-
f"SET idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
393-
)
394-
)
395-
396-
try:
397-
yield session
398-
finally:
399-
pass
400-
401-
402364
@contextmanager
403365
def get_session_with_current_tenant() -> Generator[Session, None, None]:
404366
tenant_id = get_current_tenant_id()
@@ -416,17 +378,24 @@ def get_session_with_shared_schema() -> Generator[Session, None, None]:
416378
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
417379

418380

381+
def _set_search_path_on_checkout__listener(
382+
dbapi_conn: Any, connection_record: Any, connection_proxy: Any
383+
) -> None:
384+
"""Listener to make sure we ALWAYS set the search path on checkout."""
385+
tenant_id = get_current_tenant_id()
386+
if tenant_id and is_valid_schema_name(tenant_id):
387+
with dbapi_conn.cursor() as cursor:
388+
cursor.execute(f'SET search_path TO "{tenant_id}"')
389+
390+
419391
@contextmanager
420392
def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]:
421393
"""
422394
Generate a database session for a specific tenant.
423395
"""
424-
if tenant_id is None:
425-
tenant_id = POSTGRES_DEFAULT_SCHEMA
426-
427396
engine = get_sqlalchemy_engine()
428397

429-
event.listen(engine, "checkout", set_search_path_on_checkout)
398+
event.listen(engine, "checkout", _set_search_path_on_checkout__listener)
430399

431400
if not is_valid_schema_name(tenant_id):
432401
raise HTTPException(status_code=400, detail="Invalid tenant ID")
@@ -457,57 +426,84 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
457426
cursor.close()
458427

459428

460-
def set_search_path_on_checkout(
461-
dbapi_conn: Any, connection_record: Any, connection_proxy: Any
462-
) -> None:
429+
def get_session() -> Generator[Session, None, None]:
430+
"""For use w/ Depends for FastAPI endpoints.
431+
432+
Has some additional validation, and likely should be merged
433+
with get_session_context_manager in the future."""
463434
tenant_id = get_current_tenant_id()
464-
if tenant_id and is_valid_schema_name(tenant_id):
465-
with dbapi_conn.cursor() as cursor:
466-
cursor.execute(f'SET search_path TO "{tenant_id}"')
435+
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
436+
raise BasicAuthenticationError(detail="User must authenticate")
437+
438+
if not is_valid_schema_name(tenant_id):
439+
raise HTTPException(status_code=400, detail="Invalid tenant ID")
467440

441+
with get_session_context_manager() as db_session:
442+
yield db_session
468443

469-
def get_session_generator_with_tenant() -> Generator[Session, None, None]:
444+
445+
@contextlib.contextmanager
446+
def get_session_context_manager() -> Generator[Session, None, None]:
447+
"""Context manager for database sessions."""
470448
tenant_id = get_current_tenant_id()
471449
with get_session_with_tenant(tenant_id=tenant_id) as session:
472450
yield session
473451

474452

475-
def get_session() -> Generator[Session, None, None]:
476-
tenant_id = get_current_tenant_id()
477-
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
478-
raise BasicAuthenticationError(detail="User must authenticate")
453+
def _set_search_path_on_transaction__listener(
454+
session: Session, transaction: Any, connection: Any, *args: Any, **kwargs: Any
455+
) -> None:
456+
"""Every time a new transaction is started,
457+
set the search_path from the session's info."""
458+
tenant_id = session.info.get("tenant_id")
459+
if tenant_id:
460+
connection.exec_driver_sql(f'SET search_path = "{tenant_id}"')
479461

480-
engine = get_sqlalchemy_engine()
481462

482-
with Session(engine, expire_on_commit=False) as session:
483-
if MULTI_TENANT:
484-
if not is_valid_schema_name(tenant_id):
485-
raise HTTPException(status_code=400, detail="Invalid tenant ID")
486-
session.execute(text(f'SET search_path = "{tenant_id}"'))
487-
yield session
463+
async def get_async_session(
464+
tenant_id: str | None = None,
465+
) -> AsyncGenerator[AsyncSession, None]:
466+
"""For use w/ Depends for *async* FastAPI endpoints.
488467
468+
For standard `async with ... as ...` use, use get_async_session_context_manager.
469+
"""
470+
471+
if tenant_id is None:
472+
tenant_id = get_current_tenant_id()
489473

490-
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
491-
tenant_id = get_current_tenant_id()
492474
engine = get_sqlalchemy_async_engine()
475+
493476
async with AsyncSession(engine, expire_on_commit=False) as async_session:
494-
if MULTI_TENANT:
495-
if not is_valid_schema_name(tenant_id):
496-
raise HTTPException(status_code=400, detail="Invalid tenant ID")
497-
await async_session.execute(text(f'SET search_path = "{tenant_id}"'))
498-
yield async_session
477+
# set the search path on sync session as well to be extra safe
478+
event.listen(
479+
async_session.sync_session,
480+
"after_begin",
481+
_set_search_path_on_transaction__listener,
482+
)
499483

484+
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
485+
await async_session.execute(
486+
text(
487+
f"SET idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
488+
)
489+
)
500490

501-
def get_session_context_manager() -> ContextManager[Session]:
502-
"""Context manager for database sessions."""
503-
return contextlib.contextmanager(get_session_generator_with_tenant)()
491+
if not is_valid_schema_name(tenant_id):
492+
raise HTTPException(status_code=400, detail="Invalid tenant ID")
504493

494+
# don't need to set the search path for self-hosted + default schema
495+
# this is also true for sync sessions, but just not adding it there for
496+
# now to simplify / not change too much
497+
if MULTI_TENANT or tenant_id != POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE:
498+
await async_session.execute(text(f'SET search_path = "{tenant_id}"'))
499+
500+
yield async_session
505501

506-
def get_session_factory() -> sessionmaker[Session]:
507-
global SessionFactory
508-
if SessionFactory is None:
509-
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
510-
return SessionFactory
502+
503+
def get_async_session_context_manager(
504+
tenant_id: str | None = None,
505+
) -> AsyncContextManager[AsyncSession]:
506+
return asynccontextmanager(get_async_session)(tenant_id)
511507

512508

513509
async def warm_up_connections(

0 commit comments

Comments
 (0)