Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Add image input support to model config

Revision ID: 64bd5677aeb6
Revises: b30353be4eec
Create Date: 2025-09-28 15:48:12.003612

"""

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "64bd5677aeb6"
down_revision = "b30353be4eec"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column(
"model_configuration",
sa.Column("supports_image_input", sa.Boolean(), nullable=True),
)

# Seems to be left over from when model visibility was introduced and a nullable field.
# Set any null is_visible values to False
connection = op.get_bind()
connection.execute(
sa.text(
"UPDATE model_configuration SET is_visible = false WHERE is_visible IS NULL"
)
)


def downgrade() -> None:
op.drop_column("model_configuration", "supports_image_input")
1 change: 1 addition & 0 deletions backend/onyx/db/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def upsert_llm_provider(
name=model_configuration.name,
is_visible=model_configuration.is_visible,
max_input_tokens=model_configuration.max_input_tokens,
supports_image_input=model_configuration.supports_image_input,
)
.on_conflict_do_nothing()
)
Expand Down
2 changes: 2 additions & 0 deletions backend/onyx/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2353,6 +2353,8 @@ class ModelConfiguration(Base):
# - The end-user is configuring a model and chooses not to set a max-input-tokens limit.
max_input_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)

supports_image_input: Mapped[bool | None] = mapped_column(Boolean, nullable=True)

llm_provider: Mapped["LLMProvider"] = relationship(
"LLMProvider",
back_populates="model_configurations",
Expand Down
40 changes: 29 additions & 11 deletions backend/onyx/llm/factory.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from typing import Any

from onyx.chat.models import PersonaOverrideConfig
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.llm import fetch_default_provider
Expand All @@ -13,6 +10,8 @@
from onyx.llm.chat_llm import DefaultMultiLLM
from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.interfaces import LLM
from onyx.llm.llm_provider_options import OLLAMA_API_KEY_CONFIG_KEY
from onyx.llm.llm_provider_options import OLLAMA_PROVIDER_NAME
from onyx.llm.override_models import LLMOverride
from onyx.llm.utils import get_max_input_tokens_from_llm_provider
from onyx.llm.utils import model_supports_image_input
Expand All @@ -24,13 +23,22 @@
logger = setup_logger()


def _build_extra_model_kwargs(provider: str) -> dict[str, Any]:
"""Ollama requires us to specify the max context window.
def _build_provider_extra_headers(
provider: str, custom_config: dict[str, str] | None
) -> dict[str, str]:
if provider != OLLAMA_PROVIDER_NAME or not custom_config:
return {}

For now, just using the GEN_AI_MODEL_FALLBACK_MAX_TOKENS value.
TODO: allow model-specific values to be configured via the UI.
"""
return {"num_ctx": GEN_AI_MODEL_FALLBACK_MAX_TOKENS} if provider == "ollama" else {}
raw_api_key = custom_config.get(OLLAMA_API_KEY_CONFIG_KEY)

api_key = raw_api_key.strip() if raw_api_key else None
if not api_key:
return {}

if not api_key.lower().startswith("bearer "):
api_key = f"Bearer {api_key}"

return {"Authorization": api_key}


def get_main_llm_from_tuple(
Expand Down Expand Up @@ -272,6 +280,16 @@ def get_llm(
) -> LLM:
if temperature is None:
temperature = GEN_AI_TEMPERATURE

extra_headers = build_llm_extra_headers(additional_headers)

# NOTE: this is needed since Ollama API key is optional
# User may access Ollama cloud via locally hosted instance (logged in)
# or just via the cloud API (not logged in, using API key)
provider_extra_headers = _build_provider_extra_headers(provider, custom_config)
if provider_extra_headers:
extra_headers.update(provider_extra_headers)

return DefaultMultiLLM(
model_provider=provider,
model_name=model,
Expand All @@ -282,8 +300,8 @@ def get_llm(
timeout=timeout,
temperature=temperature,
custom_config=custom_config,
extra_headers=build_llm_extra_headers(additional_headers),
model_kwargs=_build_extra_model_kwargs(provider),
extra_headers=extra_headers,
model_kwargs={},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this gone?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. the build_llm_extra_headers is still invoked. its just that the additional header we need for Ollama cannot be passed into that function, so I just append it to the result and pass in just "extra_headers"
  2. the model kwargs was originally "user had to set context limit for every Ollama model" I don't believe this exists any more --> I was never prompted to do this in Ollama and I don't see an easy way to change the context limits it already has set

long_term_logger=long_term_logger,
max_input_tokens=max_input_tokens,
)
33 changes: 30 additions & 3 deletions backend/onyx/llm/llm_provider_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class WellKnownLLMProviderDescriptor(BaseModel):
model_configurations: list[ModelConfigurationView]
default_model: str | None = None
default_fast_model: str | None = None
default_api_base: str | None = None
# set for providers like Azure, which require a deployment name.
deployment_name_required: bool = False
# set for providers like Azure, which support a single model per deployment.
Expand Down Expand Up @@ -95,7 +96,9 @@ class WellKnownLLMProviderDescriptor(BaseModel):
for model in list(litellm.bedrock_models.union(litellm.bedrock_converse_models))
if "/" not in model and "embed" not in model
][::-1]
BEDROCK_DEFAULT_MODEL = "anthropic.claude-3-5-sonnet-20241022-v2:0"

OLLAMA_PROVIDER_NAME = "ollama"
OLLAMA_API_KEY_CONFIG_KEY = "OLLAMA_API_KEY"

IGNORABLE_ANTHROPIC_MODELS = [
"claude-2",
Expand Down Expand Up @@ -160,13 +163,15 @@ class WellKnownLLMProviderDescriptor(BaseModel):
BEDROCK_PROVIDER_NAME: BEDROCK_MODEL_NAMES,
ANTHROPIC_PROVIDER_NAME: ANTHROPIC_MODEL_NAMES,
VERTEXAI_PROVIDER_NAME: VERTEXAI_MODEL_NAMES,
OLLAMA_PROVIDER_NAME: [],
}

_PROVIDER_TO_VISIBLE_MODELS_MAP = {
OPENAI_PROVIDER_NAME: OPEN_AI_VISIBLE_MODEL_NAMES,
BEDROCK_PROVIDER_NAME: [BEDROCK_DEFAULT_MODEL],
BEDROCK_PROVIDER_NAME: [],
ANTHROPIC_PROVIDER_NAME: ANTHROPIC_VISIBLE_MODEL_NAMES,
VERTEXAI_PROVIDER_NAME: VERTEXAI_VISIBLE_MODEL_NAMES,
OLLAMA_PROVIDER_NAME: [],
}


Expand All @@ -185,6 +190,28 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
default_model="gpt-4o",
default_fast_model="gpt-4o-mini",
),
WellKnownLLMProviderDescriptor(
name=OLLAMA_PROVIDER_NAME,
display_name="Ollama",
api_key_required=False,
api_base_required=True,
api_version_required=False,
custom_config_keys=[
CustomConfigKey(
name=OLLAMA_API_KEY_CONFIG_KEY,
display_name="Ollama API Key",
description="Optional API key used when connecting to Ollama Cloud (i.e. API base is https://ollama.com).",
is_required=False,
is_secret=True,
)
],
model_configurations=fetch_model_configurations_for_provider(
OLLAMA_PROVIDER_NAME
),
default_model=None,
default_fast_model=None,
default_api_base="http://127.0.0.1:11434",
),
WellKnownLLMProviderDescriptor(
name=ANTHROPIC_PROVIDER_NAME,
display_name="Anthropic",
Expand Down Expand Up @@ -248,7 +275,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
model_configurations=fetch_model_configurations_for_provider(
BEDROCK_PROVIDER_NAME
),
default_model=BEDROCK_DEFAULT_MODEL,
default_model=None,
default_fast_model=None,
),
WellKnownLLMProviderDescriptor(
Expand Down
28 changes: 28 additions & 0 deletions backend/onyx/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from sqlalchemy import select

from onyx.configs.app_configs import LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS
from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION
Expand All @@ -26,6 +27,9 @@
from onyx.configs.model_configs import GEN_AI_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import LLMProvider
from onyx.db.models import ModelConfiguration
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.llm.interfaces import LLM
Expand Down Expand Up @@ -640,6 +644,30 @@ def get_max_input_tokens_from_llm_provider(


def model_supports_image_input(model_name: str, model_provider: str) -> bool:
# TODO: Add support to check model config for any provider
# TODO: Circular import means OLLAMA_PROVIDER_NAME is not available here

if model_provider == "ollama":
try:
with get_session_with_current_tenant() as db_session:
model_config = db_session.scalar(
select(ModelConfiguration)
.join(
LLMProvider,
ModelConfiguration.llm_provider_id == LLMProvider.id,
)
.where(
ModelConfiguration.name == model_name,
LLMProvider.provider == model_provider,
)
)
if model_config and model_config.supports_image_input is not None:
return model_config.supports_image_input
except Exception as e:
logger.warning(
f"Failed to query database for {model_provider} model {model_name} image support: {e}"
)

model_map = get_model_map()
try:
model_obj = find_model_obj(
Expand Down
103 changes: 103 additions & 0 deletions backend/onyx/server/manage/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@
from datetime import timezone

import boto3
import httpx
from botocore.exceptions import BotoCoreError
from botocore.exceptions import ClientError
from botocore.exceptions import NoCredentialsError
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from pydantic import ValidationError
from sqlalchemy.orm import Session

from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accessible_user
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from onyx.db.engine.sql_engine import get_session
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_existing_llm_providers
Expand All @@ -40,6 +43,9 @@
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.server.manage.llm.models import OllamaFinalModelResponse
from onyx.server.manage.llm.models import OllamaModelDetails
from onyx.server.manage.llm.models import OllamaModelsRequest
from onyx.server.manage.llm.models import TestLLMRequest
from onyx.server.manage.llm.models import VisionProviderResponse
from onyx.utils.logger import setup_logger
Expand Down Expand Up @@ -474,3 +480,100 @@ def get_bedrock_available_models(
raise HTTPException(
status_code=500, detail=f"Unexpected error fetching Bedrock models: {e}"
)


def _get_ollama_available_model_names(api_base: str) -> set[str]:
"""Fetch available model names from Ollama server."""
tags_url = f"{api_base}/api/tags"
try:
response = httpx.get(tags_url, timeout=5.0)
response.raise_for_status()
response_json = response.json()
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Failed to fetch Ollama models: {e}",
)

models = response_json.get("models", [])
return {model.get("name") for model in models if model.get("name")}


@admin_router.post("/ollama/available-models")
def get_ollama_available_models(
request: OllamaModelsRequest,
_: User | None = Depends(current_admin_user),
) -> list[OllamaFinalModelResponse]:
"""Fetch the list of available models from an Ollama server."""

cleaned_api_base = request.api_base.strip().rstrip("/")
if not cleaned_api_base:
raise HTTPException(
status_code=400, detail="API base URL is required to fetch Ollama models."
)

model_names = _get_ollama_available_model_names(cleaned_api_base)
if not model_names:
raise HTTPException(
status_code=400,
detail="No models found from your Ollama server",
)

all_models_with_context_size_and_vision: list[OllamaFinalModelResponse] = []
show_url = f"{cleaned_api_base}/api/show"

for model_name in model_names:
context_limit: int | None = None
supports_image_input: bool | None = None
try:
show_response = httpx.post(
show_url,
json={"model": model_name},
timeout=5.0,
)
show_response.raise_for_status()
show_response_json = show_response.json()

# Parse the response into the expected format
ollama_model_details = OllamaModelDetails.model_validate(show_response_json)

# Check if this model supports completion/chat
if not ollama_model_details.supports_completion():
continue

# Optimistically access. Context limit is stored as "model_architecture.context" = int
architecture = ollama_model_details.model_info.get(
"general.architecture", ""
)
context_limit = ollama_model_details.model_info.get(
architecture + ".context_length", None
)
supports_image_input = ollama_model_details.supports_image_input()
except ValidationError as e:
logger.warning(
"Invalid model details from Ollama server",
extra={"model": model_name, "validation_error": str(e)},
)
except Exception as e:
logger.warning(
"Failed to fetch Ollama model details",
extra={"model": model_name, "error": str(e)},
)

# If we fail at any point attempting to extract context limit,
# still allow this model to be used with a fallback max context size
if not context_limit:
context_limit = GEN_AI_MODEL_FALLBACK_MAX_TOKENS

if not supports_image_input:
supports_image_input = False

all_models_with_context_size_and_vision.append(
OllamaFinalModelResponse(
name=model_name,
max_input_tokens=context_limit,
supports_image_input=supports_image_input,
)
)

return all_models_with_context_size_and_vision
Loading
Loading