Skip to content

Commit 939835d

Browse files
authored
Merge pull request #150 from oracle-samples/110-modelapi-list-generation
New endpoint for listing models
2 parents 7f66f35 + 0bfdf02 commit 939835d

File tree

4 files changed

+76
-7
lines changed

4 files changed

+76
-7
lines changed

src/client/content/config/models.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,16 @@ def get_models(model_type: ModelTypeType = None, force: bool = False) -> dict[st
5858
state[enable_key] = {}
5959

6060

61+
@st.cache_data
62+
def get_model_apis(model_type: ModelTypeType = None) -> list:
63+
"""Get list of valid APIs; function for Streamlit caching"""
64+
response = api_call.get(
65+
endpoint="v1/models/api",
66+
params={"model_type": model_type},
67+
)
68+
return response
69+
70+
6171
def create_model(model: Model) -> None:
6272
"""Add either Language Model or Embed Model"""
6373
api_call.post(
@@ -118,10 +128,7 @@ def edit_model(model_type: ModelTypeType, action: Literal["add", "edit"], model_
118128
key="add_model_name",
119129
disabled=action == "edit",
120130
)
121-
if model_type == "ll":
122-
api_values = list({models["api"] for models in state.ll_model_config.values()})
123-
else:
124-
api_values = list({models["api"] for models in state.embed_model_config.values()})
131+
api_values = get_model_apis(model_type)
125132
api_index = next((i for i, item in enumerate(api_values) if item == model.api), None)
126133
model.api = st.selectbox(
127134
"API:",

src/common/schema.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# spell-checker:ignore ollama, hnsw, mult, ocid, testset
66

77
from typing import Optional, Literal, Union
8-
from pydantic import BaseModel, Field, PrivateAttr
8+
from pydantic import BaseModel, Field, PrivateAttr, model_validator
99

1010
from langchain_core.messages import ChatMessage
1111
import oracledb
@@ -18,6 +18,24 @@
1818
DistanceMetrics = Literal["COSINE", "EUCLIDEAN_DISTANCE", "DOT_PRODUCT"]
1919
IndexTypes = Literal["HNSW", "IVF"]
2020

21+
# ModelAPIs
22+
EmbedAPI = Literal[
23+
"OllamaEmbeddings",
24+
"OCIGenAIEmbeddings",
25+
"CompatOpenAIEmbeddings",
26+
"OpenAIEmbeddings",
27+
"CohereEmbeddings",
28+
"HuggingFaceEndpointEmbeddings",
29+
]
30+
LlAPI = Literal[
31+
"ChatOllama",
32+
"ChatOCIGenAI",
33+
"CompatOpenAI",
34+
"Perplexity",
35+
"OpenAI",
36+
"Cohere",
37+
]
38+
2139

2240
#####################################################
2341
# Database
@@ -110,6 +128,21 @@ class Model(ModelAccess, LanguageModelParameters, EmbeddingModelParameters):
110128
openai_compat: bool = Field(default=True, description="Is the API OpenAI compatible?")
111129
status: Statuses = Field(default="UNVERIFIED", description="Status (read-only)", readOnly=True)
112130

131+
@model_validator(mode="after")
132+
def check_api_matches_type(self):
133+
"""Validate valid API"""
134+
ll_apis = LlAPI.__args__
135+
embed_apis = EmbedAPI.__args__
136+
137+
if not self.api or self.api == "unset":
138+
return self
139+
140+
if self.type == "ll" and self.api not in ll_apis:
141+
raise ValueError(f"API '{self.api}' is not valid for type 'll'. Must be one of: {ll_apis}")
142+
if self.type == "embed" and self.api not in embed_apis:
143+
raise ValueError(f"API '{self.api}' is not valid for type 'embed'. Must be one of: {embed_apis}")
144+
return self
145+
113146

114147
#####################################################
115148
# Oracle Cloud Infrastructure

src/server/endpoints.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from urllib.parse import urlparse
1313
from pathlib import Path
1414
import shutil
15-
from typing import AsyncGenerator, Literal, Optional
15+
from typing import AsyncGenerator, Literal, Optional, get_args
1616
import time
1717
import requests
1818
from pydantic import HttpUrl
@@ -309,6 +309,19 @@ async def split_embed(
309309
#################################################
310310
# models Endpoints
311311
#################################################
312+
@auth.get("/v1/models/api", description="Get support model APIs", response_model=list)
313+
async def models_list_api(
314+
model_type: Optional[schema.ModelTypeType] = Query(None),
315+
) -> list[schema.Model]:
316+
"""List all models APIs after applying filters if specified"""
317+
logger.debug("Received models_list_api - type: %s", model_type)
318+
if model_type == "ll":
319+
return list(get_args(schema.LlAPI))
320+
elif model_type == "embed":
321+
return list(get_args(schema.EmbedAPI))
322+
else:
323+
return list()
324+
312325
@auth.get("/v1/models", description="Get all models", response_model=list[schema.Model])
313326
async def models_list(
314327
model_type: Optional[schema.ModelTypeType] = Query(None),

tests/server/test_endpoints_models.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
# spell-checker: disable
66
# pylint: disable=import-error
77

8-
from typing import Any, Dict
8+
from typing import Any, Dict, get_args
99
import pytest
1010
from fastapi.testclient import TestClient
1111
from conftest import TEST_HEADERS, TEST_BAD_HEADERS
12+
from common.schema import LlAPI, EmbedAPI
1213

1314

1415
#############################################################################
@@ -55,6 +56,21 @@ def test_no_auth(self, client: TestClient, test_case: Dict[str, Any]) -> None:
5556
class TestEndpoints:
5657
"""Test endpoints with AuthN"""
5758

59+
@pytest.mark.parametrize(
60+
"model_type,expected",
61+
[
62+
("ll", list(get_args(LlAPI))),
63+
("embed", list(get_args(EmbedAPI))),
64+
(None, []),
65+
],
66+
)
67+
def test_models_list_api(self, client: TestClient, model_type, expected):
68+
"""Get a list of model APIs to use with tests"""
69+
params = {"model_type": model_type} if model_type else {}
70+
response = client.get("/v1/models/api", headers=TEST_HEADERS, params=params)
71+
assert response.status_code == 200
72+
assert sorted(response.json()) == sorted(expected)
73+
5874
def models_list(self, client: TestClient):
5975
"""Get a list of bootstrapped models to use with tests"""
6076
response = client.get("/v1/models", headers=TEST_HEADERS)

0 commit comments

Comments
 (0)