1
1
import uuid
2
+ from typing import Any
2
3
3
4
import pytest
4
5
import requests
5
6
from requests .models import Response
6
7
7
8
from onyx .llm .utils import get_max_input_tokens
9
+ from onyx .llm .utils import model_supports_image_input
8
10
from onyx .server .manage .llm .models import ModelConfigurationUpsertRequest
9
11
from tests .integration .common_utils .constants import API_SERVER_URL
10
12
from tests .integration .common_utils .managers .user import UserManager
11
13
from tests .integration .common_utils .test_models import DATestUser
12
14
13
15
14
- _DEFAULT_MODELS = ["gpt-4" , "gpt-4o" ]
15
-
16
-
17
16
def _get_provider_by_id (admin_user : DATestUser , provider_id : str ) -> dict | None :
18
17
"""Utility function to fetch an LLM provider by ID"""
19
18
response = requests .get (
@@ -40,24 +39,32 @@ def assert_response_is_equivalent(
40
39
41
40
assert provider_data ["default_model_name" ] == default_model_name
42
41
43
- def fill_max_input_tokens_if_none (
42
+ def fill_max_input_tokens_and_supports_image_input (
44
43
req : ModelConfigurationUpsertRequest ,
45
- ) -> ModelConfigurationUpsertRequest :
46
- return ModelConfigurationUpsertRequest (
44
+ ) -> dict [ str , Any ] :
45
+ filled_with_max_input_tokens = ModelConfigurationUpsertRequest (
47
46
name = req .name ,
48
47
is_visible = req .is_visible ,
49
48
max_input_tokens = req .max_input_tokens
50
49
or get_max_input_tokens (
51
50
model_name = req .name , model_provider = default_model_name
52
51
),
53
52
)
53
+ return {
54
+ ** filled_with_max_input_tokens .model_dump (),
55
+ "supports_image_input" : model_supports_image_input (
56
+ req .name , created_provider ["provider" ]
57
+ ),
58
+ }
54
59
55
60
actual = set (
56
61
tuple (model_configuration .items ())
57
62
for model_configuration in provider_data ["model_configurations" ]
58
63
)
59
64
expected = set (
60
- tuple (fill_max_input_tokens_if_none (model_configuration ).dict ().items ())
65
+ tuple (
66
+ fill_max_input_tokens_and_supports_image_input (model_configuration ).items ()
67
+ )
61
68
for model_configuration in model_configurations
62
69
)
63
70
assert actual == expected
@@ -150,7 +157,7 @@ def test_create_llm_provider(
150
157
"api_key" : "sk-000000000000000000000000000000000000000000000000" ,
151
158
"default_model_name" : default_model_name ,
152
159
"model_configurations" : [
153
- model_configuration .dict ()
160
+ model_configuration .model_dump ()
154
161
for model_configuration in model_configurations
155
162
],
156
163
"is_public" : True ,
0 commit comments