Skip to content

fix: converts litellm generate config err #1509

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 50 additions & 10 deletions src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Dict
from typing import Generator
from typing import Iterable
from typing import List
from typing import Literal
from typing import Optional
from typing import Tuple
Expand Down Expand Up @@ -481,16 +482,22 @@ def _message_to_generate_content_response(

def _get_completion_inputs(
llm_request: LlmRequest,
) -> tuple[Iterable[Message], Iterable[dict]]:
"""Converts an LlmRequest to litellm inputs.
) -> Tuple[
List[Message],
Optional[List[Dict]],
Optional[types.SchemaUnion],
Optional[Dict],
]:
"""Converts an LlmRequest to litellm inputs and extracts generation params.

Args:
llm_request: The LlmRequest to convert.

Returns:
The litellm inputs (message list, tool dictionary and response format).
The litellm inputs (message list, tool dictionary, response format and generation params).
"""
messages = []
# 1. Construct messages
messages: List[Message] = []
for content in llm_request.contents or []:
message_param_or_list = _content_to_message_param(content)
if isinstance(message_param_or_list, list):
Expand All @@ -507,7 +514,8 @@ def _get_completion_inputs(
),
)

tools = None
# 2. Convert tool declarations
tools: Optional[List[Dict]] = None
if (
llm_request.config
and llm_request.config.tools
Expand All @@ -518,12 +526,39 @@ def _get_completion_inputs(
for tool in llm_request.config.tools[0].function_declarations
]

response_format = None

if llm_request.config.response_schema:
# 3. Handle response format
response_format: Optional[types.SchemaUnion] = None
if llm_request.config and llm_request.config.response_schema:
response_format = llm_request.config.response_schema

return messages, tools, response_format
# 4. Extract generation parameters
generation_params: Optional[Dict] = None
if llm_request.config:
config_dict = llm_request.config.model_dump(exclude_none=True)
# Generate LiteLlm parameters here,
# Following https://docs.litellm.ai/docs/completion/input.
generation_params = {}
param_mapping = {
"max_output_tokens": "max_completion_tokens",
"stop_sequences": "stop",
}
for key in (
"temperature",
"max_output_tokens",
"top_p",
"top_k",
"stop_sequences",
"presence_penalty",
"frequency_penalty",
):
if key in config_dict:
mapped_key = param_mapping.get(key, key)
generation_params[mapped_key] = config_dict[key]

if not generation_params:
generation_params = None

return messages, tools, response_format, generation_params


def _build_function_declaration_log(
Expand Down Expand Up @@ -660,7 +695,9 @@ async def generate_content_async(
self._maybe_append_user_content(llm_request)
logger.debug(_build_request_log(llm_request))

messages, tools, response_format = _get_completion_inputs(llm_request)
messages, tools, response_format, generation_params = (
_get_completion_inputs(llm_request)
)

completion_args = {
"model": self.model,
Expand All @@ -670,6 +707,9 @@ async def generate_content_async(
}
completion_args.update(self._additional_args)

if generation_params:
completion_args.update(generation_params)

if stream:
text = ""
# Track function calls by index
Expand Down
32 changes: 32 additions & 0 deletions tests/unittests/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1430,3 +1430,35 @@ async def test_generate_content_async_non_compliant_multiple_function_calls(
assert final_response.content.parts[1].function_call.name == "function_2"
assert final_response.content.parts[1].function_call.id == "1"
assert final_response.content.parts[1].function_call.args == {"arg": "value2"}


@pytest.mark.asyncio
def test_get_completion_inputs_generation_params():
# Test that generation_params are extracted and mapped correctly
req = LlmRequest(
contents=[
types.Content(role="user", parts=[types.Part.from_text(text="hi")]),
],
config=types.GenerateContentConfig(
temperature=0.33,
max_output_tokens=123,
top_p=0.88,
top_k=7,
stop_sequences=["foo", "bar"],
presence_penalty=0.1,
frequency_penalty=0.2,
),
)
from google.adk.models.lite_llm import _get_completion_inputs

_, _, _, generation_params = _get_completion_inputs(req)
assert generation_params["temperature"] == 0.33
assert generation_params["max_completion_tokens"] == 123
assert generation_params["top_p"] == 0.88
assert generation_params["top_k"] == 7
assert generation_params["stop"] == ["foo", "bar"]
assert generation_params["presence_penalty"] == 0.1
assert generation_params["frequency_penalty"] == 0.2
# Should not include max_output_tokens
assert "max_output_tokens" not in generation_params
assert "stop_sequences" not in generation_params