Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions backend/onyx/server/manage/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,19 +142,20 @@ def put_llm_provider(
detail=f"LLM Provider with name {llm_provider.name} already exists",
)

# Ensure default_model_name and fast_default_model_name are in display_model_names
# This is necessary for custom models and Bedrock/Azure models
if llm_provider.display_model_names is None:
llm_provider.display_model_names = []

if llm_provider.default_model_name not in llm_provider.display_model_names:
llm_provider.display_model_names.append(llm_provider.default_model_name)

if (
llm_provider.fast_default_model_name
and llm_provider.fast_default_model_name not in llm_provider.display_model_names
):
llm_provider.display_model_names.append(llm_provider.fast_default_model_name)
if llm_provider.display_model_names is not None:
# Ensure default_model_name and fast_default_model_name are in display_model_names
# This is necessary for custom models and Bedrock/Azure models
if llm_provider.default_model_name not in llm_provider.display_model_names:
llm_provider.display_model_names.append(llm_provider.default_model_name)

if (
llm_provider.fast_default_model_name
and llm_provider.fast_default_model_name
not in llm_provider.display_model_names
):
llm_provider.display_model_names.append(
llm_provider.fast_default_model_name
)

try:
return upsert_llm_provider(
Expand Down
28 changes: 28 additions & 0 deletions backend/tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
import pytest
from sqlalchemy.orm import Session

from onyx.auth.schemas import UserRole
from onyx.db.engine import get_session_context_manager
from onyx.db.search_settings import get_current_search_settings
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.managers.user import build_email
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.reset import reset_all
from tests.integration.common_utils.reset import reset_all_multitenant
Expand Down Expand Up @@ -57,6 +61,30 @@ def new_admin_user(reset: None) -> DATestUser | None:
return None


@pytest.fixture
def admin_user() -> DATestUser | None:
try:
return UserManager.create(name="admin_user")
except Exception:
pass

try:
return UserManager.login_as_user(
DATestUser(
id="",
email=build_email("admin_user"),
password=DEFAULT_PASSWORD,
headers=GENERAL_HEADERS,
role=UserRole.ADMIN,
is_active=True,
)
)
except Exception:
pass

return None


@pytest.fixture
def reset_multitenant() -> None:
reset_all_multitenant()
28 changes: 0 additions & 28 deletions backend/tests/integration/openai_assistants_api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,12 @@
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import build_email
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user import UserRole
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser

BASE_URL = f"{API_SERVER_URL}/openai-assistants"


@pytest.fixture
def admin_user() -> DATestUser | None:
try:
return UserManager.create("admin_user")
except Exception:
pass

try:
return UserManager.login_as_user(
DATestUser(
id="",
email=build_email("admin_user"),
password=DEFAULT_PASSWORD,
headers=GENERAL_HEADERS,
role=UserRole.ADMIN,
is_active=True,
)
)
except Exception:
pass

return None


@pytest.fixture
def llm_provider(admin_user: DATestUser | None) -> DATestLLMProvider:
return LLMProviderManager.create(user_performing_action=admin_user)
Expand Down
120 changes: 120 additions & 0 deletions backend/tests/integration/tests/llm_provider/test_llm_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import uuid

import requests

from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.test_models import DATestUser


_DEFAULT_MODELS = ["gpt-4", "gpt-4o"]


def _get_provider_by_id(admin_user: DATestUser, provider_id: str) -> dict | None:
"""Utility function to fetch an LLM provider by ID"""
response = requests.get(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
)
assert response.status_code == 200
providers = response.json()
return next((p for p in providers if p["id"] == provider_id), None)


def test_create_llm_provider_without_display_model_names(
admin_user: DATestUser,
) -> None:
"""Test creating an LLM provider without specifying
display_model_names and verify it's null in response"""
# Create LLM provider without model_names
response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
json={
"name": str(uuid.uuid4()),
"provider": "openai",
"default_model_name": _DEFAULT_MODELS[0],
"model_names": _DEFAULT_MODELS,
"is_public": True,
"groups": [],
},
)
assert response.status_code == 200
created_provider = response.json()
provider_data = _get_provider_by_id(admin_user, created_provider["id"])

# Verify model_names is None/null
assert provider_data is not None
assert provider_data["model_names"] == _DEFAULT_MODELS
assert provider_data["default_model_name"] == _DEFAULT_MODELS[0]
assert provider_data["display_model_names"] is None


def test_update_llm_provider_model_names(admin_user: DATestUser) -> None:
"""Test updating an LLM provider's model_names"""
# First create provider without model_names
name = str(uuid.uuid4())
response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
json={
"name": name,
"provider": "openai",
"default_model_name": _DEFAULT_MODELS[0],
"model_names": [_DEFAULT_MODELS[0]],
"is_public": True,
"groups": [],
},
)
assert response.status_code == 200
created_provider = response.json()

# Update with model_names
response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
json={
"id": created_provider["id"],
"name": name,
"provider": created_provider["provider"],
"default_model_name": _DEFAULT_MODELS[0],
"model_names": _DEFAULT_MODELS,
"is_public": True,
"groups": [],
},
)
assert response.status_code == 200

# Verify update
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
assert provider_data is not None
assert provider_data["model_names"] == _DEFAULT_MODELS


def test_delete_llm_provider(admin_user: DATestUser) -> None:
"""Test deleting an LLM provider"""
# Create a provider
response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
json={
"name": "test-provider-delete",
"provider": "openai",
"default_model_name": _DEFAULT_MODELS[0],
"model_names": _DEFAULT_MODELS,
"is_public": True,
"groups": [],
},
)
assert response.status_code == 200
created_provider = response.json()

# Delete the provider
response = requests.delete(
f"{API_SERVER_URL}/admin/llm/provider/{created_provider['id']}",
headers=admin_user.headers,
)
assert response.status_code == 200

# Verify provider is deleted by checking it's not in the list
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
assert provider_data is None
Loading