Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 52 additions & 71 deletions backend/alembic/versions/abbfec3a5ac5_merge_prompt_into_persona.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,76 +16,59 @@
branch_labels = None
depends_on = None

MAX_PROMPT_LENGTH = 5_000_000


def upgrade() -> None:
"""NOTE: Prompts without any Personas will just be lost."""
# Step 1: Add new columns to persona table (only if they don't exist)
op.add_column(
"persona",
sa.Column("system_prompt", sa.String(length=MAX_PROMPT_LENGTH), nullable=True),
)
op.add_column(
"persona",
sa.Column("task_prompt", sa.String(length=MAX_PROMPT_LENGTH), nullable=True),
)
op.add_column(
"persona",
sa.Column(
"datetime_aware", sa.Boolean(), nullable=False, server_default="true"
),
)

# Check if columns exist before adding them
connection = op.get_bind()
inspector = sa.inspect(connection)
existing_columns = [col["name"] for col in inspector.get_columns("persona")]

if "system_prompt" not in existing_columns:
op.add_column(
"persona", sa.Column("system_prompt", sa.String(length=8000), nullable=True)
)

if "task_prompt" not in existing_columns:
op.add_column(
"persona", sa.Column("task_prompt", sa.String(length=8000), nullable=True)
)

if "datetime_aware" not in existing_columns:
op.add_column(
"persona",
sa.Column(
"datetime_aware", sa.Boolean(), nullable=False, server_default="true"
),
)

# Step 2: Migrate data from prompt table to persona table (only if tables exist)
existing_tables = inspector.get_table_names()

if "prompt" in existing_tables and "persona__prompt" in existing_tables:
# For personas that have associated prompts, copy the prompt data
op.execute(
"""
UPDATE persona
SET
system_prompt = p.system_prompt,
task_prompt = p.task_prompt,
datetime_aware = p.datetime_aware
FROM (
-- Get the first prompt for each persona (in case there are multiple)
SELECT DISTINCT ON (pp.persona_id)
pp.persona_id,
pr.system_prompt,
pr.task_prompt,
pr.datetime_aware
FROM persona__prompt pp
JOIN prompt pr ON pp.prompt_id = pr.id
) p
WHERE persona.id = p.persona_id
# For personas that have associated prompts, copy the prompt data
op.execute(
"""
)
UPDATE persona
SET
system_prompt = p.system_prompt,
task_prompt = p.task_prompt,
datetime_aware = p.datetime_aware
FROM (
-- Get the first prompt for each persona (in case there are multiple)
SELECT DISTINCT ON (pp.persona_id)
pp.persona_id,
pr.system_prompt,
pr.task_prompt,
pr.datetime_aware
FROM persona__prompt pp
JOIN prompt pr ON pp.prompt_id = pr.id
) p
WHERE persona.id = p.persona_id
"""
)

# Step 3: Update chat_message references
# Since chat messages referenced prompt_id, we need to update them to use persona_id
# This is complex as we need to map from prompt_id to persona_id

# Check if chat_message has prompt_id column
chat_message_columns = [
col["name"] for col in inspector.get_columns("chat_message")
]
if "prompt_id" in chat_message_columns:
op.execute(
"""
ALTER TABLE chat_message
DROP CONSTRAINT IF EXISTS chat_message__prompt_fk
"""
)
op.drop_column("chat_message", "prompt_id")
# Step 3: Update chat_message references
# Since chat messages referenced prompt_id, we need to update them to use persona_id
# This is complex as we need to map from prompt_id to persona_id
op.execute(
"""
ALTER TABLE chat_message
DROP CONSTRAINT IF EXISTS chat_message__prompt_fk
"""
)
op.drop_column("chat_message", "prompt_id")

# Step 4: Handle personas without prompts - set default values if needed (always run this)
op.execute(
Expand All @@ -99,26 +82,24 @@ def upgrade() -> None:
)

# Step 5: Drop the persona__prompt association table (if it exists)
if "persona__prompt" in existing_tables:
op.drop_table("persona__prompt")
op.drop_table("persona__prompt")

# Step 6: Drop the prompt table (if it exists)
if "prompt" in existing_tables:
op.drop_table("prompt")
op.drop_table("prompt")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Tables may not exist in all environments - could cause migration failures


# Step 7: Make system_prompt and task_prompt non-nullable after migration (only if they exist)
op.alter_column(
"persona",
"system_prompt",
existing_type=sa.String(length=8000),
existing_type=sa.String(length=MAX_PROMPT_LENGTH),
nullable=False,
server_default=None,
)

op.alter_column(
"persona",
"task_prompt",
existing_type=sa.String(length=8000),
existing_type=sa.String(length=MAX_PROMPT_LENGTH),
nullable=False,
server_default=None,
)
Expand All @@ -132,8 +113,8 @@ def downgrade() -> None:
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("name", sa.String(), nullable=False),
sa.Column("description", sa.String(), nullable=False),
sa.Column("system_prompt", sa.String(length=8000), nullable=False),
sa.Column("task_prompt", sa.String(length=8000), nullable=False),
sa.Column("system_prompt", sa.String(length=MAX_PROMPT_LENGTH), nullable=False),
sa.Column("task_prompt", sa.String(length=MAX_PROMPT_LENGTH), nullable=False),
sa.Column(
"datetime_aware", sa.Boolean(), nullable=False, server_default="true"
),
Expand Down
55 changes: 55 additions & 0 deletions backend/alembic/versions/b7ec9b5b505f_adjust_prompt_length.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""adjust prompt length
Revision ID: b7ec9b5b505f
Revises: abbfec3a5ac5
Create Date: 2025-09-10 18:51:15.629197
"""

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "b7ec9b5b505f"
down_revision = "abbfec3a5ac5"
branch_labels = None
depends_on = None


MAX_PROMPT_LENGTH = 5_000_000


def upgrade() -> None:
# NOTE: need to run this since the previous migration PREVIOUSLY set the length to 8000
op.alter_column(
"persona",
"system_prompt",
existing_type=sa.TEXT(),
type_=sa.String(length=MAX_PROMPT_LENGTH),
existing_nullable=False,
)
op.alter_column(
"persona",
"task_prompt",
existing_type=sa.TEXT(),
type_=sa.String(length=MAX_PROMPT_LENGTH),
existing_nullable=False,
)


def downgrade() -> None:
op.alter_column(
"prompt",
"system_prompt",
existing_type=sa.String(length=MAX_PROMPT_LENGTH),
type_=sa.TEXT(),
existing_nullable=False,
)
op.alter_column(
"prompt",
"task_prompt",
existing_type=sa.String(length=MAX_PROMPT_LENGTH),
type_=sa.TEXT(),
existing_nullable=False,
)
8 changes: 6 additions & 2 deletions backend/onyx/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@

logger = setup_logger()

PROMPT_LENGTH = 5_000_000


class Base(DeclarativeBase):
__abstract__ = True
Expand Down Expand Up @@ -2583,9 +2585,11 @@ class Persona(Base):

# Prompt fields merged from Prompt table
system_prompt: Mapped[str | None] = mapped_column(
String(length=8000), nullable=True
String(length=PROMPT_LENGTH), nullable=True
)
task_prompt: Mapped[str | None] = mapped_column(
String(length=PROMPT_LENGTH), nullable=True
)
task_prompt: Mapped[str | None] = mapped_column(String(length=8000), nullable=True)
datetime_aware: Mapped[bool] = mapped_column(Boolean, default=True)

uploaded_image_id: Mapped[str | None] = mapped_column(String, nullable=True)
Expand Down
Loading