Skip to content

Commit cdb7d61

Browse files
authored
Rationalise Call llm (#509)
* use aio llm client to support anthropic * start of single * change all the typed * pretty up * add debugging logs * bump version * fix test * redefine once outputs
1 parent 5a65576 commit cdb7d61

File tree

15 files changed

+448
-339
lines changed

15 files changed

+448
-339
lines changed

patchwork/common/client/llm/aio.py

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing_extensions import Dict, Iterable, List, Optional, Union
99

1010
from patchwork.common.client.llm.protocol import NOT_GIVEN, LlmClient, NotGiven
11+
from patchwork.logger import logger
1112

1213

1314
class AioLlmClient(LlmClient):
@@ -46,6 +47,7 @@ def chat_completion(
4647
) -> ChatCompletion:
4748
for client in self.__clients:
4849
if client.is_model_supported(model):
50+
logger.debug(f"Using {client.__class__.__name__} for model {model}")
4951
return client.chat_completion(
5052
messages,
5153
model,

patchwork/common/client/llm/google.py

+22-17
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

3+
import functools
34
import time
4-
from functools import lru_cache
55

66
from google.ai.generativelanguage_v1 import GenerateContentResponse
77
from google.ai.generativelanguage_v1.services import generative_service, model_service
@@ -23,6 +23,25 @@
2323
from patchwork.common.client.llm.protocol import NOT_GIVEN, LlmClient, NotGiven
2424

2525

26+
@functools.lru_cache
27+
def _cached_list_model_from_google(api_key):
28+
model_client = model_service.ModelServiceClient(
29+
client_options=dict(
30+
api_key=api_key,
31+
# quota_project_id="",
32+
)
33+
)
34+
35+
request = ListModelsRequest()
36+
response = model_client.list_models(request)
37+
38+
models = set()
39+
for page in response.pages:
40+
models.update(map(lambda x: x.name, page.models))
41+
42+
return models
43+
44+
2645
class GoogleLlmClient(LlmClient):
2746
__SAFETY_SETTINGS = [
2847
dict(category="HARM_CATEGORY_HATE_SPEECH", threshold="BLOCK_NONE"),
@@ -32,36 +51,22 @@ class GoogleLlmClient(LlmClient):
3251
]
3352

3453
def __init__(self, api_key: str):
35-
self.model_client = model_service.ModelServiceClient(
36-
client_options=dict(
37-
api_key=api_key,
38-
# quota_project_id="",
39-
)
40-
)
54+
self.__api_key = api_key
4155
self.generative_client = generative_service.GenerativeServiceClient(
4256
client_options=dict(
4357
api_key=api_key,
4458
# quota_project_id="",
4559
)
4660
)
4761

48-
@lru_cache(maxsize=None)
4962
def __get_true_model_names(self) -> set[str]:
50-
request = ListModelsRequest()
51-
response = self.model_client.list_models(request)
52-
53-
models = set()
54-
for page in response.pages:
55-
models.update(map(lambda x: x.name, page.models))
56-
57-
return models
63+
return _cached_list_model_from_google(self.__api_key)
5864

5965
@staticmethod
6066
def __handle_model_name(model_name) -> str:
6167
_, _, model = model_name.rpartition("/")
6268
return model
6369

64-
@lru_cache(maxsize=None)
6570
def get_models(self) -> set[str]:
6671
models = self.__get_true_model_names()
6772
return set(map(self.__handle_model_name, models))

patchwork/common/client/llm/openai.py

+27-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from functools import lru_cache
3+
import functools
44

55
from openai import OpenAI
66
from openai.types.chat import (
@@ -13,22 +13,39 @@
1313
from patchwork.common.client.llm.protocol import NOT_GIVEN, LlmClient, NotGiven
1414

1515

16+
@functools.lru_cache
17+
def _cached_list_models_from_openai(api_key):
18+
client = OpenAI(api_key=api_key)
19+
sync_page = client.models.list()
20+
21+
models = set()
22+
for pages in sync_page.iter_pages():
23+
models.update(map(lambda x: x.id, pages.data))
24+
25+
return models
26+
27+
1628
class OpenAiLlmClient(LlmClient):
17-
def __init__(self, api_key: str):
29+
def __init__(self, api_key: str, base_url=None):
1830
self.api_key = api_key
19-
self.client = OpenAI(api_key=api_key)
31+
self.base_url = base_url
32+
self.client = OpenAI(api_key=api_key, base_url=base_url)
2033

21-
@lru_cache(maxsize=None)
22-
def get_models(self) -> set[str]:
23-
sync_page = self.client.models.list()
34+
def __is_not_openai_url(self):
35+
# Some providers/apis only implement the chat completion endpoint.
36+
# We mainly use this to skip using the model endpoints.
37+
return self.base_url is not None and self.base_url != "https://api.openai.com/v1"
2438

25-
models = set()
26-
for pages in sync_page.iter_pages():
27-
models.update(map(lambda x: x.id, pages.data))
39+
def get_models(self) -> set[str]:
40+
if self.__is_not_openai_url():
41+
return set()
2842

29-
return models
43+
return _cached_list_models_from_openai(self.api_key)
3044

3145
def is_model_supported(self, model: str) -> bool:
46+
# might not implement model endpoint
47+
if self.__is_not_openai_url():
48+
return True
3249
return model in self.get_models()
3350

3451
def chat_completion(

0 commit comments

Comments
 (0)