Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions backend/alembic/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,40 @@ To run all un-applied migrations:
To undo migrations:
`alembic downgrade -X`
where X is the number of migrations you want to undo from the current state

### Multi-tenant migrations

For multi-tenant deployments, you can use additional options:

**Upgrade all tenants:**
```bash
alembic -x upgrade_all_tenants=true upgrade head
```

**Upgrade a specific tenant schema:**
```bash
alembic -x schema=tenant_12345678-1234-1234-1234-123456789012 upgrade head
```

**Upgrade tenants within an alphabetical range:**
```bash
# Upgrade tenants 100-200 when sorted alphabetically (positions 100 to 200)
alembic -x upgrade_all_tenants=true -x tenant_range_start=100 -x tenant_range_end=200 upgrade head

# Upgrade tenants starting from position 1000 alphabetically
alembic -x upgrade_all_tenants=true -x tenant_range_start=1000 upgrade head

# Upgrade first 500 tenants alphabetically
alembic -x upgrade_all_tenants=true -x tenant_range_end=500 upgrade head
```

**Continue on error (for batch operations):**
```bash
alembic -x upgrade_all_tenants=true -x continue=true upgrade head
```

The tenant range filtering works by:
1. Sorting tenant IDs alphabetically
2. Using 1-based position numbers (1st, 2nd, 3rd tenant, etc.)
3. Filtering to the specified range of positions
4. Non-tenant schemas (like 'public') are always included
279 changes: 233 additions & 46 deletions backend/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.sql.schema import SchemaItem
from onyx.configs.constants import SSL_CERT_FILE
from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import (
MULTI_TENANT,
TENANT_ID_PREFIX,
)
from onyx.db.models import Base
from celery.backends.database.session import ResultModelBase # type: ignore
from onyx.db.engine.sql_engine import SqlEngine
Expand Down Expand Up @@ -69,33 +72,151 @@ def include_object(
return True


def get_schema_options() -> tuple[str, bool, bool, bool]:
def filter_tenants_by_range(
tenant_ids: list[str], start_range: int | None = None, end_range: int | None = None
) -> list[str]:
"""
Filter tenant IDs by alphabetical position range.

Args:
tenant_ids: List of tenant IDs to filter
start_range: Starting position in alphabetically sorted list (1-based, inclusive)
end_range: Ending position in alphabetically sorted list (1-based, inclusive)

Returns:
Filtered list of tenant IDs in their original order
"""
if start_range is None and end_range is None:
return tenant_ids

# Separate tenant IDs from non-tenant schemas
tenant_schemas = [tid for tid in tenant_ids if tid.startswith(TENANT_ID_PREFIX)]
non_tenant_schemas = [
tid for tid in tenant_ids if not tid.startswith(TENANT_ID_PREFIX)
]

# Sort tenant schemas alphabetically.
# NOTE: can cause missed schemas if a schema is created in between workers
# fetching of all tenant IDs. We accept this risk for now. Just re-running
# the migration will fix the issue.
sorted_tenant_schemas = sorted(tenant_schemas)

# Apply range filtering (convert to 0-based indexing)
start_idx = (start_range - 1) if start_range is not None else 0
end_idx = end_range if end_range is not None else len(sorted_tenant_schemas)

# Ensure indices are within bounds
start_idx = max(0, start_idx)
end_idx = min(len(sorted_tenant_schemas), end_idx)

# Get the filtered tenant schemas
filtered_tenant_schemas = sorted_tenant_schemas[start_idx:end_idx]

# Combine with non-tenant schemas and preserve original order
filtered_tenants = []
for tenant_id in tenant_ids:
if tenant_id in filtered_tenant_schemas or tenant_id in non_tenant_schemas:
filtered_tenants.append(tenant_id)

return filtered_tenants


def get_schema_options() -> (
tuple[bool, bool, bool, int | None, int | None, list[str] | None]
):
x_args_raw = context.get_x_argument()
x_args = {}
for arg in x_args_raw:
for pair in arg.split(","):
if "=" in pair:
key, value = pair.split("=", 1)
x_args[key.strip()] = value.strip()
schema_name = x_args.get("schema", POSTGRES_DEFAULT_SCHEMA)
if "=" in arg:
key, value = arg.split("=", 1)
x_args[key.strip()] = value.strip()
else:
raise ValueError(f"Invalid argument: {arg}")

create_schema = x_args.get("create_schema", "true").lower() == "true"
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"

# continue on error with individual tenant
# only applies to online migrations
continue_on_error = x_args.get("continue", "false").lower() == "true"

if (
MULTI_TENANT
and schema_name == POSTGRES_DEFAULT_SCHEMA
and not upgrade_all_tenants
):
# Tenant range filtering
tenant_range_start = None
tenant_range_end = None

if "tenant_range_start" in x_args:
try:
tenant_range_start = int(x_args["tenant_range_start"])
except ValueError:
raise ValueError(
f"Invalid tenant_range_start value: {x_args['tenant_range_start']}. Must be an integer."
)

if "tenant_range_end" in x_args:
try:
tenant_range_end = int(x_args["tenant_range_end"])
except ValueError:
raise ValueError(
f"Invalid tenant_range_end value: {x_args['tenant_range_end']}. Must be an integer."
)

# Validate range
if tenant_range_start is not None and tenant_range_end is not None:
if tenant_range_start > tenant_range_end:
raise ValueError(
f"tenant_range_start ({tenant_range_start}) cannot be greater than tenant_range_end ({tenant_range_end})"
)

# Specific schema names filtering (replaces both schema_name and the old tenant_ids approach)
specific_schema_names = None
if "specific_schema_names" in x_args:
schema_names_str = x_args["specific_schema_names"].strip()
if schema_names_str:
# Split by comma and strip whitespace
specific_schema_names = [
name.strip() for name in schema_names_str.split(",") if name.strip()
]
if specific_schema_names:
logger.info(f"Specific schema names specified: {specific_schema_names}")

# Validate that only one method is used at a time
range_filtering = tenant_range_start is not None or tenant_range_end is not None
specific_filtering = (
specific_schema_names is not None and len(specific_schema_names) > 0
)

if range_filtering and specific_filtering:
raise ValueError(
"Cannot use both tenant range filtering (tenant_range_start/tenant_range_end) "
"and specific schema filtering (specific_schema_names) at the same time. "
"Please use only one filtering method."
)

if upgrade_all_tenants and specific_filtering:
raise ValueError(
"Cannot use both upgrade_all_tenants=true and specific_schema_names at the same time. "
"Use either upgrade_all_tenants=true for all tenants, or specific_schema_names for specific schemas."
)

# If any filtering parameters are specified, we're not doing the default single schema migration
if range_filtering:
upgrade_all_tenants = True

# Validate multi-tenant requirements
if MULTI_TENANT and not upgrade_all_tenants and not specific_filtering:
raise ValueError(
"Cannot run default migrations in public schema when multi-tenancy is enabled. "
"Please specify a tenant-specific schema."
"In multi-tenant mode, you must specify either upgrade_all_tenants=true "
"or provide specific_schema_names. Cannot run default migration."
)

return schema_name, create_schema, upgrade_all_tenants, continue_on_error
return (
create_schema,
upgrade_all_tenants,
continue_on_error,
tenant_range_start,
tenant_range_end,
specific_schema_names,
)


def do_run_migrations(
Expand Down Expand Up @@ -142,10 +263,12 @@ def provide_iam_token_for_alembic(

async def run_async_migrations() -> None:
(
schema_name,
create_schema,
upgrade_all_tenants,
continue_on_error,
tenant_range_start,
tenant_range_end,
specific_schema_names,
) = get_schema_options()

# without init_engine, subsequent engine calls fail hard intentionally
Expand All @@ -164,12 +287,50 @@ def event_provide_iam_token_for_alembic(
) -> None:
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)

if upgrade_all_tenants:
if specific_schema_names:
# Use specific schema names directly without fetching all tenants
logger.info(f"Migrating specific schema names: {specific_schema_names}")

i_schema = 0
num_schemas = len(specific_schema_names)
for schema in specific_schema_names:
i_schema += 1
logger.info(
f"Migrating schema: index={i_schema} num_schemas={num_schemas} schema={schema}"
)
try:
async with engine.connect() as connection:
await connection.run_sync(
do_run_migrations,
schema_name=schema,
create_schema=create_schema,
)
except Exception as e:
logger.error(f"Error migrating schema {schema}: {e}")
if not continue_on_error:
logger.error("--continue=true is not set, raising exception!")
raise

logger.warning("--continue=true is set, continuing to next schema.")

elif upgrade_all_tenants:
tenant_schemas = get_all_tenant_ids()

filtered_tenant_schemas = filter_tenants_by_range(
tenant_schemas, tenant_range_start, tenant_range_end
)

if tenant_range_start is not None or tenant_range_end is not None:
logger.info(
f"Filtering tenants by range: start={tenant_range_start}, end={tenant_range_end}"
)
logger.info(
f"Total tenants: {len(tenant_schemas)}, Filtered tenants: {len(filtered_tenant_schemas)}"
)

i_tenant = 0
num_tenants = len(tenant_schemas)
for schema in tenant_schemas:
num_tenants = len(filtered_tenant_schemas)
for schema in filtered_tenant_schemas:
i_tenant += 1
logger.info(
f"Migrating schema: index={i_tenant} num_tenants={num_tenants} schema={schema}"
Expand All @@ -190,17 +351,13 @@ def event_provide_iam_token_for_alembic(
logger.warning("--continue=true is set, continuing to next schema.")

else:
try:
logger.info(f"Migrating schema: {schema_name}")
async with engine.connect() as connection:
await connection.run_sync(
do_run_migrations,
schema_name=schema_name,
create_schema=create_schema,
)
except Exception as e:
logger.error(f"Error migrating schema {schema_name}: {e}")
raise
# This should not happen in the new design since we require either
# upgrade_all_tenants=true or specific_schema_names in multi-tenant mode
# and for non-multi-tenant mode, we should use specific_schema_names with the default schema
raise ValueError(
"No migration target specified. Use either upgrade_all_tenants=true for all tenants "
"or specific_schema_names for specific schemas."
)

await engine.dispose()

Expand All @@ -221,10 +378,37 @@ def run_migrations_offline() -> None:
# without init_engine, subsequent engine calls fail hard intentionally
SqlEngine.init_engine(pool_size=20, max_overflow=5)

schema_name, _, upgrade_all_tenants, continue_on_error = get_schema_options()
(
create_schema,
upgrade_all_tenants,
continue_on_error,
tenant_range_start,
tenant_range_end,
specific_schema_names,
) = get_schema_options()
url = build_connection_string()

if upgrade_all_tenants:
if specific_schema_names:
# Use specific schema names directly without fetching all tenants
logger.info(f"Migrating specific schema names: {specific_schema_names}")

for schema in specific_schema_names:
logger.info(f"Migrating schema: {schema}")
context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
include_object=include_object,
version_table_schema=schema,
include_schemas=True,
script_location=config.get_main_option("script_location"),
dialect_opts={"paramstyle": "named"},
)

with context.begin_transaction():
context.run_migrations()

elif upgrade_all_tenants:
engine = create_async_engine(url)

if USE_IAM_AUTH:
Expand All @@ -238,7 +422,19 @@ def event_provide_iam_token_for_alembic_offline(
tenant_schemas = get_all_tenant_ids()
engine.sync_engine.dispose()

for schema in tenant_schemas:
filtered_tenant_schemas = filter_tenants_by_range(
tenant_schemas, tenant_range_start, tenant_range_end
)

if tenant_range_start is not None or tenant_range_end is not None:
logger.info(
f"Filtering tenants by range: start={tenant_range_start}, end={tenant_range_end}"
)
logger.info(
f"Total tenants: {len(tenant_schemas)}, Filtered tenants: {len(filtered_tenant_schemas)}"
)

for schema in filtered_tenant_schemas:
logger.info(f"Migrating schema: {schema}")
context.configure(
url=url,
Expand All @@ -254,21 +450,12 @@ def event_provide_iam_token_for_alembic_offline(
with context.begin_transaction():
context.run_migrations()
else:
logger.info(f"Migrating schema: {schema_name}")
context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
include_object=include_object,
version_table_schema=schema_name,
include_schemas=True,
script_location=config.get_main_option("script_location"),
dialect_opts={"paramstyle": "named"},
# This should not happen in the new design
raise ValueError(
"No migration target specified. Use either upgrade_all_tenants=true for all tenants "
"or specific_schema_names for specific schemas."
)

with context.begin_transaction():
context.run_migrations()


def run_migrations_online() -> None:
logger.info("run_migrations_online starting.")
Expand Down
Loading
Loading