Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/any_llm/any_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class AnyLLM(ABC):
For example, in `gemini` provider, this could include `google.genai.types.Tool`.
"""

ANY_API_KEY = "ANY_API_KEY"

def __init__(self, api_key: str | None = None, api_base: str | None = None, **kwargs: Any) -> None:
self._verify_no_missing_packages()
self._init_client(
Expand Down Expand Up @@ -142,6 +144,11 @@ def _create_provider(
raise ImportError(msg) from e

provider_class: type[AnyLLM] = getattr(module, provider_class_name)
if any_api_key := os.getenv(cls.ANY_API_KEY):
from any_llm.providers.any_api import AnyAPI

return AnyAPI(provider_class, any_api_key=any_api_key, api_base=api_base, **kwargs)

return provider_class(api_key=api_key, api_base=api_base, **kwargs)

@classmethod
Expand Down
40 changes: 40 additions & 0 deletions src/any_llm/providers/any_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from collections.abc import Iterator

from any_llm.any_llm import AnyLLM
from any_llm.types.completion import ChatCompletion, ChatCompletionChunk


class AnyAPI:
def __init__(self, any_api_key: str, provider_class: type[AnyLLM], api_base: str | None = None, **kwargs: Any) -> None:
self.any_api_key = any_api_key
self.provider_class = provider_class
self.api_base = api_base
self.kwargs = kwargs

self.provider_instance: AnyLLM | None = None
self.api_key = None

def _init_provider(self) -> None:
api_key = self.get_provider_key()
self.provider_instance = self.provider_class(api_key=api_key, api_base=self.api_base, **self.kwargs)

def get_provider_key(self) -> str:
raise NotImplementedError

def completion(
self,
**kwargs: Any,
) -> ChatCompletion | Iterator[ChatCompletionChunk]:
if self.provider_instance is None:
self._init_provider()
result = self.provider_instance.completion(**kwargs)
self.post_metadata(result)
return result

def post_metadata(self, result):
raise NotImplementedError
Loading