Skip to content

Commit f6d8f5c

Browse files
authored
Migrate tenant upgrades to data plane (#3051)
* add provisioning on data plane * functional but scrappy * minor cleanup * minor clean up * k * simplify * update provisioning * improve import logic * ensure proper conditional * minor pydantic update * minor config update * nit
1 parent 1fb4cdf commit f6d8f5c

File tree

16 files changed

+356
-226
lines changed

16 files changed

+356
-226
lines changed

backend/danswer/auth/users.py

Lines changed: 12 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
from httpx_oauth.oauth2 import BaseOAuth2
4949
from httpx_oauth.oauth2 import OAuth2Token
5050
from pydantic import BaseModel
51-
from sqlalchemy import select
5251
from sqlalchemy import text
5352
from sqlalchemy.orm import attributes
5453
from sqlalchemy.orm import Session
@@ -83,21 +82,19 @@
8382
from danswer.db.engine import get_async_session_with_tenant
8483
from danswer.db.engine import get_session
8584
from danswer.db.engine import get_session_with_tenant
86-
from danswer.db.engine import get_sqlalchemy_engine
8785
from danswer.db.models import AccessToken
8886
from danswer.db.models import OAuthAccount
8987
from danswer.db.models import User
90-
from danswer.db.models import UserTenantMapping
9188
from danswer.db.users import get_user_by_email
9289
from danswer.utils.logger import setup_logger
9390
from danswer.utils.telemetry import optional_telemetry
9491
from danswer.utils.telemetry import RecordType
9592
from danswer.utils.variable_functionality import fetch_versioned_implementation
93+
from ee.danswer.server.tenants.provisioning import get_or_create_tenant_id
94+
from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email
9695
from shared_configs.configs import MULTI_TENANT
97-
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
9896
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
9997

100-
10198
logger = setup_logger()
10299

103100

@@ -190,20 +187,6 @@ def verify_email_domain(email: str) -> None:
190187
)
191188

192189

193-
def get_tenant_id_for_email(email: str) -> str:
194-
if not MULTI_TENANT:
195-
return POSTGRES_DEFAULT_SCHEMA
196-
# Implement logic to get tenant_id from the mapping table
197-
with Session(get_sqlalchemy_engine()) as db_session:
198-
result = db_session.execute(
199-
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
200-
)
201-
tenant_id = result.scalar_one_or_none()
202-
if tenant_id is None:
203-
raise exceptions.UserNotExists()
204-
return tenant_id
205-
206-
207190
def send_user_verification_email(
208191
user_email: str,
209192
token: str,
@@ -238,19 +221,7 @@ async def create(
238221
safe: bool = False,
239222
request: Optional[Request] = None,
240223
) -> User:
241-
try:
242-
tenant_id = (
243-
get_tenant_id_for_email(user_create.email)
244-
if MULTI_TENANT
245-
else POSTGRES_DEFAULT_SCHEMA
246-
)
247-
except exceptions.UserNotExists:
248-
raise HTTPException(status_code=401, detail="User not found")
249-
250-
if not tenant_id:
251-
raise HTTPException(
252-
status_code=401, detail="User does not belong to an organization"
253-
)
224+
tenant_id = await get_or_create_tenant_id(user_create.email)
254225

255226
async with get_async_session_with_tenant(tenant_id) as db_session:
256227
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@@ -271,7 +242,7 @@ async def create(
271242
user_create.role = UserRole.ADMIN
272243
else:
273244
user_create.role = UserRole.BASIC
274-
user = None
245+
275246
try:
276247
user = await super().create(user_create, safe=safe, request=request) # type: ignore
277248
except exceptions.UserAlreadyExists:
@@ -292,7 +263,9 @@ async def create(
292263
else:
293264
raise exceptions.UserAlreadyExists()
294265

295-
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
266+
finally:
267+
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
268+
296269
return user
297270

298271
async def oauth_callback(
@@ -308,19 +281,12 @@ async def oauth_callback(
308281
associate_by_email: bool = False,
309282
is_verified_by_default: bool = False,
310283
) -> models.UOAP:
311-
# Get tenant_id from mapping table
312-
try:
313-
tenant_id = (
314-
get_tenant_id_for_email(account_email)
315-
if MULTI_TENANT
316-
else POSTGRES_DEFAULT_SCHEMA
317-
)
318-
except exceptions.UserNotExists:
319-
raise HTTPException(status_code=401, detail="User not found")
284+
tenant_id = await get_or_create_tenant_id(account_email)
320285

321286
if not tenant_id:
322287
raise HTTPException(status_code=401, detail="User not found")
323288

289+
# Proceed with the tenant context
324290
token = None
325291
async with get_async_session_with_tenant(tenant_id) as db_session:
326292
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@@ -371,9 +337,9 @@ async def oauth_callback(
371337
# Explicitly set the Postgres schema for this session to ensure
372338
# OAuth account creation happens in the correct tenant schema
373339
await db_session.execute(text(f'SET search_path = "{tenant_id}"'))
374-
user = await self.user_db.add_oauth_account(
375-
user, oauth_account_dict
376-
)
340+
341+
# Add OAuth account
342+
await self.user_db.add_oauth_account(user, oauth_account_dict)
377343
await self.on_after_register(user, request)
378344

379345
else:

backend/danswer/background/celery/apps/beat.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,10 @@ def _update_tenant_tasks(self) -> None:
119119
else:
120120
logger.info("Schedule is up to date, no changes needed")
121121

122-
except (AttributeError, KeyError) as e:
123-
logger.exception(f"Failed to process task configuration: {str(e)}")
124-
except Exception as e:
125-
logger.exception(f"Unexpected error updating tenant tasks: {str(e)}")
122+
except (AttributeError, KeyError):
123+
logger.exception("Failed to process task configuration")
124+
except Exception:
125+
logger.exception("Unexpected error updating tenant tasks")
126126

127127
def _should_update_schedule(
128128
self, current_schedule: dict, new_schedule: dict

backend/danswer/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,12 +277,14 @@ def get_application() -> FastAPI:
277277
prefix="/auth",
278278
tags=["auth"],
279279
)
280+
280281
include_router_with_global_prefix_prepended(
281282
application,
282283
fastapi_users.get_register_router(UserRead, UserCreate),
283284
prefix="/auth",
284285
tags=["auth"],
285286
)
287+
286288
include_router_with_global_prefix_prepended(
287289
application,
288290
fastapi_users.get_reset_password_router(),

backend/danswer/server/manage/users.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from danswer.auth.users import current_admin_user
3131
from danswer.auth.users import current_curator_or_admin_user
3232
from danswer.auth.users import current_user
33-
from danswer.auth.users import get_tenant_id_for_email
3433
from danswer.auth.users import optional_user
3534
from danswer.configs.app_configs import AUTH_TYPE
3635
from danswer.configs.app_configs import ENABLE_EMAIL_INVITES
@@ -66,7 +65,8 @@
6665
from ee.danswer.db.user_group import remove_curator_status__no_commit
6766
from ee.danswer.server.tenants.billing import register_tenant_users
6867
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
69-
from ee.danswer.server.tenants.provisioning import remove_users_from_tenant
68+
from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email
69+
from ee.danswer.server.tenants.user_mapping import remove_users_from_tenant
7070
from shared_configs.configs import MULTI_TENANT
7171

7272
logger = setup_logger()

backend/danswer/server/query_and_chat/chat_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def stream_generator() -> Generator[str, None, None]:
359359
yield json.dumps(packet) if isinstance(packet, dict) else packet
360360

361361
except Exception as e:
362-
logger.exception(f"Error in chat message streaming: {e}")
362+
logger.exception("Error in chat message streaming")
363363
yield json.dumps({"error": str(e)})
364364

365365
finally:

backend/danswer/server/query_and_chat/query_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def stream_generator() -> Generator[str, None, None]:
279279
):
280280
yield json.dumps(packet) if isinstance(packet, dict) else packet
281281
except Exception as e:
282-
logger.exception(f"Error in search answer streaming: {e}")
282+
logger.exception("Error in search answer streaming")
283283
yield json.dumps({"error": str(e)})
284284

285285
return StreamingResponse(stream_generator(), media_type="application/json")

backend/ee/danswer/server/tenants/api.py

Lines changed: 1 addition & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -7,31 +7,23 @@
77
from danswer.auth.users import auth_backend
88
from danswer.auth.users import current_admin_user
99
from danswer.auth.users import get_jwt_strategy
10-
from danswer.auth.users import get_tenant_id_for_email
1110
from danswer.auth.users import User
1211
from danswer.configs.app_configs import WEB_DOMAIN
1312
from danswer.db.engine import get_session_with_tenant
1413
from danswer.db.notification import create_notification
1514
from danswer.db.users import get_user_by_email
1615
from danswer.server.settings.store import load_settings
1716
from danswer.server.settings.store import store_settings
18-
from danswer.setup import setup_danswer
1917
from danswer.utils.logger import setup_logger
2018
from ee.danswer.auth.users import current_cloud_superuser
2119
from ee.danswer.configs.app_configs import STRIPE_SECRET_KEY
2220
from ee.danswer.server.tenants.access import control_plane_dep
2321
from ee.danswer.server.tenants.billing import fetch_billing_information
2422
from ee.danswer.server.tenants.billing import fetch_tenant_stripe_information
2523
from ee.danswer.server.tenants.models import BillingInformation
26-
from ee.danswer.server.tenants.models import CreateTenantRequest
2724
from ee.danswer.server.tenants.models import ImpersonateRequest
2825
from ee.danswer.server.tenants.models import ProductGatingRequest
29-
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
30-
from ee.danswer.server.tenants.provisioning import configure_default_api_keys
31-
from ee.danswer.server.tenants.provisioning import ensure_schema_exists
32-
from ee.danswer.server.tenants.provisioning import run_alembic_migrations
33-
from ee.danswer.server.tenants.provisioning import user_owns_a_tenant
34-
from shared_configs.configs import MULTI_TENANT
26+
from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email
3527
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
3628

3729
stripe.api_key = STRIPE_SECRET_KEY
@@ -40,52 +32,6 @@
4032
router = APIRouter(prefix="/tenants")
4133

4234

43-
@router.post("/create")
44-
def create_tenant(
45-
create_tenant_request: CreateTenantRequest, _: None = Depends(control_plane_dep)
46-
) -> dict[str, str]:
47-
if not MULTI_TENANT:
48-
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
49-
50-
tenant_id = create_tenant_request.tenant_id
51-
email = create_tenant_request.initial_admin_email
52-
token = None
53-
54-
if user_owns_a_tenant(email):
55-
raise HTTPException(
56-
status_code=409, detail="User already belongs to an organization"
57-
)
58-
59-
try:
60-
if not ensure_schema_exists(tenant_id):
61-
logger.info(f"Created schema for tenant {tenant_id}")
62-
else:
63-
logger.info(f"Schema already exists for tenant {tenant_id}")
64-
65-
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
66-
run_alembic_migrations(tenant_id)
67-
68-
with get_session_with_tenant(tenant_id) as db_session:
69-
setup_danswer(db_session, tenant_id)
70-
71-
configure_default_api_keys(db_session)
72-
73-
add_users_to_tenant([email], tenant_id)
74-
75-
return {
76-
"status": "success",
77-
"message": f"Tenant {tenant_id} created successfully",
78-
}
79-
except Exception as e:
80-
logger.exception(f"Failed to create tenant {tenant_id}: {str(e)}")
81-
raise HTTPException(
82-
status_code=500, detail=f"Failed to create tenant: {str(e)}"
83-
)
84-
finally:
85-
if token is not None:
86-
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
87-
88-
8935
@router.post("/product-gating")
9036
def gate_product(
9137
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)

backend/ee/danswer/server/tenants/models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,8 @@ class CheckoutSessionCreationResponse(BaseModel):
3333

3434
class ImpersonateRequest(BaseModel):
3535
email: str
36+
37+
38+
class TenantCreationPayload(BaseModel):
39+
tenant_id: str
40+
email: str

0 commit comments

Comments
 (0)