diff --git a/backend/alembic/versions/64bd5677aeb6_add_image_input_support_to_model_config.py b/backend/alembic/versions/64bd5677aeb6_add_image_input_support_to_model_config.py new file mode 100644 index 00000000000..b373114302a --- /dev/null +++ b/backend/alembic/versions/64bd5677aeb6_add_image_input_support_to_model_config.py @@ -0,0 +1,37 @@ +"""Add image input support to model config + +Revision ID: 64bd5677aeb6 +Revises: b30353be4eec +Create Date: 2025-09-28 15:48:12.003612 + +""" + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "64bd5677aeb6" +down_revision = "b30353be4eec" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "model_configuration", + sa.Column("supports_image_input", sa.Boolean(), nullable=True), + ) + + # Seems to be left over from when model visibility was introduced and a nullable field. + # Set any null is_visible values to False + connection = op.get_bind() + connection.execute( + sa.text( + "UPDATE model_configuration SET is_visible = false WHERE is_visible IS NULL" + ) + ) + + +def downgrade() -> None: + op.drop_column("model_configuration", "supports_image_input") diff --git a/backend/onyx/db/llm.py b/backend/onyx/db/llm.py index e1c49349139..42e6db57077 100644 --- a/backend/onyx/db/llm.py +++ b/backend/onyx/db/llm.py @@ -112,6 +112,7 @@ def upsert_llm_provider( name=model_configuration.name, is_visible=model_configuration.is_visible, max_input_tokens=model_configuration.max_input_tokens, + supports_image_input=model_configuration.supports_image_input, ) .on_conflict_do_nothing() ) diff --git a/backend/onyx/db/models.py b/backend/onyx/db/models.py index f897bcdca42..2d39d486695 100644 --- a/backend/onyx/db/models.py +++ b/backend/onyx/db/models.py @@ -2353,6 +2353,8 @@ class ModelConfiguration(Base): # - The end-user is configuring a model and chooses not to set a max-input-tokens limit. max_input_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True) + supports_image_input: Mapped[bool | None] = mapped_column(Boolean, nullable=True) + llm_provider: Mapped["LLMProvider"] = relationship( "LLMProvider", back_populates="model_configurations", diff --git a/backend/onyx/llm/factory.py b/backend/onyx/llm/factory.py index 985b8e129d9..aeba171aa46 100644 --- a/backend/onyx/llm/factory.py +++ b/backend/onyx/llm/factory.py @@ -1,8 +1,5 @@ -from typing import Any - from onyx.chat.models import PersonaOverrideConfig from onyx.configs.app_configs import DISABLE_GENERATIVE_AI -from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS from onyx.configs.model_configs import GEN_AI_TEMPERATURE from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.llm import fetch_default_provider @@ -13,6 +10,8 @@ from onyx.llm.chat_llm import DefaultMultiLLM from onyx.llm.exceptions import GenAIDisabledException from onyx.llm.interfaces import LLM +from onyx.llm.llm_provider_options import OLLAMA_API_KEY_CONFIG_KEY +from onyx.llm.llm_provider_options import OLLAMA_PROVIDER_NAME from onyx.llm.override_models import LLMOverride from onyx.llm.utils import get_max_input_tokens_from_llm_provider from onyx.llm.utils import model_supports_image_input @@ -24,13 +23,22 @@ logger = setup_logger() -def _build_extra_model_kwargs(provider: str) -> dict[str, Any]: - """Ollama requires us to specify the max context window. +def _build_provider_extra_headers( + provider: str, custom_config: dict[str, str] | None +) -> dict[str, str]: + if provider != OLLAMA_PROVIDER_NAME or not custom_config: + return {} - For now, just using the GEN_AI_MODEL_FALLBACK_MAX_TOKENS value. - TODO: allow model-specific values to be configured via the UI. - """ - return {"num_ctx": GEN_AI_MODEL_FALLBACK_MAX_TOKENS} if provider == "ollama" else {} + raw_api_key = custom_config.get(OLLAMA_API_KEY_CONFIG_KEY) + + api_key = raw_api_key.strip() if raw_api_key else None + if not api_key: + return {} + + if not api_key.lower().startswith("bearer "): + api_key = f"Bearer {api_key}" + + return {"Authorization": api_key} def get_main_llm_from_tuple( @@ -272,6 +280,16 @@ def get_llm( ) -> LLM: if temperature is None: temperature = GEN_AI_TEMPERATURE + + extra_headers = build_llm_extra_headers(additional_headers) + + # NOTE: this is needed since Ollama API key is optional + # User may access Ollama cloud via locally hosted instance (logged in) + # or just via the cloud API (not logged in, using API key) + provider_extra_headers = _build_provider_extra_headers(provider, custom_config) + if provider_extra_headers: + extra_headers.update(provider_extra_headers) + return DefaultMultiLLM( model_provider=provider, model_name=model, @@ -282,8 +300,8 @@ def get_llm( timeout=timeout, temperature=temperature, custom_config=custom_config, - extra_headers=build_llm_extra_headers(additional_headers), - model_kwargs=_build_extra_model_kwargs(provider), + extra_headers=extra_headers, + model_kwargs={}, long_term_logger=long_term_logger, max_input_tokens=max_input_tokens, ) diff --git a/backend/onyx/llm/llm_provider_options.py b/backend/onyx/llm/llm_provider_options.py index 462e7003e16..e2e56b7fcfd 100644 --- a/backend/onyx/llm/llm_provider_options.py +++ b/backend/onyx/llm/llm_provider_options.py @@ -39,6 +39,7 @@ class WellKnownLLMProviderDescriptor(BaseModel): model_configurations: list[ModelConfigurationView] default_model: str | None = None default_fast_model: str | None = None + default_api_base: str | None = None # set for providers like Azure, which require a deployment name. deployment_name_required: bool = False # set for providers like Azure, which support a single model per deployment. @@ -95,7 +96,9 @@ class WellKnownLLMProviderDescriptor(BaseModel): for model in list(litellm.bedrock_models.union(litellm.bedrock_converse_models)) if "/" not in model and "embed" not in model ][::-1] -BEDROCK_DEFAULT_MODEL = "anthropic.claude-3-5-sonnet-20241022-v2:0" + +OLLAMA_PROVIDER_NAME = "ollama" +OLLAMA_API_KEY_CONFIG_KEY = "OLLAMA_API_KEY" IGNORABLE_ANTHROPIC_MODELS = [ "claude-2", @@ -160,13 +163,15 @@ class WellKnownLLMProviderDescriptor(BaseModel): BEDROCK_PROVIDER_NAME: BEDROCK_MODEL_NAMES, ANTHROPIC_PROVIDER_NAME: ANTHROPIC_MODEL_NAMES, VERTEXAI_PROVIDER_NAME: VERTEXAI_MODEL_NAMES, + OLLAMA_PROVIDER_NAME: [], } _PROVIDER_TO_VISIBLE_MODELS_MAP = { OPENAI_PROVIDER_NAME: OPEN_AI_VISIBLE_MODEL_NAMES, - BEDROCK_PROVIDER_NAME: [BEDROCK_DEFAULT_MODEL], + BEDROCK_PROVIDER_NAME: [], ANTHROPIC_PROVIDER_NAME: ANTHROPIC_VISIBLE_MODEL_NAMES, VERTEXAI_PROVIDER_NAME: VERTEXAI_VISIBLE_MODEL_NAMES, + OLLAMA_PROVIDER_NAME: [], } @@ -185,6 +190,28 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]: default_model="gpt-4o", default_fast_model="gpt-4o-mini", ), + WellKnownLLMProviderDescriptor( + name=OLLAMA_PROVIDER_NAME, + display_name="Ollama", + api_key_required=False, + api_base_required=True, + api_version_required=False, + custom_config_keys=[ + CustomConfigKey( + name=OLLAMA_API_KEY_CONFIG_KEY, + display_name="Ollama API Key", + description="Optional API key used when connecting to Ollama Cloud (i.e. API base is https://ollama.com).", + is_required=False, + is_secret=True, + ) + ], + model_configurations=fetch_model_configurations_for_provider( + OLLAMA_PROVIDER_NAME + ), + default_model=None, + default_fast_model=None, + default_api_base="http://127.0.0.1:11434", + ), WellKnownLLMProviderDescriptor( name=ANTHROPIC_PROVIDER_NAME, display_name="Anthropic", @@ -248,7 +275,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]: model_configurations=fetch_model_configurations_for_provider( BEDROCK_PROVIDER_NAME ), - default_model=BEDROCK_DEFAULT_MODEL, + default_model=None, default_fast_model=None, ), WellKnownLLMProviderDescriptor( diff --git a/backend/onyx/llm/utils.py b/backend/onyx/llm/utils.py index ffcfb46bae4..f8396a5c73d 100644 --- a/backend/onyx/llm/utils.py +++ b/backend/onyx/llm/utils.py @@ -16,6 +16,7 @@ from langchain.schema.messages import BaseMessage from langchain.schema.messages import HumanMessage from langchain.schema.messages import SystemMessage +from sqlalchemy import select from onyx.configs.app_configs import LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION @@ -26,6 +27,9 @@ from onyx.configs.model_configs import GEN_AI_MAX_TOKENS from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS from onyx.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS +from onyx.db.engine.sql_engine import get_session_with_current_tenant +from onyx.db.models import LLMProvider +from onyx.db.models import ModelConfiguration from onyx.file_store.models import ChatFileType from onyx.file_store.models import InMemoryChatFile from onyx.llm.interfaces import LLM @@ -640,6 +644,30 @@ def get_max_input_tokens_from_llm_provider( def model_supports_image_input(model_name: str, model_provider: str) -> bool: + # TODO: Add support to check model config for any provider + # TODO: Circular import means OLLAMA_PROVIDER_NAME is not available here + + if model_provider == "ollama": + try: + with get_session_with_current_tenant() as db_session: + model_config = db_session.scalar( + select(ModelConfiguration) + .join( + LLMProvider, + ModelConfiguration.llm_provider_id == LLMProvider.id, + ) + .where( + ModelConfiguration.name == model_name, + LLMProvider.provider == model_provider, + ) + ) + if model_config and model_config.supports_image_input is not None: + return model_config.supports_image_input + except Exception as e: + logger.warning( + f"Failed to query database for {model_provider} model {model_name} image support: {e}" + ) + model_map = get_model_map() try: model_obj = find_model_obj( diff --git a/backend/onyx/server/manage/llm/api.py b/backend/onyx/server/manage/llm/api.py index 59bdd4b91da..ab444dda4ef 100644 --- a/backend/onyx/server/manage/llm/api.py +++ b/backend/onyx/server/manage/llm/api.py @@ -4,6 +4,7 @@ from datetime import timezone import boto3 +import httpx from botocore.exceptions import BotoCoreError from botocore.exceptions import ClientError from botocore.exceptions import NoCredentialsError @@ -11,10 +12,12 @@ from fastapi import Depends from fastapi import HTTPException from fastapi import Query +from pydantic import ValidationError from sqlalchemy.orm import Session from onyx.auth.users import current_admin_user from onyx.auth.users import current_chat_accessible_user +from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS from onyx.db.engine.sql_engine import get_session from onyx.db.llm import fetch_existing_llm_provider from onyx.db.llm import fetch_existing_llm_providers @@ -40,6 +43,9 @@ from onyx.server.manage.llm.models import LLMProviderUpsertRequest from onyx.server.manage.llm.models import LLMProviderView from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest +from onyx.server.manage.llm.models import OllamaFinalModelResponse +from onyx.server.manage.llm.models import OllamaModelDetails +from onyx.server.manage.llm.models import OllamaModelsRequest from onyx.server.manage.llm.models import TestLLMRequest from onyx.server.manage.llm.models import VisionProviderResponse from onyx.utils.logger import setup_logger @@ -474,3 +480,100 @@ def get_bedrock_available_models( raise HTTPException( status_code=500, detail=f"Unexpected error fetching Bedrock models: {e}" ) + + +def _get_ollama_available_model_names(api_base: str) -> set[str]: + """Fetch available model names from Ollama server.""" + tags_url = f"{api_base}/api/tags" + try: + response = httpx.get(tags_url, timeout=5.0) + response.raise_for_status() + response_json = response.json() + except Exception as e: + raise HTTPException( + status_code=400, + detail=f"Failed to fetch Ollama models: {e}", + ) + + models = response_json.get("models", []) + return {model.get("name") for model in models if model.get("name")} + + +@admin_router.post("/ollama/available-models") +def get_ollama_available_models( + request: OllamaModelsRequest, + _: User | None = Depends(current_admin_user), +) -> list[OllamaFinalModelResponse]: + """Fetch the list of available models from an Ollama server.""" + + cleaned_api_base = request.api_base.strip().rstrip("/") + if not cleaned_api_base: + raise HTTPException( + status_code=400, detail="API base URL is required to fetch Ollama models." + ) + + model_names = _get_ollama_available_model_names(cleaned_api_base) + if not model_names: + raise HTTPException( + status_code=400, + detail="No models found from your Ollama server", + ) + + all_models_with_context_size_and_vision: list[OllamaFinalModelResponse] = [] + show_url = f"{cleaned_api_base}/api/show" + + for model_name in model_names: + context_limit: int | None = None + supports_image_input: bool | None = None + try: + show_response = httpx.post( + show_url, + json={"model": model_name}, + timeout=5.0, + ) + show_response.raise_for_status() + show_response_json = show_response.json() + + # Parse the response into the expected format + ollama_model_details = OllamaModelDetails.model_validate(show_response_json) + + # Check if this model supports completion/chat + if not ollama_model_details.supports_completion(): + continue + + # Optimistically access. Context limit is stored as "model_architecture.context" = int + architecture = ollama_model_details.model_info.get( + "general.architecture", "" + ) + context_limit = ollama_model_details.model_info.get( + architecture + ".context_length", None + ) + supports_image_input = ollama_model_details.supports_image_input() + except ValidationError as e: + logger.warning( + "Invalid model details from Ollama server", + extra={"model": model_name, "validation_error": str(e)}, + ) + except Exception as e: + logger.warning( + "Failed to fetch Ollama model details", + extra={"model": model_name, "error": str(e)}, + ) + + # If we fail at any point attempting to extract context limit, + # still allow this model to be used with a fallback max context size + if not context_limit: + context_limit = GEN_AI_MODEL_FALLBACK_MAX_TOKENS + + if not supports_image_input: + supports_image_input = False + + all_models_with_context_size_and_vision.append( + OllamaFinalModelResponse( + name=model_name, + max_input_tokens=context_limit, + supports_image_input=supports_image_input, + ) + ) + + return all_models_with_context_size_and_vision diff --git a/backend/onyx/server/manage/llm/models.py b/backend/onyx/server/manage/llm/models.py index 027769a0929..4074cb06db1 100644 --- a/backend/onyx/server/manage/llm/models.py +++ b/backend/onyx/server/manage/llm/models.py @@ -1,3 +1,4 @@ +from typing import Any from typing import TYPE_CHECKING from pydantic import BaseModel @@ -138,8 +139,9 @@ def from_model( class ModelConfigurationUpsertRequest(BaseModel): name: str - is_visible: bool | None = False + is_visible: bool max_input_tokens: int | None = None + supports_image_input: bool | None = None @classmethod def from_model( @@ -149,12 +151,13 @@ def from_model( name=model_configuration_model.name, is_visible=model_configuration_model.is_visible, max_input_tokens=model_configuration_model.max_input_tokens, + supports_image_input=model_configuration_model.supports_image_input, ) class ModelConfigurationView(BaseModel): name: str - is_visible: bool | None = False + is_visible: bool max_input_tokens: int | None = None supports_image_input: bool @@ -196,3 +199,28 @@ class BedrockModelsRequest(BaseModel): aws_secret_access_key: str | None = None aws_bearer_token_bedrock: str | None = None provider_name: str | None = None # Optional: to save models to existing provider + + +class OllamaModelsRequest(BaseModel): + api_base: str + + +class OllamaFinalModelResponse(BaseModel): + name: str + max_input_tokens: int + supports_image_input: bool + + +class OllamaModelDetails(BaseModel): + """Response model for Ollama /api/show endpoint""" + + model_info: dict[str, Any] + capabilities: list[str] = [] + + def supports_completion(self) -> bool: + """Check if this model supports completion/chat""" + return "completion" in self.capabilities + + def supports_image_input(self) -> bool: + """Check if this model supports image input""" + return "vision" in self.capabilities diff --git a/backend/tests/integration/tests/llm_provider/test_llm_provider.py b/backend/tests/integration/tests/llm_provider/test_llm_provider.py index 85d46ef2fe3..7680e571870 100644 --- a/backend/tests/integration/tests/llm_provider/test_llm_provider.py +++ b/backend/tests/integration/tests/llm_provider/test_llm_provider.py @@ -92,31 +92,22 @@ def fill_max_input_tokens_and_supports_image_input( ) ], ), - # Test the case in which the basic model-configuration is passed, but its visibility is not - # specified (and thus defaulted to False). - # In this case, since the one model-configuration is also the default-model-name, its - # visibility should be overriden to True. - ( - "gpt-4", - [ModelConfigurationUpsertRequest(name="gpt-4")], - [ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True)], - ), # Test the case in which multiple model-configuration are passed. ( "gpt-4", [ - ModelConfigurationUpsertRequest(name="gpt-4"), - ModelConfigurationUpsertRequest(name="gpt-4o"), + ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True), + ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True), ], [ ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True), - ModelConfigurationUpsertRequest(name="gpt-4o"), + ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True), ], ), # Test the case in which duplicate model-configuration are passed. ( "gpt-4", - [ModelConfigurationUpsertRequest(name="gpt-4")] * 4, + [ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True)] * 4, [ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True)], ), # Test the case in which no model-configurations are passed. @@ -132,10 +123,16 @@ def fill_max_input_tokens_and_supports_image_input( # (`ModelConfiguration(name="gpt-4", is_visible=True, max_input_tokens=None)`). ( "gpt-4", - [ModelConfigurationUpsertRequest(name="gpt-4o", max_input_tokens=4096)], + [ + ModelConfigurationUpsertRequest( + name="gpt-4o", is_visible=True, max_input_tokens=4096 + ) + ], [ ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True), - ModelConfigurationUpsertRequest(name="gpt-4o", max_input_tokens=4096), + ModelConfigurationUpsertRequest( + name="gpt-4o", is_visible=True, max_input_tokens=4096 + ), ], ), ], @@ -182,7 +179,11 @@ def test_create_llm_provider( ( ( "gpt-4", - [ModelConfigurationUpsertRequest(name="gpt-4", max_input_tokens=4096)], + [ + ModelConfigurationUpsertRequest( + name="gpt-4", is_visible=True, max_input_tokens=4096 + ) + ], ), [ ModelConfigurationUpsertRequest( @@ -191,7 +192,7 @@ def test_create_llm_provider( ], ( "gpt-4", - [ModelConfigurationUpsertRequest(name="gpt-4")], + [ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True)], ), [ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True)], ), @@ -201,19 +202,25 @@ def test_create_llm_provider( ( "gpt-4", [ - ModelConfigurationUpsertRequest(name="gpt-4"), + ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True), ModelConfigurationUpsertRequest( - name="gpt-4o", max_input_tokens=4096 + name="gpt-4o", is_visible=True, max_input_tokens=4096 ), ], ), [ ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True), - ModelConfigurationUpsertRequest(name="gpt-4o", max_input_tokens=4096), + ModelConfigurationUpsertRequest( + name="gpt-4o", is_visible=True, max_input_tokens=4096 + ), ], ( "gpt-4", - [ModelConfigurationUpsertRequest(name="gpt-4", max_input_tokens=4096)], + [ + ModelConfigurationUpsertRequest( + name="gpt-4", is_visible=True, max_input_tokens=4096 + ) + ], ), [ ModelConfigurationUpsertRequest( @@ -328,8 +335,8 @@ def test_update_model_configurations( ( "gpt-4", [ - ModelConfigurationUpsertRequest(name="gpt-4o"), - ModelConfigurationUpsertRequest(name="gpt-4"), + ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True), + ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True), ], ), ], diff --git a/web/src/app/admin/configuration/llm/ConfiguredLLMProviderDisplay.tsx b/web/src/app/admin/configuration/llm/ConfiguredLLMProviderDisplay.tsx index 92e2f8ea858..cc3f52fa682 100644 --- a/web/src/app/admin/configuration/llm/ConfiguredLLMProviderDisplay.tsx +++ b/web/src/app/admin/configuration/llm/ConfiguredLLMProviderDisplay.tsx @@ -176,16 +176,16 @@ export function ConfiguredLLMProviderDisplay({ // then the provider is custom - don't use the default // provider descriptor llmProviderDescriptor={ - isSubset( - defaultProviderDesciptor - ? defaultProviderDesciptor.model_configurations.map( - (model_configuration) => model_configuration.name - ) - : [], - provider.model_configurations.map( - (model_configuration) => model_configuration.name - ) - ) + defaultProviderDesciptor && + (defaultProviderDesciptor.model_configurations.length === 0 || + isSubset( + defaultProviderDesciptor.model_configurations.map( + (model_configuration) => model_configuration.name + ), + provider.model_configurations.map( + (model_configuration) => model_configuration.name + ) + )) ? defaultProviderDesciptor : null } diff --git a/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx b/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx index e3ca2ff9779..38513f7ce91 100644 --- a/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx +++ b/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx @@ -5,6 +5,7 @@ import Text from "@/components/ui/text"; import { Separator } from "@/components/ui/separator"; import { Button } from "@/components/ui/button"; import { Form, Formik } from "formik"; +import type { FormikProps } from "formik"; import { FiTrash } from "react-icons/fi"; import { LLM_PROVIDERS_ADMIN_URL } from "./constants"; import { @@ -13,18 +14,79 @@ import { MultiSelectField, FileUploadFormField, } from "@/components/Field"; -import { useState } from "react"; +import { useEffect, useRef, useState } from "react"; import { useSWRConfig } from "swr"; import { LLMProviderView, - ModelConfigurationUpsertRequest, + ModelConfiguration, WellKnownLLMProviderDescriptor, } from "./interfaces"; +import { dynamicProviderConfigs, fetchModels } from "./utils"; import { PopupSpec } from "@/components/admin/connectors/Popup"; import * as Yup from "yup"; import isEqual from "lodash/isEqual"; import { IsPublicGroupSelector } from "@/components/IsPublicGroupSelector"; +function AutoFetchModelsOnEdit({ + llmProviderDescriptor, + existingLlmProvider, + values, + setFieldValue, + setIsFetchingModels, + setFetchModelsError, + setPopup, +}: { + llmProviderDescriptor: WellKnownLLMProviderDescriptor; + existingLlmProvider?: LLMProviderView; + values: any; + setFieldValue: FormikProps["setFieldValue"]; + setIsFetchingModels: (loading: boolean) => void; + setFetchModelsError: (error: string) => void; + setPopup?: (popup: PopupSpec) => void; +}) { + const hasAutoFetchedRef = useRef(false); + + useEffect(() => { + if (!existingLlmProvider) { + return; + } + + const config = dynamicProviderConfigs[llmProviderDescriptor.name]; + if (!config) { + return; + } + + if (hasAutoFetchedRef.current) { + return; + } + + if (config.isDisabled(values)) { + return; + } + + hasAutoFetchedRef.current = true; + fetchModels( + llmProviderDescriptor, + existingLlmProvider, + values, + setFieldValue, + setIsFetchingModels, + setFetchModelsError, + setPopup + ); + }, [ + existingLlmProvider, + llmProviderDescriptor, + setFieldValue, + setFetchModelsError, + setIsFetchingModels, + setPopup, + values, + ]); + + return null; +} + export function LLMProviderUpdateForm({ llmProviderDescriptor, onClose, @@ -53,12 +115,22 @@ export function LLMProviderUpdateForm({ const [showAdvancedOptions, setShowAdvancedOptions] = useState(false); + // Helper function to get current model configurations + const getCurrentModelConfigurations = (values: any): ModelConfiguration[] => { + return values.fetched_model_configurations?.length > 0 + ? values.fetched_model_configurations + : llmProviderDescriptor.model_configurations; + }; + // Define the initial values based on the provider's requirements const initialValues = { name: existingLlmProvider?.name || (firstTimeConfiguration ? "Default" : ""), api_key: existingLlmProvider?.api_key ?? "", - api_base: existingLlmProvider?.api_base ?? "", + api_base: + existingLlmProvider?.api_base ?? + llmProviderDescriptor.default_api_base ?? + "", api_version: existingLlmProvider?.api_version ?? "", // For Azure OpenAI, combine api_base and api_version into target_uri target_uri: @@ -100,6 +172,9 @@ export function LLMProviderUpdateForm({ .filter((modelConfiguration) => modelConfiguration.is_visible) .map((modelConfiguration) => modelConfiguration.name) as string[]), + // Store fetched model configurations in form state instead of mutating props + fetched_model_configurations: [] as ModelConfiguration[], + // Helper field to force re-renders when model list updates _modelListUpdated: 0, }; @@ -174,6 +249,7 @@ export function LLMProviderUpdateForm({ is_public: Yup.boolean().required(), groups: Yup.array().of(Yup.number()), selected_model_names: Yup.array().of(Yup.string()), + fetched_model_configurations: Yup.array(), }); const customLinkRenderer = ({ href, children }: any) => { @@ -184,99 +260,6 @@ export function LLMProviderUpdateForm({ ); }; - const fetchBedrockModels = async (values: any, setFieldValue: any) => { - if (llmProviderDescriptor.name !== "bedrock") { - return; - } - - setIsFetchingModels(true); - setFetchModelsError(""); - - try { - const response = await fetch("/api/admin/llm/bedrock/available-models", { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ - aws_region_name: values.custom_config?.AWS_REGION_NAME, - aws_access_key_id: values.custom_config?.AWS_ACCESS_KEY_ID, - aws_secret_access_key: values.custom_config?.AWS_SECRET_ACCESS_KEY, - aws_bearer_token_bedrock: - values.custom_config?.AWS_BEARER_TOKEN_BEDROCK, - provider_name: existingLlmProvider?.name, // Save models to existing provider if editing - }), - }); - - if (!response.ok) { - const errorData = await response.json(); - throw new Error(errorData.detail || "Failed to fetch models"); - } - - const availableModels: string[] = await response.json(); - - // Update the model configurations with the fetched models - const updatedModelConfigs = availableModels.map((modelName) => { - // Find existing configuration to preserve is_visible setting - const existingConfig = llmProviderDescriptor.model_configurations.find( - (config) => config.name === modelName - ); - - return { - name: modelName, - is_visible: existingConfig?.is_visible ?? false, // Preserve existing visibility or default to false - max_input_tokens: null, - supports_image_input: false, // Will be determined by the backend - }; - }); - - // Update the descriptor and form values - llmProviderDescriptor.model_configurations = updatedModelConfigs; - - // Update selected model names to only include previously visible models that are available - const previouslySelectedModels = values.selected_model_names || []; - const stillAvailableSelectedModels = previouslySelectedModels.filter( - (modelName: string) => availableModels.includes(modelName) - ); - setFieldValue("selected_model_names", stillAvailableSelectedModels); - - // Set a default model if none is set - if ( - (!values.default_model_name || - !availableModels.includes(values.default_model_name)) && - availableModels.length > 0 - ) { - setFieldValue("default_model_name", availableModels[0]); - } - - // Clear fast model if it's not in the new list - if ( - values.fast_default_model_name && - !availableModels.includes(values.fast_default_model_name) - ) { - setFieldValue("fast_default_model_name", null); - } - - // Force a re-render by updating a timestamp or counter - setFieldValue("_modelListUpdated", Date.now()); - - setPopup?.({ - message: `Successfully fetched ${availableModels.length} models for the selected region (including cross-region inference models).`, - type: "success", - }); - } catch (error) { - const errorMessage = - error instanceof Error ? error.message : "Unknown error"; - setFetchModelsError(errorMessage); - setPopup?.({ - message: `Failed to fetch models: ${errorMessage}`, - type: "error", - }); - } finally { - setIsFetchingModels(false); - } - }; - return ( ({ + model_configurations: getCurrentModelConfigurations(values).map( + (modelConfiguration): ModelConfiguration => ({ name: modelConfiguration.name, is_visible: visibleModels.includes(modelConfiguration.name), - max_input_tokens: null, + max_input_tokens: modelConfiguration.max_input_tokens ?? null, + supports_image_input: modelConfiguration.supports_image_input, }) ), }; @@ -422,175 +406,165 @@ export function LLMProviderUpdateForm({ setSubmitting(false); }} > - {(formikProps) => ( -
- {!firstTimeConfiguration && ( - - )} - - {llmProviderDescriptor.api_key_required && ( - - )} - - {llmProviderDescriptor.name === "azure" ? ( - { + // Get current model configurations for this render + const currentModelConfigurations = getCurrentModelConfigurations( + formikProps.values + ); + const dynamicConfig = + dynamicProviderConfigs[llmProviderDescriptor.name]; + + return ( + + - ) : ( - <> - {llmProviderDescriptor.api_base_required && ( - - )} + {!firstTimeConfiguration && ( + + )} - {llmProviderDescriptor.api_version_required && ( - - )} - - )} + {llmProviderDescriptor.api_key_required && ( + + )} - {llmProviderDescriptor.custom_config_keys?.map((customConfigKey) => { - if (customConfigKey.key_type === "text_input") { - return ( -
+ {llmProviderDescriptor.name === "azure" ? ( + + ) : ( + <> + {llmProviderDescriptor.api_base_required && ( - {customConfigKey.description} - - } - placeholder={customConfigKey.default_value || undefined} - type={customConfigKey.is_secret ? "password" : "text"} + name="api_base" + label="API Base" + placeholder="API Base" /> -
- ); - } else if (customConfigKey.key_type === "file_input") { - return ( - - ); - } else { - throw new Error("Unreachable; there should only exist 2 options"); - } - })} - - {/* Bedrock-specific fetch models button */} - {llmProviderDescriptor.name === "bedrock" && ( -
- - {fetchModelsError && ( - {fetchModelsError} - )} + {llmProviderDescriptor.api_version_required && ( + + )} + + )} - - Enter your AWS region, then click this button to fetch available - Bedrock models. -
- If you're updating your existing provider, you'll need - to click this button to fetch the latest models. -
-
- )} - - {!firstTimeConfiguration && ( - <> - - - {llmProviderDescriptor.model_configurations.length > 0 ? ( - ({ - // don't clean up names here to give admins descriptive names / handle duplicates - // like us.anthropic.claude-3-7-sonnet-20250219-v1:0 and anthropic.claude-3-7-sonnet-20250219-v1:0 - name: modelConfiguration.name, - value: modelConfiguration.name, - }) + {llmProviderDescriptor.custom_config_keys?.map( + (customConfigKey) => { + if (customConfigKey.key_type === "text_input") { + return ( +
+ + {customConfigKey.description} + + } + placeholder={customConfigKey.default_value || undefined} + type={customConfigKey.is_secret ? "password" : "text"} + /> +
+ ); + } else if (customConfigKey.key_type === "file_input") { + return ( + + ); + } else { + throw new Error( + "Unreachable; there should only exist 2 options" + ); + } + } + )} + {/* Fetch models button - automatically shows for supported providers */} + {dynamicConfig && ( +
+ - {llmProviderDescriptor.deployment_name_required && ( - - )} + {fetchModelsError && ( + + {fetchModelsError} + + )} + + + Retrieve the latest available models for this provider. + +
+ )} + + {!firstTimeConfiguration && ( + <> + - {!llmProviderDescriptor.single_model_supported && - (llmProviderDescriptor.model_configurations.length > 0 ? ( + {currentModelConfigurations.length > 0 ? ( ({ // don't clean up names here to give admins descriptive names / handle duplicates // like us.anthropic.claude-3-7-sonnet-20250219-v1:0 and anthropic.claude-3-7-sonnet-20250219-v1:0 @@ -598,127 +572,161 @@ export function LLMProviderUpdateForm({ value: modelConfiguration.name, }) )} - includeDefault maxHeight="max-h-56" /> ) : ( - ))} + )} - <> - - - {showAdvancedOptions && ( - <> - {llmProviderDescriptor.model_configurations.length > 0 && ( -
- ({ - value: modelConfiguration.name, - // don't clean up names here to give admins descriptive names / handle duplicates - // like us.anthropic.claude-3-7-sonnet-20250219-v1:0 and anthropic.claude-3-7-sonnet-20250219-v1:0 - label: modelConfiguration.name, - }) - )} - onChange={(selected) => - formikProps.setFieldValue( - "selected_model_names", - selected - ) - } - /> -
- )} - - + {llmProviderDescriptor.deployment_name_required && ( + )} + + {!llmProviderDescriptor.single_model_supported && + (currentModelConfigurations.length > 0 ? ( + ({ + // don't clean up names here to give admins descriptive names / handle duplicates + // like us.anthropic.claude-3-7-sonnet-20250219-v1:0 and anthropic.claude-3-7-sonnet-20250219-v1:0 + name: modelConfiguration.name, + value: modelConfiguration.name, + }) + )} + includeDefault + maxHeight="max-h-56" + /> + ) : ( + + ))} + + <> + + + {showAdvancedOptions && ( + <> + {currentModelConfigurations.length > 0 && ( +
+ ({ + value: modelConfiguration.name, + // don't clean up names here to give admins descriptive names / handle duplicates + // like us.anthropic.claude-3-7-sonnet-20250219-v1:0 and anthropic.claude-3-7-sonnet-20250219-v1:0 + label: modelConfiguration.name, + }) + )} + onChange={(selected) => + formikProps.setFieldValue( + "selected_model_names", + selected + ) + } + /> +
+ )} + + + )} + - - )} - - {/* NOTE: this is above the test button to make sure it's visible */} - {testError && {testError}} - -
- - {existingLlmProvider && ( - + {existingLlmProvider && ( + - )} -
- - )} + mutate(LLM_PROVIDERS_ADMIN_URL); + onClose(); + }} + > + Delete + + )} + + + ); + }}
); } diff --git a/web/src/app/admin/configuration/llm/interfaces.ts b/web/src/app/admin/configuration/llm/interfaces.ts index a19d0c8ce6c..6c32ba96bd8 100644 --- a/web/src/app/admin/configuration/llm/interfaces.ts +++ b/web/src/app/admin/configuration/llm/interfaces.ts @@ -1,3 +1,5 @@ +import { PopupSpec } from "@/components/admin/connectors/Popup"; + export interface CustomConfigKey { name: string; display_name: string; @@ -10,14 +12,11 @@ export interface CustomConfigKey { export type CustomConfigKeyType = "text_input" | "file_input"; -export interface ModelConfigurationUpsertRequest { +export interface ModelConfiguration { name: string; is_visible: boolean; max_input_tokens: number | null; -} - -export interface ModelConfiguration extends ModelConfigurationUpsertRequest { - supports_image_input: boolean; + supports_image_input: boolean | null; } export interface WellKnownLLMProviderDescriptor { @@ -34,6 +33,7 @@ export interface WellKnownLLMProviderDescriptor { model_configurations: ModelConfiguration[]; default_model: string | null; default_fast_model: string | null; + default_api_base: string | null; is_public: boolean; groups: number[]; } @@ -81,3 +81,28 @@ export interface LLMProviderDescriptor { groups: number[]; model_configurations: ModelConfiguration[]; } + +export interface OllamaModelResponse { + name: string; + max_input_tokens: number; + supports_image_input: boolean; +} + +export interface DynamicProviderConfig< + TApiResponse = any, + TProcessedResponse = ModelConfiguration, +> { + endpoint: string; + isDisabled: (values: any) => boolean; + disabledReason: string; + buildRequestBody: (args: { + values: any; + existingLlmProvider?: LLMProviderView; + }) => Record; + processResponse: ( + data: TApiResponse, + llmProviderDescriptor: WellKnownLLMProviderDescriptor + ) => TProcessedResponse[]; + getModelNames: (data: TApiResponse) => string[]; + successMessage: (count: number) => string; +} diff --git a/web/src/app/admin/configuration/llm/utils.ts b/web/src/app/admin/configuration/llm/utils.ts index fc3e80b812d..aa59f7c4b6a 100644 --- a/web/src/app/admin/configuration/llm/utils.ts +++ b/web/src/app/admin/configuration/llm/utils.ts @@ -11,6 +11,14 @@ import { OpenAISVG, QwenIcon, } from "@/components/icons/icons"; +import { + WellKnownLLMProviderDescriptor, + LLMProviderView, + DynamicProviderConfig, + OllamaModelResponse, + ModelConfiguration, +} from "./interfaces"; +import { PopupSpec } from "@/components/admin/connectors/Popup"; export const getProviderIcon = ( providerName: string, @@ -62,3 +70,158 @@ export const getProviderIcon = ( export const isAnthropic = (provider: string, modelName: string) => provider === "anthropic" || modelName.toLowerCase().includes("claude"); + +export const dynamicProviderConfigs: Record< + string, + DynamicProviderConfig +> = { + bedrock: { + endpoint: "/api/admin/llm/bedrock/available-models", + isDisabled: (values) => !values.custom_config?.AWS_REGION_NAME, + disabledReason: "AWS region is required to fetch Bedrock models", + buildRequestBody: ({ values, existingLlmProvider }) => ({ + aws_region_name: values.custom_config?.AWS_REGION_NAME, + aws_access_key_id: values.custom_config?.AWS_ACCESS_KEY_ID, + aws_secret_access_key: values.custom_config?.AWS_SECRET_ACCESS_KEY, + aws_bearer_token_bedrock: values.custom_config?.AWS_BEARER_TOKEN_BEDROCK, + provider_name: existingLlmProvider?.name, + }), + processResponse: (data: string[], llmProviderDescriptor) => + data.map((modelName) => { + const existingConfig = llmProviderDescriptor.model_configurations.find( + (config) => config.name === modelName + ); + return { + name: modelName, + is_visible: existingConfig?.is_visible ?? false, + max_input_tokens: null, + supports_image_input: existingConfig?.supports_image_input ?? null, + }; + }), + getModelNames: (data: string[]) => data, + successMessage: (count: number) => + `Successfully fetched ${count} models for the selected region (including cross-region inference models).`, + }, + ollama: { + endpoint: "/api/admin/llm/ollama/available-models", + isDisabled: (values) => !values.api_base, + disabledReason: "API Base is required to fetch Ollama models", + buildRequestBody: ({ values }) => ({ + api_base: values.api_base, + }), + processResponse: (data: OllamaModelResponse[], llmProviderDescriptor) => + data.map((modelData) => { + const existingConfig = llmProviderDescriptor.model_configurations.find( + (config) => config.name === modelData.name + ); + return { + name: modelData.name, + is_visible: existingConfig?.is_visible ?? true, + max_input_tokens: modelData.max_input_tokens, + supports_image_input: modelData.supports_image_input, + }; + }), + getModelNames: (data: OllamaModelResponse[]) => + data.map((model) => model.name), + successMessage: (count: number) => + `Successfully fetched ${count} models from Ollama.`, + }, +}; + +export const fetchModels = async ( + llmProviderDescriptor: WellKnownLLMProviderDescriptor, + existingLlmProvider: LLMProviderView | undefined, + values: any, + setFieldValue: any, + setIsFetchingModels: (loading: boolean) => void, + setFetchModelsError: (error: string) => void, + setPopup?: (popup: PopupSpec) => void +) => { + const config = dynamicProviderConfigs[llmProviderDescriptor.name]; + if (!config) { + return; + } + + if (config.isDisabled(values)) { + setFetchModelsError(config.disabledReason); + return; + } + + setIsFetchingModels(true); + setFetchModelsError(""); + + try { + const response = await fetch(config.endpoint, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify( + config.buildRequestBody({ values, existingLlmProvider }) + ), + }); + + if (!response.ok) { + let errorMessage = "Failed to fetch models"; + try { + const errorData = await response.json(); + errorMessage = errorData.detail || errorMessage; + } catch { + // ignore JSON parsing errors and use the fallback message + } + throw new Error(errorMessage); + } + + const availableModels = await response.json(); + const updatedModelConfigs = config.processResponse( + availableModels, + llmProviderDescriptor + ); + const availableModelNames = config.getModelNames(availableModels); + + // Store the updated model configurations in form state instead of mutating props + setFieldValue("fetched_model_configurations", updatedModelConfigs); + + // Update selected model names to only include previously visible models that are available + const previouslySelectedModels = values.selected_model_names || []; + const stillAvailableSelectedModels = previouslySelectedModels.filter( + (modelName: string) => availableModelNames.includes(modelName) + ); + setFieldValue("selected_model_names", stillAvailableSelectedModels); + + // Set a default model if none is set + if ( + (!values.default_model_name || + !availableModelNames.includes(values.default_model_name)) && + availableModelNames.length > 0 + ) { + setFieldValue("default_model_name", availableModelNames[0]); + } + + // Clear fast model if it's not in the new list + if ( + values.fast_default_model_name && + !availableModelNames.includes(values.fast_default_model_name) + ) { + setFieldValue("fast_default_model_name", null); + } + + // Force a re-render by updating a timestamp or counter + setFieldValue("_modelListUpdated", Date.now()); + + setPopup?.({ + message: config.successMessage(availableModelNames.length), + type: "success", + }); + } catch (error) { + const errorMessage = + error instanceof Error ? error.message : "Unknown error"; + setFetchModelsError(errorMessage); + setPopup?.({ + message: `Failed to fetch models: ${errorMessage}`, + type: "error", + }); + } finally { + setIsFetchingModels(false); + } +}; diff --git a/web/src/app/chat/components/input/ChatInputBar.tsx b/web/src/app/chat/components/input/ChatInputBar.tsx index dda968de675..70429bad932 100644 --- a/web/src/app/chat/components/input/ChatInputBar.tsx +++ b/web/src/app/chat/components/input/ChatInputBar.tsx @@ -660,7 +660,7 @@ export const ChatInputBar = React.memo(function ChatInputBar({ diff --git a/web/src/lib/hooks.ts b/web/src/lib/hooks.ts index 6a042c5617a..1514ac1b94e 100644 --- a/web/src/lib/hooks.ts +++ b/web/src/lib/hooks.ts @@ -955,6 +955,40 @@ const MODEL_DISPLAY_NAMES: { [key: string]: string } = { "ai21.jamba-instruct-v1:0": "Jamba Instruct", "ai21.j2-ultra-v1": "J2 Ultra", "ai21.j2-mid-v1": "J2 Mid", + + // Ollama cloud models + "gpt-oss:20b-cloud": "gpt-oss 20B Cloud", + "gpt-oss:120b-cloud": "gpt-oss 120B Cloud", + "deepseek-v3.1:671b-cloud": "DeepSeek-v3.1 671B Cloud", + "kimi-k2:1t": "Kimi K2 1T Cloud", + "qwen3-coder:480b-cloud": "Qwen3-Coder 480B Cloud", + + // Ollama models in litellm map (disjoint from ollama's supported model list) + // https://models.litellm.ai --> provider ollama + codegeex4: "CodeGeeX 4", + codegemma: "CodeGemma", + codellama: "CodeLLama", + "deepseek-coder-v2-base": "DeepSeek-Coder-v2 Base", + "deepseek-coder-v2-instruct": "DeepSeek-Coder-v2 Instruct", + "deepseek-coder-v2-lite-base": "DeepSeek-Coder-v2 Lite Base", + "deepseek-coder-v2-lite-instruct": "DeepSeek-Coder-v2 Lite Instruct", + "internlm2_5-20b-chat": "InternLM 2.5 20B Chat", + llama2: "Llama 2", + "llama2-uncensored": "Llama 2 Uncensored", + "llama2:13b": "Llama 2 13B", + "llama2:70b": "Llama 2 70B", + "llama2:7b": "Llama 2 7B", + llama3: "Llama 3", + "llama3:70b": "Llama 3 70B", + "llama3:8b": "Llama 3 8B", + mistral: "Mistral", // Mistral 7b + "mistral-7B-Instruct-v0.1": "Mistral 7B Instruct v0.1", + "mistral-7B-Instruct-v0.2": "Mistral 7B Instruct v0.2", + "mistral-large-instruct-2407": "Mistral Large Instruct 24.07", + "mixtral-8x22B-Instruct-v0.1": "Mixtral 8x22B Instruct v0.1", + "mixtral8x7B-Instruct-v0.1": "Mixtral 8x7B Instruct v0.1", + "orca-mini": "Orca Mini", + vicuna: "Vicuna", }; export function getDisplayNameForModel(modelName: string): string { @@ -972,29 +1006,6 @@ export function getDisplayNameForModel(modelName: string): string { return MODEL_DISPLAY_NAMES[modelName] || modelName; } -export const defaultModelsByProvider: { [name: string]: string[] } = { - openai: [ - "gpt-4", - "gpt-4o", - "gpt-4o-mini", - "gpt-4.1", - "o3-mini", - "o1-mini", - "o1", - "o4-mini", - "o3", - ], - bedrock: [ - "meta.llama3-1-70b-instruct-v1:0", - "meta.llama3-1-8b-instruct-v1:0", - "anthropic.claude-3-opus-20240229-v1:0", - "mistral.mistral-large-2402-v1:0", - "anthropic.claude-3-5-sonnet-20241022-v2:0", - "anthropic.claude-3-7-sonnet-20250219-v1:0", - ], - anthropic: ["claude-3-opus-20240229", "claude-3-5-sonnet-20241022"], -}; - // Get source metadata for configured sources - deduplicated by source type function getConfiguredSources( availableSources: ValidSources[]