diff --git a/backend/onyx/server/manage/llm/api.py b/backend/onyx/server/manage/llm/api.py index dc36ce649cb..b5b52f59014 100644 --- a/backend/onyx/server/manage/llm/api.py +++ b/backend/onyx/server/manage/llm/api.py @@ -142,19 +142,20 @@ def put_llm_provider( detail=f"LLM Provider with name {llm_provider.name} already exists", ) - # Ensure default_model_name and fast_default_model_name are in display_model_names - # This is necessary for custom models and Bedrock/Azure models - if llm_provider.display_model_names is None: - llm_provider.display_model_names = [] - - if llm_provider.default_model_name not in llm_provider.display_model_names: - llm_provider.display_model_names.append(llm_provider.default_model_name) - - if ( - llm_provider.fast_default_model_name - and llm_provider.fast_default_model_name not in llm_provider.display_model_names - ): - llm_provider.display_model_names.append(llm_provider.fast_default_model_name) + if llm_provider.display_model_names is not None: + # Ensure default_model_name and fast_default_model_name are in display_model_names + # This is necessary for custom models and Bedrock/Azure models + if llm_provider.default_model_name not in llm_provider.display_model_names: + llm_provider.display_model_names.append(llm_provider.default_model_name) + + if ( + llm_provider.fast_default_model_name + and llm_provider.fast_default_model_name + not in llm_provider.display_model_names + ): + llm_provider.display_model_names.append( + llm_provider.fast_default_model_name + ) try: return upsert_llm_provider( diff --git a/backend/tests/integration/conftest.py b/backend/tests/integration/conftest.py index 5eba1e66f87..ec50669d0bb 100644 --- a/backend/tests/integration/conftest.py +++ b/backend/tests/integration/conftest.py @@ -4,8 +4,12 @@ import pytest from sqlalchemy.orm import Session +from onyx.auth.schemas import UserRole from onyx.db.engine import get_session_context_manager from onyx.db.search_settings import get_current_search_settings +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.managers.user import build_email +from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.reset import reset_all from tests.integration.common_utils.reset import reset_all_multitenant @@ -57,6 +61,30 @@ def new_admin_user(reset: None) -> DATestUser | None: return None +@pytest.fixture +def admin_user() -> DATestUser | None: + try: + return UserManager.create(name="admin_user") + except Exception: + pass + + try: + return UserManager.login_as_user( + DATestUser( + id="", + email=build_email("admin_user"), + password=DEFAULT_PASSWORD, + headers=GENERAL_HEADERS, + role=UserRole.ADMIN, + is_active=True, + ) + ) + except Exception: + pass + + return None + + @pytest.fixture def reset_multitenant() -> None: reset_all_multitenant() diff --git a/backend/tests/integration/openai_assistants_api/conftest.py b/backend/tests/integration/openai_assistants_api/conftest.py index 37ada5cd87b..5fc6660ee62 100644 --- a/backend/tests/integration/openai_assistants_api/conftest.py +++ b/backend/tests/integration/openai_assistants_api/conftest.py @@ -7,40 +7,12 @@ from tests.integration.common_utils.constants import API_SERVER_URL from tests.integration.common_utils.constants import GENERAL_HEADERS from tests.integration.common_utils.managers.llm_provider import LLMProviderManager -from tests.integration.common_utils.managers.user import build_email -from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD -from tests.integration.common_utils.managers.user import UserManager -from tests.integration.common_utils.managers.user import UserRole from tests.integration.common_utils.test_models import DATestLLMProvider from tests.integration.common_utils.test_models import DATestUser BASE_URL = f"{API_SERVER_URL}/openai-assistants" -@pytest.fixture -def admin_user() -> DATestUser | None: - try: - return UserManager.create("admin_user") - except Exception: - pass - - try: - return UserManager.login_as_user( - DATestUser( - id="", - email=build_email("admin_user"), - password=DEFAULT_PASSWORD, - headers=GENERAL_HEADERS, - role=UserRole.ADMIN, - is_active=True, - ) - ) - except Exception: - pass - - return None - - @pytest.fixture def llm_provider(admin_user: DATestUser | None) -> DATestLLMProvider: return LLMProviderManager.create(user_performing_action=admin_user) diff --git a/backend/tests/integration/tests/llm_provider/test_llm_provider.py b/backend/tests/integration/tests/llm_provider/test_llm_provider.py new file mode 100644 index 00000000000..4540f24b239 --- /dev/null +++ b/backend/tests/integration/tests/llm_provider/test_llm_provider.py @@ -0,0 +1,120 @@ +import uuid + +import requests + +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.test_models import DATestUser + + +_DEFAULT_MODELS = ["gpt-4", "gpt-4o"] + + +def _get_provider_by_id(admin_user: DATestUser, provider_id: str) -> dict | None: + """Utility function to fetch an LLM provider by ID""" + response = requests.get( + f"{API_SERVER_URL}/admin/llm/provider", + headers=admin_user.headers, + ) + assert response.status_code == 200 + providers = response.json() + return next((p for p in providers if p["id"] == provider_id), None) + + +def test_create_llm_provider_without_display_model_names( + admin_user: DATestUser, +) -> None: + """Test creating an LLM provider without specifying + display_model_names and verify it's null in response""" + # Create LLM provider without model_names + response = requests.put( + f"{API_SERVER_URL}/admin/llm/provider", + headers=admin_user.headers, + json={ + "name": str(uuid.uuid4()), + "provider": "openai", + "default_model_name": _DEFAULT_MODELS[0], + "model_names": _DEFAULT_MODELS, + "is_public": True, + "groups": [], + }, + ) + assert response.status_code == 200 + created_provider = response.json() + provider_data = _get_provider_by_id(admin_user, created_provider["id"]) + + # Verify model_names is None/null + assert provider_data is not None + assert provider_data["model_names"] == _DEFAULT_MODELS + assert provider_data["default_model_name"] == _DEFAULT_MODELS[0] + assert provider_data["display_model_names"] is None + + +def test_update_llm_provider_model_names(admin_user: DATestUser) -> None: + """Test updating an LLM provider's model_names""" + # First create provider without model_names + name = str(uuid.uuid4()) + response = requests.put( + f"{API_SERVER_URL}/admin/llm/provider", + headers=admin_user.headers, + json={ + "name": name, + "provider": "openai", + "default_model_name": _DEFAULT_MODELS[0], + "model_names": [_DEFAULT_MODELS[0]], + "is_public": True, + "groups": [], + }, + ) + assert response.status_code == 200 + created_provider = response.json() + + # Update with model_names + response = requests.put( + f"{API_SERVER_URL}/admin/llm/provider", + headers=admin_user.headers, + json={ + "id": created_provider["id"], + "name": name, + "provider": created_provider["provider"], + "default_model_name": _DEFAULT_MODELS[0], + "model_names": _DEFAULT_MODELS, + "is_public": True, + "groups": [], + }, + ) + assert response.status_code == 200 + + # Verify update + provider_data = _get_provider_by_id(admin_user, created_provider["id"]) + assert provider_data is not None + assert provider_data["model_names"] == _DEFAULT_MODELS + + +def test_delete_llm_provider(admin_user: DATestUser) -> None: + """Test deleting an LLM provider""" + # Create a provider + response = requests.put( + f"{API_SERVER_URL}/admin/llm/provider", + headers=admin_user.headers, + json={ + "name": "test-provider-delete", + "provider": "openai", + "default_model_name": _DEFAULT_MODELS[0], + "model_names": _DEFAULT_MODELS, + "is_public": True, + "groups": [], + }, + ) + assert response.status_code == 200 + created_provider = response.json() + + # Delete the provider + response = requests.delete( + f"{API_SERVER_URL}/admin/llm/provider/{created_provider['id']}", + headers=admin_user.headers, + ) + assert response.status_code == 200 + + # Verify provider is deleted by checking it's not in the list + provider_data = _get_provider_by_id(admin_user, created_provider["id"]) + assert provider_data is None