diff --git a/src/any_llm/any_llm.py b/src/any_llm/any_llm.py index b60fb10f..e4bd69ac 100644 --- a/src/any_llm/any_llm.py +++ b/src/any_llm/any_llm.py @@ -82,6 +82,8 @@ class AnyLLM(ABC): For example, in `gemini` provider, this could include `google.genai.types.Tool`. """ + ANY_API_KEY = "ANY_API_KEY" + def __init__(self, api_key: str | None = None, api_base: str | None = None, **kwargs: Any) -> None: self._verify_no_missing_packages() self._init_client( @@ -142,6 +144,11 @@ def _create_provider( raise ImportError(msg) from e provider_class: type[AnyLLM] = getattr(module, provider_class_name) + if any_api_key := os.getenv(cls.ANY_API_KEY): + from any_llm.providers.any_api import AnyAPI + + return AnyAPI(provider_class, any_api_key=any_api_key, api_base=api_base, **kwargs) + return provider_class(api_key=api_key, api_base=api_base, **kwargs) @classmethod diff --git a/src/any_llm/providers/any_api.py b/src/any_llm/providers/any_api.py new file mode 100644 index 00000000..bdcbeb49 --- /dev/null +++ b/src/any_llm/providers/any_api.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Iterator + + from any_llm.any_llm import AnyLLM + from any_llm.types.completion import ChatCompletion, ChatCompletionChunk + + +class AnyAPI: + def __init__(self, any_api_key: str, provider_class: type[AnyLLM], api_base: str | None = None, **kwargs: Any) -> None: + self.any_api_key = any_api_key + self.provider_class = provider_class + self.api_base = api_base + self.kwargs = kwargs + + self.provider_instance: AnyLLM | None = None + self.api_key = None + + def _init_provider(self) -> None: + api_key = self.get_provider_key() + self.provider_instance = self.provider_class(api_key=api_key, api_base=self.api_base, **self.kwargs) + + def get_provider_key(self) -> str: + raise NotImplementedError + + def completion( + self, + **kwargs: Any, + ) -> ChatCompletion | Iterator[ChatCompletionChunk]: + if self.provider_instance is None: + self._init_provider() + result = self.provider_instance.completion(**kwargs) + self.post_metadata(result) + return result + + def post_metadata(self, result): + raise NotImplementedError