11import functools
2+ from .base_lm import BaseLM
23import logging
34import os
45import uuid
56from concurrent .futures import ThreadPoolExecutor
6- from dataclasses import dataclass
77from datetime import datetime
88from typing import Any , Dict , List , Literal , Optional
99
10+ import litellm
1011import ujson
11- from litellm import Router
12- from litellm .router import RetryPolicy
1312
1413from dspy .clients .finetune import FinetuneJob , TrainingMethod
1514from dspy .clients .lm_finetune_utils import execute_finetune_job , get_provider_finetune_job_class
1615from dspy .utils .callback import BaseCallback , with_callbacks
1716
18- from .base_lm import BaseLM
1917
2018logger = logging .getLogger (__name__ )
2119
@@ -34,7 +32,7 @@ def __init__(
3432 cache : bool = True ,
3533 launch_kwargs : Optional [Dict [str , Any ]] = None ,
3634 callbacks : Optional [List [BaseCallback ]] = None ,
37- num_retries : int = 8 ,
35+ num_retries : int = 3 ,
3836 ** kwargs ,
3937 ):
4038 """
@@ -165,55 +163,6 @@ def copy(self, **kwargs):
165163 return new_instance
166164
167165
168- @dataclass (frozen = True )
169- class _ProviderAPIConfig :
170- """
171- API configurations for a provider (e.g. OpenAI, Azure OpenAI)
172- """
173-
174- api_key : Optional [str ]
175- api_base : Optional [str ]
176- api_version : Optional [str ]
177- # Azure OpenAI with Azure AD auth requires an Azure AD token for authentication.
178- # For all other providers, this field is empty
179- azure_ad_token : Optional [str ]
180-
181-
182- def _extract_provider_api_config (model : str , llm_kwargs : Dict [str , Any ]) -> _ProviderAPIConfig :
183- """
184- Extract the API configurations from the specified LLM keyword arguments (`llm_kwargs`) for the
185- provider corresponding to the given model.
186-
187- Note: The API configurations are removed from the specified `llm_kwargs`, if present, mutating
188- the input dictionary.
189- """
190- provider = _get_provider (model )
191- api_key = llm_kwargs .pop ("api_key" , None ) or os .getenv (f"{ provider .upper ()} _API_KEY" )
192- api_base = llm_kwargs .pop ("api_base" , None ) or os .getenv (f"{ provider .upper ()} _API_BASE" )
193- api_version = llm_kwargs .pop ("api_version" , None ) or os .getenv (f"{ provider .upper ()} _API_VERSION" )
194- if "azure" in provider :
195- azure_ad_token = llm_kwargs .pop ("azure_ad_token" , None ) or os .getenv ("AZURE_AD_TOKEN" )
196- else :
197- azure_ad_token = None
198- return _ProviderAPIConfig (
199- api_key = api_key ,
200- api_base = api_base ,
201- api_version = api_version ,
202- azure_ad_token = azure_ad_token ,
203- )
204-
205-
206- def _get_provider (model : str ) -> str :
207- """
208- Extract the provider name from the model string of the format "<provider_name>/<model_name>",
209- e.g. "openai/gpt-4".
210-
211- TODO: Not all the models are in the format of "provider/model"
212- """
213- model = model .split ("/" , 1 )
214- return model [0 ] if len (model ) > 1 else "openai"
215-
216-
217166@functools .lru_cache (maxsize = None )
218167def cached_litellm_completion (request , num_retries : int ):
219168 return litellm_completion (
@@ -225,69 +174,13 @@ def cached_litellm_completion(request, num_retries: int):
225174
226175def litellm_completion (request , num_retries : int , cache = {"no-cache" : True , "no-store" : True }):
227176 kwargs = ujson .loads (request )
228- api_config = _extract_provider_api_config (model = kwargs ["model" ], llm_kwargs = kwargs )
229- router = _get_litellm_router (model = kwargs ["model" ], num_retries = num_retries , api_config = api_config )
230- return router .completion (
177+ return litellm .completion (
178+ num_retries = num_retries ,
231179 cache = cache ,
232180 ** kwargs ,
233181 )
234182
235183
236- @functools .lru_cache (maxsize = None )
237- def _get_litellm_router (model : str , num_retries : int , api_config : _ProviderAPIConfig ) -> Router :
238- """
239- Get a LiteLLM router for the given model with the specified number of retries
240- for transient errors.
241-
242- Args:
243- model: The name of the LiteLLM model to query (e.g. 'openai/gpt-4').
244- num_retries: The number of times to retry a request if it fails transiently due to
245- network error, rate limiting, etc. Requests are retried with exponential
246- backoff.
247- api_config: The API configurations (keys, base URL, etc.) for the provider
248- (OpenAI, Azure OpenAI, etc.) corresponding to the given model.
249-
250- Returns:
251- A LiteLLM router instance that can be used to query the given model.
252- """
253- retry_policy = RetryPolicy (
254- TimeoutErrorRetries = num_retries ,
255- RateLimitErrorRetries = num_retries ,
256- InternalServerErrorRetries = num_retries ,
257- # We don't retry on errors that are unlikely to be transient
258- # (e.g. bad request, invalid auth credentials)
259- BadRequestErrorRetries = 0 ,
260- AuthenticationErrorRetries = 0 ,
261- ContentPolicyViolationErrorRetries = 0 ,
262- )
263-
264- # LiteLLM routers must specify a `model_list`, which maps model names passed
265- # to `completions()` into actual LiteLLM model names. For our purposes, the
266- # model name is the same as the LiteLLM model name, so we add a single
267- # entry to the `model_list` that maps the model name to itself
268- litellm_params = {
269- "model" : model ,
270- }
271- if api_config .api_key is not None :
272- litellm_params ["api_key" ] = api_config .api_key
273- if api_config .api_base is not None :
274- litellm_params ["api_base" ] = api_config .api_base
275- if api_config .api_version is not None :
276- litellm_params ["api_version" ] = api_config .api_version
277- if api_config .azure_ad_token is not None :
278- litellm_params ["azure_ad_token" ] = api_config .azure_ad_token
279- model_list = [
280- {
281- "model_name" : model ,
282- "litellm_params" : litellm_params ,
283- }
284- ]
285- return Router (
286- model_list = model_list ,
287- retry_policy = retry_policy ,
288- )
289-
290-
291184@functools .lru_cache (maxsize = None )
292185def cached_litellm_text_completion (request , num_retries : int ):
293186 return litellm_text_completion (
@@ -299,18 +192,25 @@ def cached_litellm_text_completion(request, num_retries: int):
299192
300193def litellm_text_completion (request , num_retries : int , cache = {"no-cache" : True , "no-store" : True }):
301194 kwargs = ujson .loads (request )
302- model = kwargs .pop ("model" )
303- api_config = _extract_provider_api_config (model = model , llm_kwargs = kwargs )
304- model_name = model .split ("/" , 1 )[- 1 ]
305- text_completion_model_name = f"text-completion-openai/{ model_name } "
195+
196+ # Extract the provider and model from the model string.
197+ # TODO: Not all the models are in the format of "provider/model"
198+ model = kwargs .pop ("model" ).split ("/" , 1 )
199+ provider , model = model [0 ] if len (model ) > 1 else "openai" , model [- 1 ]
200+
201+ # Use the API key and base from the kwargs, or from the environment.
202+ api_key = kwargs .pop ("api_key" , None ) or os .getenv (f"{ provider } _API_KEY" )
203+ api_base = kwargs .pop ("api_base" , None ) or os .getenv (f"{ provider } _API_BASE" )
306204
307205 # Build the prompt from the messages.
308206 prompt = "\n \n " .join ([x ["content" ] for x in kwargs .pop ("messages" )] + ["BEGIN RESPONSE:" ])
309207
310- router = _get_litellm_router (model = text_completion_model_name , num_retries = num_retries , api_config = api_config )
311- return router .text_completion (
208+ return litellm .text_completion (
312209 cache = cache ,
313- model = text_completion_model_name ,
210+ model = f"text-completion-openai/{ model } " ,
211+ api_key = api_key ,
212+ api_base = api_base ,
314213 prompt = prompt ,
214+ num_retries = num_retries ,
315215 ** kwargs ,
316216 )
0 commit comments