Skip to content

Commit d7d6fae

Browse files
authored
Fixes for OpenAI / Azure OpenAI compatibility with LiteLLM router (#1760)
* fix Signed-off-by: dbczumar <corey.zumar@databricks.com> * fix Signed-off-by: dbczumar <corey.zumar@databricks.com> * fix Signed-off-by: dbczumar <corey.zumar@databricks.com> * Fix Signed-off-by: dbczumar <corey.zumar@databricks.com> --------- Signed-off-by: dbczumar <corey.zumar@databricks.com>
1 parent cadd619 commit d7d6fae

File tree

2 files changed

+141
-46
lines changed

2 files changed

+141
-46
lines changed

dspy/clients/lm.py

Lines changed: 83 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import uuid
55
from concurrent.futures import ThreadPoolExecutor
6+
from dataclasses import dataclass
67
from datetime import datetime
78
from typing import Any, Dict, List, Literal, Optional
89

@@ -164,6 +165,55 @@ def copy(self, **kwargs):
164165
return new_instance
165166

166167

168+
@dataclass(frozen=True)
169+
class _ProviderAPIConfig:
170+
"""
171+
API configurations for a provider (e.g. OpenAI, Azure OpenAI)
172+
"""
173+
174+
api_key: Optional[str]
175+
api_base: Optional[str]
176+
api_version: Optional[str]
177+
# Azure OpenAI with Azure AD auth requires an Azure AD token for authentication.
178+
# For all other providers, this field is empty
179+
azure_ad_token: Optional[str]
180+
181+
182+
def _extract_provider_api_config(model: str, llm_kwargs: Dict[str, Any]) -> _ProviderAPIConfig:
183+
"""
184+
Extract the API configurations from the specified LLM keyword arguments (`llm_kwargs`) for the
185+
provider corresponding to the given model.
186+
187+
Note: The API configurations are removed from the specified `llm_kwargs`, if present, mutating
188+
the input dictionary.
189+
"""
190+
provider = _get_provider(model)
191+
api_key = llm_kwargs.pop("api_key", None) or os.getenv(f"{provider.upper()}_API_KEY")
192+
api_base = llm_kwargs.pop("api_base", None) or os.getenv(f"{provider.upper()}_API_BASE")
193+
api_version = llm_kwargs.pop("api_version", None) or os.getenv(f"{provider.upper()}_API_VERSION")
194+
if "azure" in provider:
195+
azure_ad_token = llm_kwargs.pop("azure_ad_token", None) or os.getenv("AZURE_AD_TOKEN")
196+
else:
197+
azure_ad_token = None
198+
return _ProviderAPIConfig(
199+
api_key=api_key,
200+
api_base=api_base,
201+
api_version=api_version,
202+
azure_ad_token=azure_ad_token,
203+
)
204+
205+
206+
def _get_provider(model: str) -> str:
207+
"""
208+
Extract the provider name from the model string of the format "<provider_name>/<model_name>",
209+
e.g. "openai/gpt-4".
210+
211+
TODO: Not all the models are in the format of "provider/model"
212+
"""
213+
model = model.split("/", 1)
214+
return model[0] if len(model) > 1 else "openai"
215+
216+
167217
@functools.lru_cache(maxsize=None)
168218
def cached_litellm_completion(request, num_retries: int):
169219
return litellm_completion(
@@ -175,15 +225,16 @@ def cached_litellm_completion(request, num_retries: int):
175225

176226
def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}):
177227
kwargs = ujson.loads(request)
178-
router = _get_litellm_router(model=kwargs["model"], num_retries=num_retries)
228+
api_config = _extract_provider_api_config(model=kwargs["model"], llm_kwargs=kwargs)
229+
router = _get_litellm_router(model=kwargs["model"], num_retries=num_retries, api_config=api_config)
179230
return router.completion(
180231
cache=cache,
181232
**kwargs,
182233
)
183234

184235

185236
@functools.lru_cache(maxsize=None)
186-
def _get_litellm_router(model: str, num_retries: int) -> Router:
237+
def _get_litellm_router(model: str, num_retries: int, api_config: _ProviderAPIConfig) -> Router:
187238
"""
188239
Get a LiteLLM router for the given model with the specified number of retries
189240
for transient errors.
@@ -193,6 +244,9 @@ def _get_litellm_router(model: str, num_retries: int) -> Router:
193244
num_retries: The number of times to retry a request if it fails transiently due to
194245
network error, rate limiting, etc. Requests are retried with exponential
195246
backoff.
247+
api_config: The API configurations (keys, base URL, etc.) for the provider
248+
(OpenAI, Azure OpenAI, etc.) corresponding to the given model.
249+
196250
Returns:
197251
A LiteLLM router instance that can be used to query the given model.
198252
"""
@@ -207,19 +261,29 @@ def _get_litellm_router(model: str, num_retries: int) -> Router:
207261
ContentPolicyViolationErrorRetries=0,
208262
)
209263

264+
# LiteLLM routers must specify a `model_list`, which maps model names passed
265+
# to `completions()` into actual LiteLLM model names. For our purposes, the
266+
# model name is the same as the LiteLLM model name, so we add a single
267+
# entry to the `model_list` that maps the model name to itself
268+
litellm_params = {
269+
"model": model,
270+
}
271+
if api_config.api_key is not None:
272+
litellm_params["api_key"] = api_config.api_key
273+
if api_config.api_base is not None:
274+
litellm_params["api_base"] = api_config.api_base
275+
if api_config.api_version is not None:
276+
litellm_params["api_version"] = api_config.api_version
277+
if api_config.azure_ad_token is not None:
278+
litellm_params["azure_ad_token"] = api_config.azure_ad_token
279+
model_list = [
280+
{
281+
"model_name": model,
282+
"litellm_params": litellm_params,
283+
}
284+
]
210285
return Router(
211-
# LiteLLM routers must specify a `model_list`, which maps model names passed
212-
# to `completions()` into actual LiteLLM model names. For our purposes, the
213-
# model name is the same as the LiteLLM model name, so we add a single
214-
# entry to the `model_list` that maps the model name to itself
215-
model_list=[
216-
{
217-
"model_name": model,
218-
"litellm_params": {
219-
"model": model,
220-
},
221-
}
222-
],
286+
model_list=model_list,
223287
retry_policy=retry_policy,
224288
)
225289

@@ -235,26 +299,18 @@ def cached_litellm_text_completion(request, num_retries: int):
235299

236300
def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}):
237301
kwargs = ujson.loads(request)
238-
239-
# Extract the provider and model from the model string.
240-
# TODO: Not all the models are in the format of "provider/model"
241-
model = kwargs.pop("model").split("/", 1)
242-
provider, model = model[0] if len(model) > 1 else "openai", model[-1]
243-
text_completion_model_name = f"text-completion-openai/{model}"
244-
245-
# Use the API key and base from the kwargs, or from the environment.
246-
api_key = kwargs.pop("api_key", None) or os.getenv(f"{provider}_API_KEY")
247-
api_base = kwargs.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")
302+
model = kwargs.pop("model")
303+
api_config = _extract_provider_api_config(model=model, llm_kwargs=kwargs)
304+
model_name = model.split("/", 1)[-1]
305+
text_completion_model_name = f"text-completion-openai/{model_name}"
248306

249307
# Build the prompt from the messages.
250308
prompt = "\n\n".join([x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"])
251309

252-
router = _get_litellm_router(model=text_completion_model_name, num_retries=num_retries)
310+
router = _get_litellm_router(model=text_completion_model_name, num_retries=num_retries, api_config=api_config)
253311
return router.text_completion(
254312
cache=cache,
255313
model=text_completion_model_name,
256-
api_key=api_key,
257-
api_base=api_base,
258314
prompt=prompt,
259315
**kwargs,
260316
)

tests/clients/test_lm.py

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,40 @@
11
from unittest import mock
22

3+
import pytest
34
from litellm.router import RetryPolicy
45

56
from dspy.clients.lm import LM, _get_litellm_router
67

78

8-
def test_lm_chat_respects_max_retries():
9+
@pytest.mark.parametrize("keys_in_env_vars", [True, False])
10+
def test_lm_chat_respects_max_retries(keys_in_env_vars, monkeypatch):
911
model_name = "openai/gpt4o"
1012
num_retries = 17
1113
temperature = 0.5
1214
max_tokens = 100
1315
prompt = "Hello, world!"
16+
api_version = "2024-02-01"
17+
api_key = "apikey"
18+
19+
lm_kwargs = {
20+
"model": model_name,
21+
"model_type": "chat",
22+
"num_retries": num_retries,
23+
"temperature": temperature,
24+
"max_tokens": max_tokens,
25+
}
26+
if keys_in_env_vars:
27+
api_base = "http://testfromenv.com"
28+
monkeypatch.setenv("OPENAI_API_KEY", api_key)
29+
monkeypatch.setenv("OPENAI_API_BASE", api_base)
30+
monkeypatch.setenv("OPENAI_API_VERSION", api_version)
31+
else:
32+
api_base = "http://test.com"
33+
lm_kwargs["api_key"] = api_key
34+
lm_kwargs["api_base"] = api_base
35+
lm_kwargs["api_version"] = api_version
1436

15-
lm = LM(
16-
model=model_name, model_type="chat", num_retries=num_retries, temperature=temperature, max_tokens=max_tokens
17-
)
37+
lm = LM(**lm_kwargs)
1838

1939
MockRouter = mock.MagicMock()
2040
mock_completion = mock.MagicMock()
@@ -29,6 +49,9 @@ def test_lm_chat_respects_max_retries():
2949
"model_name": model_name,
3050
"litellm_params": {
3151
"model": model_name,
52+
"api_key": api_key,
53+
"api_base": api_base,
54+
"api_version": api_version,
3255
},
3356
}
3457
],
@@ -50,25 +73,39 @@ def test_lm_chat_respects_max_retries():
5073
)
5174

5275

53-
def test_lm_completions_respects_max_retries():
54-
model_name = "openai/gpt-3.5-turbo"
55-
expected_model = "text-completion-" + model_name
76+
@pytest.mark.parametrize("keys_in_env_vars", [True, False])
77+
def test_lm_completions_respects_max_retries(keys_in_env_vars, monkeypatch):
78+
model_name = "azure/gpt-3.5-turbo"
79+
expected_model = "text-completion-openai/" + model_name.split("/")[-1]
5680
num_retries = 17
5781
temperature = 0.5
5882
max_tokens = 100
5983
prompt = "Hello, world!"
60-
api_base = "http://test.com"
84+
api_version = "2024-02-01"
6185
api_key = "apikey"
86+
azure_ad_token = "adtoken"
87+
88+
lm_kwargs = {
89+
"model": model_name,
90+
"model_type": "text",
91+
"num_retries": num_retries,
92+
"temperature": temperature,
93+
"max_tokens": max_tokens,
94+
}
95+
if keys_in_env_vars:
96+
api_base = "http://testfromenv.com"
97+
monkeypatch.setenv("AZURE_API_KEY", api_key)
98+
monkeypatch.setenv("AZURE_API_BASE", api_base)
99+
monkeypatch.setenv("AZURE_API_VERSION", api_version)
100+
monkeypatch.setenv("AZURE_AD_TOKEN", azure_ad_token)
101+
else:
102+
api_base = "http://test.com"
103+
lm_kwargs["api_key"] = api_key
104+
lm_kwargs["api_base"] = api_base
105+
lm_kwargs["api_version"] = api_version
106+
lm_kwargs["azure_ad_token"] = azure_ad_token
62107

63-
lm = LM(
64-
model=model_name,
65-
model_type="text",
66-
num_retries=num_retries,
67-
temperature=temperature,
68-
max_tokens=max_tokens,
69-
api_base=api_base,
70-
api_key=api_key,
71-
)
108+
lm = LM(**lm_kwargs)
72109

73110
MockRouter = mock.MagicMock()
74111
mock_text_completion = mock.MagicMock()
@@ -83,6 +120,10 @@ def test_lm_completions_respects_max_retries():
83120
"model_name": expected_model,
84121
"litellm_params": {
85122
"model": expected_model,
123+
"api_key": api_key,
124+
"api_base": api_base,
125+
"api_version": api_version,
126+
"azure_ad_token": azure_ad_token,
86127
},
87128
}
88129
],
@@ -100,7 +141,5 @@ def test_lm_completions_respects_max_retries():
100141
prompt=prompt + "\n\nBEGIN RESPONSE:",
101142
temperature=temperature,
102143
max_tokens=max_tokens,
103-
api_key=api_key,
104-
api_base=api_base,
105144
cache=mock.ANY,
106145
)

0 commit comments

Comments
 (0)