10
10
from contextlib import contextmanager
11
11
from datetime import datetime
12
12
from typing import Any
13
- from typing import ContextManager
13
+ from typing import AsyncContextManager
14
14
15
15
import asyncpg # type: ignore
16
16
import boto3
46
46
from onyx .utils .logger import setup_logger
47
47
from shared_configs .configs import MULTI_TENANT
48
48
from shared_configs .configs import POSTGRES_DEFAULT_SCHEMA
49
+ from shared_configs .configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
49
50
from shared_configs .configs import TENANT_ID_PREFIX
50
51
from shared_configs .contextvars import CURRENT_TENANT_ID_CONTEXTVAR
51
52
from shared_configs .contextvars import get_current_tenant_id
@@ -352,18 +353,6 @@ def provide_iam_token_async(
352
353
return _ASYNC_ENGINE
353
354
354
355
355
- # Listen for events on the synchronous Session class
356
- @event .listens_for (Session , "after_begin" )
357
- def _set_search_path (
358
- session : Session , transaction : Any , connection : Any , * args : Any , ** kwargs : Any
359
- ) -> None :
360
- """Every time a new transaction is started,
361
- set the search_path from the session's info."""
362
- tenant_id = session .info .get ("tenant_id" )
363
- if tenant_id :
364
- connection .exec_driver_sql (f'SET search_path = "{ tenant_id } "' )
365
-
366
-
367
356
engine = get_sqlalchemy_async_engine ()
368
357
AsyncSessionLocal = sessionmaker ( # type: ignore
369
358
bind = engine ,
@@ -372,33 +361,6 @@ def _set_search_path(
372
361
)
373
362
374
363
375
- @asynccontextmanager
376
- async def get_async_session_with_tenant (
377
- tenant_id : str | None = None ,
378
- ) -> AsyncGenerator [AsyncSession , None ]:
379
- if tenant_id is None :
380
- tenant_id = get_current_tenant_id ()
381
-
382
- if not is_valid_schema_name (tenant_id ):
383
- logger .error (f"Invalid tenant ID: { tenant_id } " )
384
- raise ValueError ("Invalid tenant ID" )
385
-
386
- async with AsyncSessionLocal () as session :
387
- session .sync_session .info ["tenant_id" ] = tenant_id
388
-
389
- if POSTGRES_IDLE_SESSIONS_TIMEOUT :
390
- await session .execute (
391
- text (
392
- f"SET idle_in_transaction_session_timeout = { POSTGRES_IDLE_SESSIONS_TIMEOUT } "
393
- )
394
- )
395
-
396
- try :
397
- yield session
398
- finally :
399
- pass
400
-
401
-
402
364
@contextmanager
403
365
def get_session_with_current_tenant () -> Generator [Session , None , None ]:
404
366
tenant_id = get_current_tenant_id ()
@@ -416,17 +378,24 @@ def get_session_with_shared_schema() -> Generator[Session, None, None]:
416
378
CURRENT_TENANT_ID_CONTEXTVAR .reset (token )
417
379
418
380
381
+ def _set_search_path_on_checkout__listener (
382
+ dbapi_conn : Any , connection_record : Any , connection_proxy : Any
383
+ ) -> None :
384
+ """Listener to make sure we ALWAYS set the search path on checkout."""
385
+ tenant_id = get_current_tenant_id ()
386
+ if tenant_id and is_valid_schema_name (tenant_id ):
387
+ with dbapi_conn .cursor () as cursor :
388
+ cursor .execute (f'SET search_path TO "{ tenant_id } "' )
389
+
390
+
419
391
@contextmanager
420
392
def get_session_with_tenant (* , tenant_id : str ) -> Generator [Session , None , None ]:
421
393
"""
422
394
Generate a database session for a specific tenant.
423
395
"""
424
- if tenant_id is None :
425
- tenant_id = POSTGRES_DEFAULT_SCHEMA
426
-
427
396
engine = get_sqlalchemy_engine ()
428
397
429
- event .listen (engine , "checkout" , set_search_path_on_checkout )
398
+ event .listen (engine , "checkout" , _set_search_path_on_checkout__listener )
430
399
431
400
if not is_valid_schema_name (tenant_id ):
432
401
raise HTTPException (status_code = 400 , detail = "Invalid tenant ID" )
@@ -457,57 +426,84 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
457
426
cursor .close ()
458
427
459
428
460
- def set_search_path_on_checkout (
461
- dbapi_conn : Any , connection_record : Any , connection_proxy : Any
462
- ) -> None :
429
+ def get_session () -> Generator [Session , None , None ]:
430
+ """For use w/ Depends for FastAPI endpoints.
431
+
432
+ Has some additional validation, and likely should be merged
433
+ with get_session_context_manager in the future."""
463
434
tenant_id = get_current_tenant_id ()
464
- if tenant_id and is_valid_schema_name (tenant_id ):
465
- with dbapi_conn .cursor () as cursor :
466
- cursor .execute (f'SET search_path TO "{ tenant_id } "' )
435
+ if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT :
436
+ raise BasicAuthenticationError (detail = "User must authenticate" )
437
+
438
+ if not is_valid_schema_name (tenant_id ):
439
+ raise HTTPException (status_code = 400 , detail = "Invalid tenant ID" )
467
440
441
+ with get_session_context_manager () as db_session :
442
+ yield db_session
468
443
469
- def get_session_generator_with_tenant () -> Generator [Session , None , None ]:
444
+
445
+ @contextlib .contextmanager
446
+ def get_session_context_manager () -> Generator [Session , None , None ]:
447
+ """Context manager for database sessions."""
470
448
tenant_id = get_current_tenant_id ()
471
449
with get_session_with_tenant (tenant_id = tenant_id ) as session :
472
450
yield session
473
451
474
452
475
- def get_session () -> Generator [Session , None , None ]:
476
- tenant_id = get_current_tenant_id ()
477
- if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT :
478
- raise BasicAuthenticationError (detail = "User must authenticate" )
453
+ def _set_search_path_on_transaction__listener (
454
+ session : Session , transaction : Any , connection : Any , * args : Any , ** kwargs : Any
455
+ ) -> None :
456
+ """Every time a new transaction is started,
457
+ set the search_path from the session's info."""
458
+ tenant_id = session .info .get ("tenant_id" )
459
+ if tenant_id :
460
+ connection .exec_driver_sql (f'SET search_path = "{ tenant_id } "' )
479
461
480
- engine = get_sqlalchemy_engine ()
481
462
482
- with Session (engine , expire_on_commit = False ) as session :
483
- if MULTI_TENANT :
484
- if not is_valid_schema_name (tenant_id ):
485
- raise HTTPException (status_code = 400 , detail = "Invalid tenant ID" )
486
- session .execute (text (f'SET search_path = "{ tenant_id } "' ))
487
- yield session
463
+ async def get_async_session (
464
+ tenant_id : str | None = None ,
465
+ ) -> AsyncGenerator [AsyncSession , None ]:
466
+ """For use w/ Depends for *async* FastAPI endpoints.
488
467
468
+ For standard `async with ... as ...` use, use get_async_session_context_manager.
469
+ """
470
+
471
+ if tenant_id is None :
472
+ tenant_id = get_current_tenant_id ()
489
473
490
- async def get_async_session () -> AsyncGenerator [AsyncSession , None ]:
491
- tenant_id = get_current_tenant_id ()
492
474
engine = get_sqlalchemy_async_engine ()
475
+
493
476
async with AsyncSession (engine , expire_on_commit = False ) as async_session :
494
- if MULTI_TENANT :
495
- if not is_valid_schema_name (tenant_id ):
496
- raise HTTPException (status_code = 400 , detail = "Invalid tenant ID" )
497
- await async_session .execute (text (f'SET search_path = "{ tenant_id } "' ))
498
- yield async_session
477
+ # set the search path on sync session as well to be extra safe
478
+ event .listen (
479
+ async_session .sync_session ,
480
+ "after_begin" ,
481
+ _set_search_path_on_transaction__listener ,
482
+ )
499
483
484
+ if POSTGRES_IDLE_SESSIONS_TIMEOUT :
485
+ await async_session .execute (
486
+ text (
487
+ f"SET idle_in_transaction_session_timeout = { POSTGRES_IDLE_SESSIONS_TIMEOUT } "
488
+ )
489
+ )
500
490
501
- def get_session_context_manager () -> ContextManager [Session ]:
502
- """Context manager for database sessions."""
503
- return contextlib .contextmanager (get_session_generator_with_tenant )()
491
+ if not is_valid_schema_name (tenant_id ):
492
+ raise HTTPException (status_code = 400 , detail = "Invalid tenant ID" )
504
493
494
+ # don't need to set the search path for self-hosted + default schema
495
+ # this is also true for sync sessions, but just not adding it there for
496
+ # now to simplify / not change too much
497
+ if MULTI_TENANT or tenant_id != POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE :
498
+ await async_session .execute (text (f'SET search_path = "{ tenant_id } "' ))
499
+
500
+ yield async_session
505
501
506
- def get_session_factory () -> sessionmaker [ Session ]:
507
- global SessionFactory
508
- if SessionFactory is None :
509
- SessionFactory = sessionmaker ( bind = get_sqlalchemy_engine ())
510
- return SessionFactory
502
+
503
+ def get_async_session_context_manager (
504
+ tenant_id : str | None = None ,
505
+ ) -> AsyncContextManager [ AsyncSession ]:
506
+ return asynccontextmanager ( get_async_session )( tenant_id )
511
507
512
508
513
509
async def warm_up_connections (
0 commit comments