Skip to content

Commit e654062

Browse files
authored
Revert LiteLLM Router-based retries and upgrade poetry lock for litellm 1.51.0 (#1762)
* Revert LiteLLM Router-based retries and upgrade poetry lock for litellm 1.51.0 * Temporarily remove retry tests * fix test
1 parent 1864a27 commit e654062

File tree

3 files changed

+167
-267
lines changed

3 files changed

+167
-267
lines changed

dspy/clients/lm.py

Lines changed: 19 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,19 @@
11
import functools
2+
from .base_lm import BaseLM
23
import logging
34
import os
45
import uuid
56
from concurrent.futures import ThreadPoolExecutor
6-
from dataclasses import dataclass
77
from datetime import datetime
88
from typing import Any, Dict, List, Literal, Optional
99

10+
import litellm
1011
import ujson
11-
from litellm import Router
12-
from litellm.router import RetryPolicy
1312

1413
from dspy.clients.finetune import FinetuneJob, TrainingMethod
1514
from dspy.clients.lm_finetune_utils import execute_finetune_job, get_provider_finetune_job_class
1615
from dspy.utils.callback import BaseCallback, with_callbacks
1716

18-
from .base_lm import BaseLM
1917

2018
logger = logging.getLogger(__name__)
2119

@@ -34,7 +32,7 @@ def __init__(
3432
cache: bool = True,
3533
launch_kwargs: Optional[Dict[str, Any]] = None,
3634
callbacks: Optional[List[BaseCallback]] = None,
37-
num_retries: int = 8,
35+
num_retries: int = 3,
3836
**kwargs,
3937
):
4038
"""
@@ -165,55 +163,6 @@ def copy(self, **kwargs):
165163
return new_instance
166164

167165

168-
@dataclass(frozen=True)
169-
class _ProviderAPIConfig:
170-
"""
171-
API configurations for a provider (e.g. OpenAI, Azure OpenAI)
172-
"""
173-
174-
api_key: Optional[str]
175-
api_base: Optional[str]
176-
api_version: Optional[str]
177-
# Azure OpenAI with Azure AD auth requires an Azure AD token for authentication.
178-
# For all other providers, this field is empty
179-
azure_ad_token: Optional[str]
180-
181-
182-
def _extract_provider_api_config(model: str, llm_kwargs: Dict[str, Any]) -> _ProviderAPIConfig:
183-
"""
184-
Extract the API configurations from the specified LLM keyword arguments (`llm_kwargs`) for the
185-
provider corresponding to the given model.
186-
187-
Note: The API configurations are removed from the specified `llm_kwargs`, if present, mutating
188-
the input dictionary.
189-
"""
190-
provider = _get_provider(model)
191-
api_key = llm_kwargs.pop("api_key", None) or os.getenv(f"{provider.upper()}_API_KEY")
192-
api_base = llm_kwargs.pop("api_base", None) or os.getenv(f"{provider.upper()}_API_BASE")
193-
api_version = llm_kwargs.pop("api_version", None) or os.getenv(f"{provider.upper()}_API_VERSION")
194-
if "azure" in provider:
195-
azure_ad_token = llm_kwargs.pop("azure_ad_token", None) or os.getenv("AZURE_AD_TOKEN")
196-
else:
197-
azure_ad_token = None
198-
return _ProviderAPIConfig(
199-
api_key=api_key,
200-
api_base=api_base,
201-
api_version=api_version,
202-
azure_ad_token=azure_ad_token,
203-
)
204-
205-
206-
def _get_provider(model: str) -> str:
207-
"""
208-
Extract the provider name from the model string of the format "<provider_name>/<model_name>",
209-
e.g. "openai/gpt-4".
210-
211-
TODO: Not all the models are in the format of "provider/model"
212-
"""
213-
model = model.split("/", 1)
214-
return model[0] if len(model) > 1 else "openai"
215-
216-
217166
@functools.lru_cache(maxsize=None)
218167
def cached_litellm_completion(request, num_retries: int):
219168
return litellm_completion(
@@ -225,69 +174,13 @@ def cached_litellm_completion(request, num_retries: int):
225174

226175
def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}):
227176
kwargs = ujson.loads(request)
228-
api_config = _extract_provider_api_config(model=kwargs["model"], llm_kwargs=kwargs)
229-
router = _get_litellm_router(model=kwargs["model"], num_retries=num_retries, api_config=api_config)
230-
return router.completion(
177+
return litellm.completion(
178+
num_retries=num_retries,
231179
cache=cache,
232180
**kwargs,
233181
)
234182

235183

236-
@functools.lru_cache(maxsize=None)
237-
def _get_litellm_router(model: str, num_retries: int, api_config: _ProviderAPIConfig) -> Router:
238-
"""
239-
Get a LiteLLM router for the given model with the specified number of retries
240-
for transient errors.
241-
242-
Args:
243-
model: The name of the LiteLLM model to query (e.g. 'openai/gpt-4').
244-
num_retries: The number of times to retry a request if it fails transiently due to
245-
network error, rate limiting, etc. Requests are retried with exponential
246-
backoff.
247-
api_config: The API configurations (keys, base URL, etc.) for the provider
248-
(OpenAI, Azure OpenAI, etc.) corresponding to the given model.
249-
250-
Returns:
251-
A LiteLLM router instance that can be used to query the given model.
252-
"""
253-
retry_policy = RetryPolicy(
254-
TimeoutErrorRetries=num_retries,
255-
RateLimitErrorRetries=num_retries,
256-
InternalServerErrorRetries=num_retries,
257-
# We don't retry on errors that are unlikely to be transient
258-
# (e.g. bad request, invalid auth credentials)
259-
BadRequestErrorRetries=0,
260-
AuthenticationErrorRetries=0,
261-
ContentPolicyViolationErrorRetries=0,
262-
)
263-
264-
# LiteLLM routers must specify a `model_list`, which maps model names passed
265-
# to `completions()` into actual LiteLLM model names. For our purposes, the
266-
# model name is the same as the LiteLLM model name, so we add a single
267-
# entry to the `model_list` that maps the model name to itself
268-
litellm_params = {
269-
"model": model,
270-
}
271-
if api_config.api_key is not None:
272-
litellm_params["api_key"] = api_config.api_key
273-
if api_config.api_base is not None:
274-
litellm_params["api_base"] = api_config.api_base
275-
if api_config.api_version is not None:
276-
litellm_params["api_version"] = api_config.api_version
277-
if api_config.azure_ad_token is not None:
278-
litellm_params["azure_ad_token"] = api_config.azure_ad_token
279-
model_list = [
280-
{
281-
"model_name": model,
282-
"litellm_params": litellm_params,
283-
}
284-
]
285-
return Router(
286-
model_list=model_list,
287-
retry_policy=retry_policy,
288-
)
289-
290-
291184
@functools.lru_cache(maxsize=None)
292185
def cached_litellm_text_completion(request, num_retries: int):
293186
return litellm_text_completion(
@@ -299,18 +192,25 @@ def cached_litellm_text_completion(request, num_retries: int):
299192

300193
def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}):
301194
kwargs = ujson.loads(request)
302-
model = kwargs.pop("model")
303-
api_config = _extract_provider_api_config(model=model, llm_kwargs=kwargs)
304-
model_name = model.split("/", 1)[-1]
305-
text_completion_model_name = f"text-completion-openai/{model_name}"
195+
196+
# Extract the provider and model from the model string.
197+
# TODO: Not all the models are in the format of "provider/model"
198+
model = kwargs.pop("model").split("/", 1)
199+
provider, model = model[0] if len(model) > 1 else "openai", model[-1]
200+
201+
# Use the API key and base from the kwargs, or from the environment.
202+
api_key = kwargs.pop("api_key", None) or os.getenv(f"{provider}_API_KEY")
203+
api_base = kwargs.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")
306204

307205
# Build the prompt from the messages.
308206
prompt = "\n\n".join([x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"])
309207

310-
router = _get_litellm_router(model=text_completion_model_name, num_retries=num_retries, api_config=api_config)
311-
return router.text_completion(
208+
return litellm.text_completion(
312209
cache=cache,
313-
model=text_completion_model_name,
210+
model=f"text-completion-openai/{model}",
211+
api_key=api_key,
212+
api_base=api_base,
314213
prompt=prompt,
214+
num_retries=num_retries,
315215
**kwargs,
316216
)

poetry.lock

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)