Skip to content

Commit 1cbd853

Browse files
srdaspre-commit-ci[bot]dlqqq
authored
Add base API URL field for Ollama and OpenAI embedding models (#1136)
* Base API URL added for embedding models Jupyter AI currently allows the user to call a model at a URL (location) different from the default one by specifying a selected Base API URL. This can be done for Ollama, OpenAI provider models. However, for these providers, there is no way to change the API URL for embedding models when using the `/learn` command in RAG mode. This PR adds an extra field to make this feasible. Tested as follows for Ollama: [1] Start the Ollama system from port 11435 instead 11434 (the default): `OLLAMA_HOST=127.0.0.1:11435 ollama serve` [2] Set the Base API URL: [3] Check that the new API URL works: * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * allow embedding model fields to be saved * exclude empty str fields from config manager * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: David L. Qiu <david@qiu.dev>
1 parent 5ffe481 commit 1cbd853

File tree

4 files changed

+65
-23
lines changed

4 files changed

+65
-23
lines changed

packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/ollama.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from langchain_ollama import ChatOllama, OllamaEmbeddings
22

33
from ..embedding_providers import BaseEmbeddingsProvider
4-
from ..providers import BaseProvider, EnvAuthStrategy, TextField
4+
from ..providers import BaseProvider, TextField
55

66

77
class OllamaProvider(BaseProvider, ChatOllama):
@@ -23,10 +23,14 @@ class OllamaEmbeddingsProvider(BaseEmbeddingsProvider, OllamaEmbeddings):
2323
id = "ollama"
2424
name = "Ollama"
2525
# source: https://ollama.com/library
26+
model_id_key = "model"
2627
models = [
2728
"nomic-embed-text",
2829
"mxbai-embed-large",
2930
"all-minilm",
3031
"snowflake-arctic-embed",
3132
]
32-
model_id_key = "model"
33+
registry = True
34+
fields = [
35+
TextField(key="base_url", label="Base API URL (optional)", format="text"),
36+
]

packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,12 @@ class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider, OpenAIEmbeddings):
107107
model_id_key = "model"
108108
pypi_package_deps = ["langchain_openai"]
109109
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
110+
registry = True
111+
fields = [
112+
TextField(
113+
key="openai_api_base", label="Base API URL (optional)", format="text"
114+
),
115+
]
110116

111117

112118
class AzureOpenAIEmbeddingsProvider(BaseEmbeddingsProvider, AzureOpenAIEmbeddings):
@@ -122,5 +128,7 @@ class AzureOpenAIEmbeddingsProvider(BaseEmbeddingsProvider, AzureOpenAIEmbedding
122128
auth_strategy = EnvAuthStrategy(
123129
name="AZURE_OPENAI_API_KEY", keyword_param="openai_api_key"
124130
)
125-
126131
registry = True
132+
fields = [
133+
TextField(key="azure_endpoint", label="Base API URL (optional)", format="text"),
134+
]

packages/jupyter-ai/jupyter_ai/config_manager.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,13 @@ def _provider_params(self, key, listing, completions: bool = False):
462462
else:
463463
fields = config.fields.get(model_uid, {})
464464

465+
# exclude empty fields
466+
# TODO: modify the config manager to never save empty fields in the
467+
# first place.
468+
for field_key in fields:
469+
if isinstance(fields[field_key], str) and not len(fields[field_key]):
470+
fields[field_key] = None
471+
465472
# get authn fields
466473
_, Provider = get_em_provider(model_uid, listing)
467474
authn_fields = {}

packages/jupyter-ai/src/components/chat-settings.tsx

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element {
8888
const [apiKeys, setApiKeys] = useState<Record<string, string>>({});
8989
const [sendWse, setSendWse] = useState<boolean>(false);
9090
const [fields, setFields] = useState<Record<string, any>>({});
91+
const [embeddingModelFields, setEmbeddingModelFields] = useState<
92+
Record<string, any>
93+
>({});
9194

9295
const [isCompleterEnabled, setIsCompleterEnabled] = useState(
9396
props.completionProvider && props.completionProvider.isEnabled()
@@ -188,7 +191,15 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element {
188191
const currFields: Record<string, any> =
189192
server.config.fields?.[lmGlobalId] ?? {};
190193
setFields(currFields);
191-
}, [server, lmProvider]);
194+
195+
if (!emGlobalId) {
196+
return;
197+
}
198+
199+
const initEmbeddingModelFields: Record<string, any> =
200+
server.config.fields?.[emGlobalId] ?? {};
201+
setEmbeddingModelFields(initEmbeddingModelFields);
202+
}, [server, lmGlobalId, emGlobalId]);
192203

193204
const handleSave = async () => {
194205
// compress fields with JSON values
@@ -222,6 +233,9 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element {
222233
}),
223234
...(clmGlobalId && {
224235
[clmGlobalId]: fields
236+
}),
237+
...(emGlobalId && {
238+
[emGlobalId]: embeddingModelFields
225239
})
226240
}
227241
}),
@@ -376,26 +390,35 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element {
376390
{/* Embedding model section */}
377391
<h2 className="jp-ai-ChatSettings-header">Embedding model</h2>
378392
{server.emProviders.providers.length > 0 ? (
379-
<Select
380-
value={emGlobalId}
381-
label="Embedding model"
382-
onChange={e => {
383-
const emGid = e.target.value === 'null' ? null : e.target.value;
384-
setEmGlobalId(emGid);
385-
}}
386-
MenuProps={{ sx: { maxHeight: '50%', minHeight: 400 } }}
387-
>
388-
<MenuItem value="null">None</MenuItem>
389-
{server.emProviders.providers.map(emp =>
390-
emp.models
391-
.filter(em => em !== '*') // TODO: support registry providers
392-
.map(em => (
393-
<MenuItem value={`${emp.id}:${em}`}>
394-
{emp.name} :: {em}
395-
</MenuItem>
396-
))
393+
<Box>
394+
<Select
395+
value={emGlobalId}
396+
label="Embedding model"
397+
onChange={e => {
398+
const emGid = e.target.value === 'null' ? null : e.target.value;
399+
setEmGlobalId(emGid);
400+
}}
401+
MenuProps={{ sx: { maxHeight: '50%', minHeight: 400 } }}
402+
>
403+
<MenuItem value="null">None</MenuItem>
404+
{server.emProviders.providers.map(emp =>
405+
emp.models
406+
.filter(em => em !== '*') // TODO: support registry providers
407+
.map(em => (
408+
<MenuItem value={`${emp.id}:${em}`}>
409+
{emp.name} :: {em}
410+
</MenuItem>
411+
))
412+
)}
413+
</Select>
414+
{emGlobalId && (
415+
<ModelFields
416+
fields={emProvider?.fields}
417+
values={embeddingModelFields}
418+
onChange={setEmbeddingModelFields}
419+
/>
397420
)}
398-
</Select>
421+
</Box>
399422
) : (
400423
<p>No embedding models available.</p>
401424
)}

0 commit comments

Comments
 (0)