33import os
44import uuid
55from concurrent .futures import ThreadPoolExecutor
6+ from dataclasses import dataclass
67from datetime import datetime
78from typing import Any , Dict , List , Literal , Optional
89
@@ -164,6 +165,55 @@ def copy(self, **kwargs):
164165 return new_instance
165166
166167
168+ @dataclass (frozen = True )
169+ class _ProviderAPIConfig :
170+ """
171+ API configurations for a provider (e.g. OpenAI, Azure OpenAI)
172+ """
173+
174+ api_key : Optional [str ]
175+ api_base : Optional [str ]
176+ api_version : Optional [str ]
177+ # Azure OpenAI with Azure AD auth requires an Azure AD token for authentication.
178+ # For all other providers, this field is empty
179+ azure_ad_token : Optional [str ]
180+
181+
182+ def _extract_provider_api_config (model : str , llm_kwargs : Dict [str , Any ]) -> _ProviderAPIConfig :
183+ """
184+ Extract the API configurations from the specified LLM keyword arguments (`llm_kwargs`) for the
185+ provider corresponding to the given model.
186+
187+ Note: The API configurations are removed from the specified `llm_kwargs`, if present, mutating
188+ the input dictionary.
189+ """
190+ provider = _get_provider (model )
191+ api_key = llm_kwargs .pop ("api_key" , None ) or os .getenv (f"{ provider .upper ()} _API_KEY" )
192+ api_base = llm_kwargs .pop ("api_base" , None ) or os .getenv (f"{ provider .upper ()} _API_BASE" )
193+ api_version = llm_kwargs .pop ("api_version" , None ) or os .getenv (f"{ provider .upper ()} _API_VERSION" )
194+ if "azure" in provider :
195+ azure_ad_token = llm_kwargs .pop ("azure_ad_token" , None ) or os .getenv ("AZURE_AD_TOKEN" )
196+ else :
197+ azure_ad_token = None
198+ return _ProviderAPIConfig (
199+ api_key = api_key ,
200+ api_base = api_base ,
201+ api_version = api_version ,
202+ azure_ad_token = azure_ad_token ,
203+ )
204+
205+
206+ def _get_provider (model : str ) -> str :
207+ """
208+ Extract the provider name from the model string of the format "<provider_name>/<model_name>",
209+ e.g. "openai/gpt-4".
210+
211+ TODO: Not all the models are in the format of "provider/model"
212+ """
213+ model = model .split ("/" , 1 )
214+ return model [0 ] if len (model ) > 1 else "openai"
215+
216+
167217@functools .lru_cache (maxsize = None )
168218def cached_litellm_completion (request , num_retries : int ):
169219 return litellm_completion (
@@ -175,15 +225,16 @@ def cached_litellm_completion(request, num_retries: int):
175225
176226def litellm_completion (request , num_retries : int , cache = {"no-cache" : True , "no-store" : True }):
177227 kwargs = ujson .loads (request )
178- router = _get_litellm_router (model = kwargs ["model" ], num_retries = num_retries )
228+ api_config = _extract_provider_api_config (model = kwargs ["model" ], llm_kwargs = kwargs )
229+ router = _get_litellm_router (model = kwargs ["model" ], num_retries = num_retries , api_config = api_config )
179230 return router .completion (
180231 cache = cache ,
181232 ** kwargs ,
182233 )
183234
184235
185236@functools .lru_cache (maxsize = None )
186- def _get_litellm_router (model : str , num_retries : int ) -> Router :
237+ def _get_litellm_router (model : str , num_retries : int , api_config : _ProviderAPIConfig ) -> Router :
187238 """
188239 Get a LiteLLM router for the given model with the specified number of retries
189240 for transient errors.
@@ -193,6 +244,9 @@ def _get_litellm_router(model: str, num_retries: int) -> Router:
193244 num_retries: The number of times to retry a request if it fails transiently due to
194245 network error, rate limiting, etc. Requests are retried with exponential
195246 backoff.
247+ api_config: The API configurations (keys, base URL, etc.) for the provider
248+ (OpenAI, Azure OpenAI, etc.) corresponding to the given model.
249+
196250 Returns:
197251 A LiteLLM router instance that can be used to query the given model.
198252 """
@@ -207,19 +261,29 @@ def _get_litellm_router(model: str, num_retries: int) -> Router:
207261 ContentPolicyViolationErrorRetries = 0 ,
208262 )
209263
264+ # LiteLLM routers must specify a `model_list`, which maps model names passed
265+ # to `completions()` into actual LiteLLM model names. For our purposes, the
266+ # model name is the same as the LiteLLM model name, so we add a single
267+ # entry to the `model_list` that maps the model name to itself
268+ litellm_params = {
269+ "model" : model ,
270+ }
271+ if api_config .api_key is not None :
272+ litellm_params ["api_key" ] = api_config .api_key
273+ if api_config .api_base is not None :
274+ litellm_params ["api_base" ] = api_config .api_base
275+ if api_config .api_version is not None :
276+ litellm_params ["api_version" ] = api_config .api_version
277+ if api_config .azure_ad_token is not None :
278+ litellm_params ["azure_ad_token" ] = api_config .azure_ad_token
279+ model_list = [
280+ {
281+ "model_name" : model ,
282+ "litellm_params" : litellm_params ,
283+ }
284+ ]
210285 return Router (
211- # LiteLLM routers must specify a `model_list`, which maps model names passed
212- # to `completions()` into actual LiteLLM model names. For our purposes, the
213- # model name is the same as the LiteLLM model name, so we add a single
214- # entry to the `model_list` that maps the model name to itself
215- model_list = [
216- {
217- "model_name" : model ,
218- "litellm_params" : {
219- "model" : model ,
220- },
221- }
222- ],
286+ model_list = model_list ,
223287 retry_policy = retry_policy ,
224288 )
225289
@@ -235,26 +299,18 @@ def cached_litellm_text_completion(request, num_retries: int):
235299
236300def litellm_text_completion (request , num_retries : int , cache = {"no-cache" : True , "no-store" : True }):
237301 kwargs = ujson .loads (request )
238-
239- # Extract the provider and model from the model string.
240- # TODO: Not all the models are in the format of "provider/model"
241- model = kwargs .pop ("model" ).split ("/" , 1 )
242- provider , model = model [0 ] if len (model ) > 1 else "openai" , model [- 1 ]
243- text_completion_model_name = f"text-completion-openai/{ model } "
244-
245- # Use the API key and base from the kwargs, or from the environment.
246- api_key = kwargs .pop ("api_key" , None ) or os .getenv (f"{ provider } _API_KEY" )
247- api_base = kwargs .pop ("api_base" , None ) or os .getenv (f"{ provider } _API_BASE" )
302+ model = kwargs .pop ("model" )
303+ api_config = _extract_provider_api_config (model = model , llm_kwargs = kwargs )
304+ model_name = model .split ("/" , 1 )[- 1 ]
305+ text_completion_model_name = f"text-completion-openai/{ model_name } "
248306
249307 # Build the prompt from the messages.
250308 prompt = "\n \n " .join ([x ["content" ] for x in kwargs .pop ("messages" )] + ["BEGIN RESPONSE:" ])
251309
252- router = _get_litellm_router (model = text_completion_model_name , num_retries = num_retries )
310+ router = _get_litellm_router (model = text_completion_model_name , num_retries = num_retries , api_config = api_config )
253311 return router .text_completion (
254312 cache = cache ,
255313 model = text_completion_model_name ,
256- api_key = api_key ,
257- api_base = api_base ,
258314 prompt = prompt ,
259315 ** kwargs ,
260316 )
0 commit comments