diff --git a/backend/alembic/README.md b/backend/alembic/README.md index b7d294dd435..f24d6fd0be9 100644 --- a/backend/alembic/README.md +++ b/backend/alembic/README.md @@ -20,3 +20,44 @@ To run all un-applied migrations: To undo migrations: `alembic downgrade -X` where X is the number of migrations you want to undo from the current state + +### Multi-tenant migrations + +For multi-tenant deployments, you can use additional options: + +**Upgrade all tenants:** +```bash +alembic -x upgrade_all_tenants=true upgrade head +``` + +**Upgrade specific schemas:** +```bash +# Single schema +alembic -x schemas=tenant_12345678-1234-1234-1234-123456789012 upgrade head + +# Multiple schemas (comma-separated) +alembic -x schemas=tenant_12345678-1234-1234-1234-123456789012,public,another_tenant upgrade head +``` + +**Upgrade tenants within an alphabetical range:** +```bash +# Upgrade tenants 100-200 when sorted alphabetically (positions 100 to 200) +alembic -x upgrade_all_tenants=true -x tenant_range_start=100 -x tenant_range_end=200 upgrade head + +# Upgrade tenants starting from position 1000 alphabetically +alembic -x upgrade_all_tenants=true -x tenant_range_start=1000 upgrade head + +# Upgrade first 500 tenants alphabetically +alembic -x upgrade_all_tenants=true -x tenant_range_end=500 upgrade head +``` + +**Continue on error (for batch operations):** +```bash +alembic -x upgrade_all_tenants=true -x continue=true upgrade head +``` + +The tenant range filtering works by: +1. Sorting tenant IDs alphabetically +2. Using 1-based position numbers (1st, 2nd, 3rd tenant, etc.) +3. Filtering to the specified range of positions +4. Non-tenant schemas (like 'public') are always included diff --git a/backend/alembic/env.py b/backend/alembic/env.py index 1106673a87d..d33d5c37a37 100644 --- a/backend/alembic/env.py +++ b/backend/alembic/env.py @@ -21,7 +21,11 @@ from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.sql.schema import SchemaItem from onyx.configs.constants import SSL_CERT_FILE -from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA +from shared_configs.configs import ( + MULTI_TENANT, + POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE, + TENANT_ID_PREFIX, +) from onyx.db.models import Base from celery.backends.database.session import ResultModelBase # type: ignore from onyx.db.engine.sql_engine import SqlEngine @@ -69,15 +73,67 @@ def include_object( return True -def get_schema_options() -> tuple[str, bool, bool, bool]: +def filter_tenants_by_range( + tenant_ids: list[str], start_range: int | None = None, end_range: int | None = None +) -> list[str]: + """ + Filter tenant IDs by alphabetical position range. + + Args: + tenant_ids: List of tenant IDs to filter + start_range: Starting position in alphabetically sorted list (1-based, inclusive) + end_range: Ending position in alphabetically sorted list (1-based, inclusive) + + Returns: + Filtered list of tenant IDs in their original order + """ + if start_range is None and end_range is None: + return tenant_ids + + # Separate tenant IDs from non-tenant schemas + tenant_schemas = [tid for tid in tenant_ids if tid.startswith(TENANT_ID_PREFIX)] + non_tenant_schemas = [ + tid for tid in tenant_ids if not tid.startswith(TENANT_ID_PREFIX) + ] + + # Sort tenant schemas alphabetically. + # NOTE: can cause missed schemas if a schema is created in between workers + # fetching of all tenant IDs. We accept this risk for now. Just re-running + # the migration will fix the issue. + sorted_tenant_schemas = sorted(tenant_schemas) + + # Apply range filtering (0-based indexing) + start_idx = start_range if start_range is not None else 0 + end_idx = end_range if end_range is not None else len(sorted_tenant_schemas) + + # Ensure indices are within bounds + start_idx = max(0, start_idx) + end_idx = min(len(sorted_tenant_schemas), end_idx) + + # Get the filtered tenant schemas + filtered_tenant_schemas = sorted_tenant_schemas[start_idx:end_idx] + + # Combine with non-tenant schemas and preserve original order + filtered_tenants = [] + for tenant_id in tenant_ids: + if tenant_id in filtered_tenant_schemas or tenant_id in non_tenant_schemas: + filtered_tenants.append(tenant_id) + + return filtered_tenants + + +def get_schema_options() -> ( + tuple[bool, bool, bool, int | None, int | None, list[str] | None] +): x_args_raw = context.get_x_argument() x_args = {} for arg in x_args_raw: - for pair in arg.split(","): - if "=" in pair: - key, value = pair.split("=", 1) - x_args[key.strip()] = value.strip() - schema_name = x_args.get("schema", POSTGRES_DEFAULT_SCHEMA) + if "=" in arg: + key, value = arg.split("=", 1) + x_args[key.strip()] = value.strip() + else: + raise ValueError(f"Invalid argument: {arg}") + create_schema = x_args.get("create_schema", "true").lower() == "true" upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true" @@ -85,17 +141,81 @@ def get_schema_options() -> tuple[str, bool, bool, bool]: # only applies to online migrations continue_on_error = x_args.get("continue", "false").lower() == "true" - if ( - MULTI_TENANT - and schema_name == POSTGRES_DEFAULT_SCHEMA - and not upgrade_all_tenants - ): + # Tenant range filtering + tenant_range_start = None + tenant_range_end = None + + if "tenant_range_start" in x_args: + try: + tenant_range_start = int(x_args["tenant_range_start"]) + except ValueError: + raise ValueError( + f"Invalid tenant_range_start value: {x_args['tenant_range_start']}. Must be an integer." + ) + + if "tenant_range_end" in x_args: + try: + tenant_range_end = int(x_args["tenant_range_end"]) + except ValueError: + raise ValueError( + f"Invalid tenant_range_end value: {x_args['tenant_range_end']}. Must be an integer." + ) + + # Validate range + if tenant_range_start is not None and tenant_range_end is not None: + if tenant_range_start > tenant_range_end: + raise ValueError( + f"tenant_range_start ({tenant_range_start}) cannot be greater than tenant_range_end ({tenant_range_end})" + ) + + # Specific schema names filtering (replaces both schema_name and the old tenant_ids approach) + schemas = None + if "schemas" in x_args: + schema_names_str = x_args["schemas"].strip() + if schema_names_str: + # Split by comma and strip whitespace + schemas = [ + name.strip() for name in schema_names_str.split(",") if name.strip() + ] + if schemas: + logger.info(f"Specific schema names specified: {schemas}") + + # Validate that only one method is used at a time + range_filtering = tenant_range_start is not None or tenant_range_end is not None + specific_filtering = schemas is not None and len(schemas) > 0 + + if range_filtering and specific_filtering: raise ValueError( - "Cannot run default migrations in public schema when multi-tenancy is enabled. " - "Please specify a tenant-specific schema." + "Cannot use both tenant range filtering (tenant_range_start/tenant_range_end) " + "and specific schema filtering (schemas) at the same time. " + "Please use only one filtering method." ) - return schema_name, create_schema, upgrade_all_tenants, continue_on_error + if upgrade_all_tenants and specific_filtering: + raise ValueError( + "Cannot use both upgrade_all_tenants=true and schemas at the same time. " + "Use either upgrade_all_tenants=true for all tenants, or schemas for specific schemas." + ) + + # If any filtering parameters are specified, we're not doing the default single schema migration + if range_filtering: + upgrade_all_tenants = True + + # Validate multi-tenant requirements + if MULTI_TENANT and not upgrade_all_tenants and not specific_filtering: + raise ValueError( + "In multi-tenant mode, you must specify either upgrade_all_tenants=true " + "or provide schemas. Cannot run default migration." + ) + + return ( + create_schema, + upgrade_all_tenants, + continue_on_error, + tenant_range_start, + tenant_range_end, + schemas, + ) def do_run_migrations( @@ -142,12 +262,17 @@ def provide_iam_token_for_alembic( async def run_async_migrations() -> None: ( - schema_name, create_schema, upgrade_all_tenants, continue_on_error, + tenant_range_start, + tenant_range_end, + schemas, ) = get_schema_options() + if not schemas and not MULTI_TENANT: + schemas = [POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE] + # without init_engine, subsequent engine calls fail hard intentionally SqlEngine.init_engine(pool_size=20, max_overflow=5) @@ -164,12 +289,50 @@ def event_provide_iam_token_for_alembic( ) -> None: provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams) - if upgrade_all_tenants: + if schemas: + # Use specific schema names directly without fetching all tenants + logger.info(f"Migrating specific schema names: {schemas}") + + i_schema = 0 + num_schemas = len(schemas) + for schema in schemas: + i_schema += 1 + logger.info( + f"Migrating schema: index={i_schema} num_schemas={num_schemas} schema={schema}" + ) + try: + async with engine.connect() as connection: + await connection.run_sync( + do_run_migrations, + schema_name=schema, + create_schema=create_schema, + ) + except Exception as e: + logger.error(f"Error migrating schema {schema}: {e}") + if not continue_on_error: + logger.error("--continue=true is not set, raising exception!") + raise + + logger.warning("--continue=true is set, continuing to next schema.") + + elif upgrade_all_tenants: tenant_schemas = get_all_tenant_ids() + filtered_tenant_schemas = filter_tenants_by_range( + tenant_schemas, tenant_range_start, tenant_range_end + ) + + if tenant_range_start is not None or tenant_range_end is not None: + logger.info( + f"Filtering tenants by range: start={tenant_range_start}, end={tenant_range_end}" + ) + logger.info( + f"Total tenants: {len(tenant_schemas)}, Filtered tenants: {len(filtered_tenant_schemas)}" + ) + i_tenant = 0 - num_tenants = len(tenant_schemas) - for schema in tenant_schemas: + num_tenants = len(filtered_tenant_schemas) + for schema in filtered_tenant_schemas: i_tenant += 1 logger.info( f"Migrating schema: index={i_tenant} num_tenants={num_tenants} schema={schema}" @@ -190,17 +353,13 @@ def event_provide_iam_token_for_alembic( logger.warning("--continue=true is set, continuing to next schema.") else: - try: - logger.info(f"Migrating schema: {schema_name}") - async with engine.connect() as connection: - await connection.run_sync( - do_run_migrations, - schema_name=schema_name, - create_schema=create_schema, - ) - except Exception as e: - logger.error(f"Error migrating schema {schema_name}: {e}") - raise + # This should not happen in the new design since we require either + # upgrade_all_tenants=true or schemas in multi-tenant mode + # and for non-multi-tenant mode, we should use schemas with the default schema + raise ValueError( + "No migration target specified. Use either upgrade_all_tenants=true for all tenants " + "or schemas for specific schemas." + ) await engine.dispose() @@ -221,10 +380,37 @@ def run_migrations_offline() -> None: # without init_engine, subsequent engine calls fail hard intentionally SqlEngine.init_engine(pool_size=20, max_overflow=5) - schema_name, _, upgrade_all_tenants, continue_on_error = get_schema_options() + ( + create_schema, + upgrade_all_tenants, + continue_on_error, + tenant_range_start, + tenant_range_end, + schemas, + ) = get_schema_options() url = build_connection_string() - if upgrade_all_tenants: + if schemas: + # Use specific schema names directly without fetching all tenants + logger.info(f"Migrating specific schema names: {schemas}") + + for schema in schemas: + logger.info(f"Migrating schema: {schema}") + context.configure( + url=url, + target_metadata=target_metadata, # type: ignore + literal_binds=True, + include_object=include_object, + version_table_schema=schema, + include_schemas=True, + script_location=config.get_main_option("script_location"), + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + elif upgrade_all_tenants: engine = create_async_engine(url) if USE_IAM_AUTH: @@ -238,7 +424,19 @@ def event_provide_iam_token_for_alembic_offline( tenant_schemas = get_all_tenant_ids() engine.sync_engine.dispose() - for schema in tenant_schemas: + filtered_tenant_schemas = filter_tenants_by_range( + tenant_schemas, tenant_range_start, tenant_range_end + ) + + if tenant_range_start is not None or tenant_range_end is not None: + logger.info( + f"Filtering tenants by range: start={tenant_range_start}, end={tenant_range_end}" + ) + logger.info( + f"Total tenants: {len(tenant_schemas)}, Filtered tenants: {len(filtered_tenant_schemas)}" + ) + + for schema in filtered_tenant_schemas: logger.info(f"Migrating schema: {schema}") context.configure( url=url, @@ -254,21 +452,12 @@ def event_provide_iam_token_for_alembic_offline( with context.begin_transaction(): context.run_migrations() else: - logger.info(f"Migrating schema: {schema_name}") - context.configure( - url=url, - target_metadata=target_metadata, # type: ignore - literal_binds=True, - include_object=include_object, - version_table_schema=schema_name, - include_schemas=True, - script_location=config.get_main_option("script_location"), - dialect_opts={"paramstyle": "named"}, + # This should not happen in the new design + raise ValueError( + "No migration target specified. Use either upgrade_all_tenants=true for all tenants " + "or schemas for specific schemas." ) - with context.begin_transaction(): - context.run_migrations() - def run_migrations_online() -> None: logger.info("run_migrations_online starting.") diff --git a/backend/ee/onyx/server/tenants/schema_management.py b/backend/ee/onyx/server/tenants/schema_management.py index ac92c85797e..48f178b5a40 100644 --- a/backend/ee/onyx/server/tenants/schema_management.py +++ b/backend/ee/onyx/server/tenants/schema_management.py @@ -34,7 +34,7 @@ def run_alembic_migrations(schema_name: str) -> None: # Mimic command-line options by adding 'cmd_opts' to the config alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore - alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore + alembic_cfg.cmd_opts.x = [f"schemas={schema_name}"] # type: ignore # Run migrations programmatically command.upgrade(alembic_cfg, "head") diff --git a/backend/scripts/list_tenants.py b/backend/scripts/list_tenants.py new file mode 100755 index 00000000000..2ad2103d40e --- /dev/null +++ b/backend/scripts/list_tenants.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 + +""" +Tenant Count Script +Simple script to count the number of tenants in the database. +Used by the parallel migration script to determine how to split work. +""" + +import sys + +# Add the backend directory to the Python path +sys.path.append("/opt/onyx/backend") + +from onyx.db.engine.tenant_utils import get_all_tenant_ids +from onyx.db.engine.sql_engine import SqlEngine +from shared_configs.configs import TENANT_ID_PREFIX + + +def main() -> None: + try: + # Initialize the database engine with conservative settings + SqlEngine.init_engine(pool_size=5, max_overflow=2) + + # Get all tenant IDs + tenant_ids = get_all_tenant_ids() + + # Filter to only tenant schemas (not public or other system schemas) + tenant_schemas = [tid for tid in tenant_ids if tid.startswith(TENANT_ID_PREFIX)] + + # Print all tenant IDs, one per line + for tenant_id in tenant_schemas: + print(tenant_id) + + except Exception as e: + print(f"Error getting tenant IDs: {e}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main()