Skip to content

Commit 22aca0b

Browse files
committed
Merge remote-tracking branch 'origin/main' into 157-add-unstructured
2 parents 24775e4 + 939835d commit 22aca0b

File tree

7 files changed

+102
-8
lines changed

7 files changed

+102
-8
lines changed

docs/content/help/troubleshooting/_index.md

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,24 @@ Set a rate limit based on the API Key restrictions.
3535
During the Evaluation in the **Testbed**, a database error occurs: `DPY-4011: the database or network closed the connection`
3636

3737
**_Solution_**:
38-
Increase the memory of the vector_memory_size. If this is an Oracle Autonomous Database, scale up the CPU.
38+
Increase the memory of the vector_memory_size. If this is an Oracle Autonomous Database, scale up the CPU.
39+
40+
## Autonomous Database behind VPN
41+
42+
**_Problem_**:
43+
Connection to an Autonomous database while inside a VPN fails.
44+
45+
**_Solution_**:
46+
Update the database connection string to include a `https_proxy` and `https_proxy_port`.
47+
48+
For example:
49+
50+
```text
51+
myadb_high = (
52+
description=(
53+
address=
54+
(protocol=tcps)(port=1522)
55+
(https_proxy=<proxy_host>)(https_proxy_port=<proxy_port>) # <-- Add
56+
(host=<adb_host>)
57+
)(connect_data=(service_name=s<service_name>))(security=(ssl_server_dn_match=yes))
58+
)```

src/client/content/config/models.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,16 @@ def get_models(model_type: ModelTypeType = None, force: bool = False) -> dict[st
5858
state[enable_key] = {}
5959

6060

61+
@st.cache_data
62+
def get_model_apis(model_type: ModelTypeType = None) -> list:
63+
"""Get list of valid APIs; function for Streamlit caching"""
64+
response = api_call.get(
65+
endpoint="v1/models/api",
66+
params={"model_type": model_type},
67+
)
68+
return response
69+
70+
6171
def create_model(model: Model) -> None:
6272
"""Add either Language Model or Embed Model"""
6373
api_call.post(
@@ -94,6 +104,10 @@ def delete_model(model: Model) -> None:
94104
logger.info("Model deleted: %s", model.name)
95105
get_models(model.type, force=True)
96106

107+
# If deleted model is the set model; unset the user settings
108+
if state.user_settings["ll_model"]["model"] == model.name:
109+
state.user_settings["ll_model"]["model"] = None
110+
97111

98112
@st.dialog("Model Configuration", width="large")
99113
def edit_model(model_type: ModelTypeType, action: Literal["add", "edit"], model_name: ModelNameType = None) -> None:
@@ -114,10 +128,7 @@ def edit_model(model_type: ModelTypeType, action: Literal["add", "edit"], model_
114128
key="add_model_name",
115129
disabled=action == "edit",
116130
)
117-
if model_type == "ll":
118-
api_values = list({models["api"] for models in state.ll_model_config.values()})
119-
else:
120-
api_values = list({models["api"] for models in state.embed_model_config.values()})
131+
api_values = get_model_apis(model_type)
121132
api_index = next((i for i, item in enumerate(api_values) if item == model.api), None)
122133
model.api = st.selectbox(
123134
"API:",

src/client/media/favicon.png

1.49 KB
Loading

src/common/schema.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# spell-checker:ignore ollama, hnsw, mult, ocid, testset
66

77
from typing import Optional, Literal, Union
8-
from pydantic import BaseModel, Field, PrivateAttr
8+
from pydantic import BaseModel, Field, PrivateAttr, model_validator
99

1010
from langchain_core.messages import ChatMessage
1111
import oracledb
@@ -18,6 +18,24 @@
1818
DistanceMetrics = Literal["COSINE", "EUCLIDEAN_DISTANCE", "DOT_PRODUCT"]
1919
IndexTypes = Literal["HNSW", "IVF"]
2020

21+
# ModelAPIs
22+
EmbedAPI = Literal[
23+
"OllamaEmbeddings",
24+
"OCIGenAIEmbeddings",
25+
"CompatOpenAIEmbeddings",
26+
"OpenAIEmbeddings",
27+
"CohereEmbeddings",
28+
"HuggingFaceEndpointEmbeddings",
29+
]
30+
LlAPI = Literal[
31+
"ChatOllama",
32+
"ChatOCIGenAI",
33+
"CompatOpenAI",
34+
"Perplexity",
35+
"OpenAI",
36+
"Cohere",
37+
]
38+
2139

2240
#####################################################
2341
# Database
@@ -110,6 +128,21 @@ class Model(ModelAccess, LanguageModelParameters, EmbeddingModelParameters):
110128
openai_compat: bool = Field(default=True, description="Is the API OpenAI compatible?")
111129
status: Statuses = Field(default="UNVERIFIED", description="Status (read-only)", readOnly=True)
112130

131+
@model_validator(mode="after")
132+
def check_api_matches_type(self):
133+
"""Validate valid API"""
134+
ll_apis = LlAPI.__args__
135+
embed_apis = EmbedAPI.__args__
136+
137+
if not self.api or self.api == "unset":
138+
return self
139+
140+
if self.type == "ll" and self.api not in ll_apis:
141+
raise ValueError(f"API '{self.api}' is not valid for type 'll'. Must be one of: {ll_apis}")
142+
if self.type == "embed" and self.api not in embed_apis:
143+
raise ValueError(f"API '{self.api}' is not valid for type 'embed'. Must be one of: {embed_apis}")
144+
return self
145+
113146

114147
#####################################################
115148
# Oracle Cloud Infrastructure

src/launch_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def main() -> None:
3838
"""Streamlit GUI"""
3939
st.set_page_config(
4040
page_title="Oracle AI Optimizer and Toolkit",
41+
page_icon="client/media/favicon.png",
4142
layout="wide",
4243
initial_sidebar_state="expanded",
4344
menu_items={

src/server/endpoints.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from urllib.parse import urlparse
1313
from pathlib import Path
1414
import shutil
15-
from typing import AsyncGenerator, Literal, Optional
15+
from typing import AsyncGenerator, Literal, Optional, get_args
1616
import time
1717
import requests
1818
from pydantic import HttpUrl
@@ -309,6 +309,19 @@ async def split_embed(
309309
#################################################
310310
# models Endpoints
311311
#################################################
312+
@auth.get("/v1/models/api", description="Get support model APIs", response_model=list)
313+
async def models_list_api(
314+
model_type: Optional[schema.ModelTypeType] = Query(None),
315+
) -> list[schema.Model]:
316+
"""List all models APIs after applying filters if specified"""
317+
logger.debug("Received models_list_api - type: %s", model_type)
318+
if model_type == "ll":
319+
return list(get_args(schema.LlAPI))
320+
elif model_type == "embed":
321+
return list(get_args(schema.EmbedAPI))
322+
else:
323+
return list()
324+
312325
@auth.get("/v1/models", description="Get all models", response_model=list[schema.Model])
313326
async def models_list(
314327
model_type: Optional[schema.ModelTypeType] = Query(None),

tests/server/test_endpoints_models.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
# spell-checker: disable
66
# pylint: disable=import-error
77

8-
from typing import Any, Dict
8+
from typing import Any, Dict, get_args
99
import pytest
1010
from fastapi.testclient import TestClient
1111
from conftest import TEST_HEADERS, TEST_BAD_HEADERS
12+
from common.schema import LlAPI, EmbedAPI
1213

1314

1415
#############################################################################
@@ -55,6 +56,21 @@ def test_no_auth(self, client: TestClient, test_case: Dict[str, Any]) -> None:
5556
class TestEndpoints:
5657
"""Test endpoints with AuthN"""
5758

59+
@pytest.mark.parametrize(
60+
"model_type,expected",
61+
[
62+
("ll", list(get_args(LlAPI))),
63+
("embed", list(get_args(EmbedAPI))),
64+
(None, []),
65+
],
66+
)
67+
def test_models_list_api(self, client: TestClient, model_type, expected):
68+
"""Get a list of model APIs to use with tests"""
69+
params = {"model_type": model_type} if model_type else {}
70+
response = client.get("/v1/models/api", headers=TEST_HEADERS, params=params)
71+
assert response.status_code == 200
72+
assert sorted(response.json()) == sorted(expected)
73+
5874
def models_list(self, client: TestClient):
5975
"""Get a list of bootstrapped models to use with tests"""
6076
response = client.get("/v1/models", headers=TEST_HEADERS)

0 commit comments

Comments
 (0)