Skip to content

Commit 26e0d64

Browse files
authored
feat: Mistral: support auth via managed identity (#265) (#266)
1 parent 32d8816 commit 26e0d64

File tree

3 files changed

+15
-26
lines changed

3 files changed

+15
-26
lines changed

aidial_adapter_openai/endpoints/chat_completion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ async def call_chat_completion(
116116
)
117117

118118
case ChatCompletionDeploymentType.MISTRAL:
119-
return await mistral_chat_completion(data, upstream_endpoint, creds)
119+
return await mistral_chat_completion(data, endpoint, creds)
120120
case ChatCompletionDeploymentType.DATABRICKS:
121121
return await databricks_chat_completion(data, endpoint, creds)
122122

aidial_adapter_openai/mistral.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,25 @@
1-
from typing import Any
1+
from typing import Any, cast
22

3-
from openai import AsyncOpenAI, AsyncStream
3+
from openai import AsyncStream
44
from openai.types.chat.chat_completion import ChatCompletion
55
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
66

77
from aidial_adapter_openai.utils.auth import OpenAICreds
8-
from aidial_adapter_openai.utils.http_client import get_http_client
8+
from aidial_adapter_openai.utils.parsers import (
9+
AzureOpenAIEndpoint,
10+
OpenAIEndpoint,
11+
OpenAIParams,
12+
)
913
from aidial_adapter_openai.utils.reflection import call_with_extra_body
1014
from aidial_adapter_openai.utils.streaming import chunk_to_dict, map_stream
1115

1216

1317
async def chat_completion(
14-
data: Any, upstream_endpoint: str, creds: OpenAICreds
18+
data: Any,
19+
endpoint: AzureOpenAIEndpoint | OpenAIEndpoint,
20+
creds: OpenAICreds,
1521
):
16-
client = AsyncOpenAI(
17-
base_url=upstream_endpoint,
18-
api_key=creds.get("api_key"),
19-
http_client=get_http_client(),
20-
)
22+
client = endpoint.get_client(cast(OpenAIParams, creds))
2123

2224
response: AsyncStream[ChatCompletionChunk] | ChatCompletion = (
2325
await call_with_extra_body(client.chat.completions.create, data)

aidial_adapter_openai/utils/parsers.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from fastapi import Request
77
from openai import AsyncAzureOpenAI, AsyncOpenAI, Timeout
88

9-
from aidial_adapter_openai.utils.auth import OpenAICreds
109
from aidial_adapter_openai.utils.http_client import get_http_client
1110
from aidial_adapter_openai.utils.pydantic import ExtraForbidModel
1211

@@ -44,33 +43,21 @@ def get_client(self, params: OpenAIParams) -> AsyncAzureOpenAI:
4443
http_client=get_http_client(),
4544
)
4645

47-
def get_auth_headers(self, creds: OpenAICreds) -> dict[str, str]:
48-
if key := creds.get("api_key"):
49-
return {"api-key": key}
50-
51-
if token := creds.get("azure_ad_token"):
52-
return {"Authorization": f"Bearer {token}"}
53-
54-
raise ValueError("Invalid credentials")
55-
5646

5747
class OpenAIEndpoint(ExtraForbidModel):
5848
base_url: str
5949

6050
def get_client(self, params: OpenAIParams) -> AsyncOpenAI:
51+
api_key = params.get("api_key") or params.get("azure_ad_token")
52+
6153
return AsyncOpenAI(
6254
base_url=self.base_url,
63-
api_key=params.get("api_key"),
55+
api_key=api_key,
6456
timeout=params.get("timeout"),
6557
max_retries=_MAX_RETRIES,
6658
http_client=get_http_client(),
6759
)
6860

69-
def get_auth_headers(self, creds: OpenAICreds) -> dict[str, str]:
70-
if key := (creds.get("api_key") or creds.get("azure_ad_token")):
71-
return {"Authorization": f"Bearer {key}"}
72-
raise ValueError("Invalid credentials")
73-
7461

7562
def _parse_endpoint(
7663
name: str | None, endpoint: str

0 commit comments

Comments
 (0)