diff --git a/backend/alembic/versions/47a07e1a38f1_fix_invalid_model_configurations_state.py b/backend/alembic/versions/47a07e1a38f1_fix_invalid_model_configurations_state.py
index afa60b19536..341a5a5389c 100644
--- a/backend/alembic/versions/47a07e1a38f1_fix_invalid_model_configurations_state.py
+++ b/backend/alembic/versions/47a07e1a38f1_fix_invalid_model_configurations_state.py
@@ -24,7 +24,7 @@
depends_on = None
-class ModelConfiguration(BaseModel):
+class _SimpleModelConfiguration(BaseModel):
# Configure model to read from attributes
model_config = ConfigDict(from_attributes=True)
@@ -82,7 +82,7 @@ def upgrade() -> None:
)
model_configurations = [
- ModelConfiguration.model_validate(model_configuration)
+ _SimpleModelConfiguration.model_validate(model_configuration)
for model_configuration in connection.execute(
sa.select(
model_configuration_table.c.id,
diff --git a/backend/onyx/server/manage/llm/models.py b/backend/onyx/server/manage/llm/models.py
index e6446961b0a..3f559ff0e66 100644
--- a/backend/onyx/server/manage/llm/models.py
+++ b/backend/onyx/server/manage/llm/models.py
@@ -4,6 +4,7 @@
from pydantic import Field
from onyx.llm.utils import get_max_input_tokens
+from onyx.llm.utils import model_supports_image_input
if TYPE_CHECKING:
@@ -152,6 +153,7 @@ class ModelConfigurationView(BaseModel):
name: str
is_visible: bool | None = False
max_input_tokens: int | None = None
+ supports_image_input: bool
@classmethod
def from_model(
@@ -166,6 +168,10 @@ def from_model(
or get_max_input_tokens(
model_name=model_configuration_model.name, model_provider=provider_name
),
+ supports_image_input=model_supports_image_input(
+ model_name=model_configuration_model.name,
+ model_provider=provider_name,
+ ),
)
diff --git a/backend/tests/integration/tests/llm_provider/test_llm_provider.py b/backend/tests/integration/tests/llm_provider/test_llm_provider.py
index 073c28eba71..85d46ef2fe3 100644
--- a/backend/tests/integration/tests/llm_provider/test_llm_provider.py
+++ b/backend/tests/integration/tests/llm_provider/test_llm_provider.py
@@ -1,19 +1,18 @@
import uuid
+from typing import Any
import pytest
import requests
from requests.models import Response
from onyx.llm.utils import get_max_input_tokens
+from onyx.llm.utils import model_supports_image_input
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
-_DEFAULT_MODELS = ["gpt-4", "gpt-4o"]
-
-
def _get_provider_by_id(admin_user: DATestUser, provider_id: str) -> dict | None:
"""Utility function to fetch an LLM provider by ID"""
response = requests.get(
@@ -40,10 +39,10 @@ def assert_response_is_equivalent(
assert provider_data["default_model_name"] == default_model_name
- def fill_max_input_tokens_if_none(
+ def fill_max_input_tokens_and_supports_image_input(
req: ModelConfigurationUpsertRequest,
- ) -> ModelConfigurationUpsertRequest:
- return ModelConfigurationUpsertRequest(
+ ) -> dict[str, Any]:
+ filled_with_max_input_tokens = ModelConfigurationUpsertRequest(
name=req.name,
is_visible=req.is_visible,
max_input_tokens=req.max_input_tokens
@@ -51,13 +50,21 @@ def fill_max_input_tokens_if_none(
model_name=req.name, model_provider=default_model_name
),
)
+ return {
+ **filled_with_max_input_tokens.model_dump(),
+ "supports_image_input": model_supports_image_input(
+ req.name, created_provider["provider"]
+ ),
+ }
actual = set(
tuple(model_configuration.items())
for model_configuration in provider_data["model_configurations"]
)
expected = set(
- tuple(fill_max_input_tokens_if_none(model_configuration).dict().items())
+ tuple(
+ fill_max_input_tokens_and_supports_image_input(model_configuration).items()
+ )
for model_configuration in model_configurations
)
assert actual == expected
@@ -150,7 +157,7 @@ def test_create_llm_provider(
"api_key": "sk-000000000000000000000000000000000000000000000000",
"default_model_name": default_model_name,
"model_configurations": [
- model_configuration.dict()
+ model_configuration.model_dump()
for model_configuration in model_configurations
],
"is_public": True,
diff --git a/web/src/app/admin/assistants/AssistantEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx
index b5b38972229..c525e79fdeb 100644
--- a/web/src/app/admin/assistants/AssistantEditor.tsx
+++ b/web/src/app/admin/assistants/AssistantEditor.tsx
@@ -25,8 +25,8 @@ import { getDisplayNameForModel, useLabels } from "@/lib/hooks";
import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable";
import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences";
import {
- checkLLMSupportsImageInput,
destructureValue,
+ modelSupportsImageInput,
structureValue,
} from "@/lib/llm/utils";
import { ToolSnapshot } from "@/lib/tools/interfaces";
@@ -139,6 +139,7 @@ export function AssistantEditor({
admin?: boolean;
}) {
const { refreshAssistants, isImageGenerationAvailable } = useAssistants();
+
const router = useRouter();
const searchParams = useSearchParams();
const isAdminPage = searchParams?.get("admin") === "true";
@@ -643,7 +644,8 @@ export function AssistantEditor({
// model must support image input for image generation
// to work
- const currentLLMSupportsImageOutput = checkLLMSupportsImageInput(
+ const currentLLMSupportsImageOutput = modelSupportsImageInput(
+ llmProviders,
values.llm_model_version_override || defaultModelName || ""
);
diff --git a/web/src/app/admin/configuration/llm/interfaces.ts b/web/src/app/admin/configuration/llm/interfaces.ts
index 17c8e43838b..518967c1bcd 100644
--- a/web/src/app/admin/configuration/llm/interfaces.ts
+++ b/web/src/app/admin/configuration/llm/interfaces.ts
@@ -71,6 +71,7 @@ export interface ModelConfiguration {
name: string;
is_visible: boolean;
max_input_tokens: number | null;
+ supports_image_input: boolean;
}
export interface VisionProvider extends LLMProviderView {
diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx
index 01c09f6274a..7371991e5ea 100644
--- a/web/src/app/chat/ChatPage.tsx
+++ b/web/src/app/chat/ChatPage.tsx
@@ -90,8 +90,8 @@ import { buildFilters } from "@/lib/search/utils";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import Dropzone from "react-dropzone";
import {
- checkLLMSupportsImageInput,
getFinalLLM,
+ modelSupportsImageInput,
structureValue,
} from "@/lib/llm/utils";
import { ChatInputBar } from "./input/ChatInputBar";
@@ -1952,7 +1952,7 @@ export function ChatPage({
liveAssistant,
llmManager.currentLlm
);
- const llmAcceptsImages = checkLLMSupportsImageInput(llmModel);
+ const llmAcceptsImages = modelSupportsImageInput(llmProviders, llmModel);
const imageFiles = acceptedFiles.filter((file) =>
file.type.startsWith("image/")
diff --git a/web/src/app/chat/input/LLMPopover.tsx b/web/src/app/chat/input/LLMPopover.tsx
index 476d4e61157..88dead2732a 100644
--- a/web/src/app/chat/input/LLMPopover.tsx
+++ b/web/src/app/chat/input/LLMPopover.tsx
@@ -6,7 +6,7 @@ import {
} from "@/components/ui/popover";
import { getDisplayNameForModel } from "@/lib/hooks";
import {
- checkLLMSupportsImageInput,
+ modelSupportsImageInput,
destructureValue,
structureValue,
} from "@/lib/llm/utils";
@@ -175,7 +175,10 @@ export default function LLMPopover({
>
{llmOptions.map(({ name, icon, value }, index) => {
- if (!requiresImageGeneration || checkLLMSupportsImageInput(name)) {
+ if (
+ !requiresImageGeneration ||
+ modelSupportsImageInput(llmProviders, name)
+ ) {
return (