11from __future__ import annotations
22
3+ import functools
34import time
4- from functools import lru_cache
55
66from google .ai .generativelanguage_v1 import GenerateContentResponse
77from google .ai .generativelanguage_v1 .services import generative_service , model_service
2323from patchwork .common .client .llm .protocol import NOT_GIVEN , LlmClient , NotGiven
2424
2525
26+ @functools .lru_cache
27+ def _cached_list_model_from_google (api_key ):
28+ model_client = model_service .ModelServiceClient (
29+ client_options = dict (
30+ api_key = api_key ,
31+ # quota_project_id="",
32+ )
33+ )
34+
35+ request = ListModelsRequest ()
36+ response = model_client .list_models (request )
37+
38+ models = set ()
39+ for page in response .pages :
40+ models .update (map (lambda x : x .name , page .models ))
41+
42+ return models
43+
44+
2645class GoogleLlmClient (LlmClient ):
2746 __SAFETY_SETTINGS = [
2847 dict (category = "HARM_CATEGORY_HATE_SPEECH" , threshold = "BLOCK_NONE" ),
@@ -32,36 +51,22 @@ class GoogleLlmClient(LlmClient):
3251 ]
3352
3453 def __init__ (self , api_key : str ):
35- self .model_client = model_service .ModelServiceClient (
36- client_options = dict (
37- api_key = api_key ,
38- # quota_project_id="",
39- )
40- )
54+ self .__api_key = api_key
4155 self .generative_client = generative_service .GenerativeServiceClient (
4256 client_options = dict (
4357 api_key = api_key ,
4458 # quota_project_id="",
4559 )
4660 )
4761
48- @lru_cache (maxsize = None )
4962 def __get_true_model_names (self ) -> set [str ]:
50- request = ListModelsRequest ()
51- response = self .model_client .list_models (request )
52-
53- models = set ()
54- for page in response .pages :
55- models .update (map (lambda x : x .name , page .models ))
56-
57- return models
63+ return _cached_list_model_from_google (self .__api_key )
5864
5965 @staticmethod
6066 def __handle_model_name (model_name ) -> str :
6167 _ , _ , model = model_name .rpartition ("/" )
6268 return model
6369
64- @lru_cache (maxsize = None )
6570 def get_models (self ) -> set [str ]:
6671 models = self .__get_true_model_names ()
6772 return set (map (self .__handle_model_name , models ))
0 commit comments