Skip to content

Commit ecb9638

Browse files
simonwei97copybara-github
authored andcommitted
fix: converts litellm generate config err
Merge #1509 Fixs: #1302 Previous PR: #1450 FUTURE_COPYBARA_INTEGRATE_REVIEW=#1509 from simonwei97:fix/litellm-gen-config-converting-err 3120887 PiperOrigin-RevId: 774803187
1 parent 6729edd commit ecb9638

File tree

2 files changed

+82
-10
lines changed

2 files changed

+82
-10
lines changed

src/google/adk/models/lite_llm.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from typing import Dict
2424
from typing import Generator
2525
from typing import Iterable
26+
from typing import List
2627
from typing import Literal
2728
from typing import Optional
2829
from typing import Tuple
@@ -485,16 +486,22 @@ def _message_to_generate_content_response(
485486

486487
def _get_completion_inputs(
487488
llm_request: LlmRequest,
488-
) -> tuple[Iterable[Message], Iterable[dict]]:
489-
"""Converts an LlmRequest to litellm inputs.
489+
) -> Tuple[
490+
List[Message],
491+
Optional[List[Dict]],
492+
Optional[types.SchemaUnion],
493+
Optional[Dict],
494+
]:
495+
"""Converts an LlmRequest to litellm inputs and extracts generation params.
490496
491497
Args:
492498
llm_request: The LlmRequest to convert.
493499
494500
Returns:
495-
The litellm inputs (message list, tool dictionary and response format).
501+
The litellm inputs (message list, tool dictionary, response format and generation params).
496502
"""
497-
messages = []
503+
# 1. Construct messages
504+
messages: List[Message] = []
498505
for content in llm_request.contents or []:
499506
message_param_or_list = _content_to_message_param(content)
500507
if isinstance(message_param_or_list, list):
@@ -511,7 +518,8 @@ def _get_completion_inputs(
511518
),
512519
)
513520

514-
tools = None
521+
# 2. Convert tool declarations
522+
tools: Optional[List[Dict]] = None
515523
if (
516524
llm_request.config
517525
and llm_request.config.tools
@@ -522,12 +530,39 @@ def _get_completion_inputs(
522530
for tool in llm_request.config.tools[0].function_declarations
523531
]
524532

525-
response_format = None
526-
527-
if llm_request.config.response_schema:
533+
# 3. Handle response format
534+
response_format: Optional[types.SchemaUnion] = None
535+
if llm_request.config and llm_request.config.response_schema:
528536
response_format = llm_request.config.response_schema
529537

530-
return messages, tools, response_format
538+
# 4. Extract generation parameters
539+
generation_params: Optional[Dict] = None
540+
if llm_request.config:
541+
config_dict = llm_request.config.model_dump(exclude_none=True)
542+
# Generate LiteLlm parameters here,
543+
# Following https://docs.litellm.ai/docs/completion/input.
544+
generation_params = {}
545+
param_mapping = {
546+
"max_output_tokens": "max_completion_tokens",
547+
"stop_sequences": "stop",
548+
}
549+
for key in (
550+
"temperature",
551+
"max_output_tokens",
552+
"top_p",
553+
"top_k",
554+
"stop_sequences",
555+
"presence_penalty",
556+
"frequency_penalty",
557+
):
558+
if key in config_dict:
559+
mapped_key = param_mapping.get(key, key)
560+
generation_params[mapped_key] = config_dict[key]
561+
562+
if not generation_params:
563+
generation_params = None
564+
565+
return messages, tools, response_format, generation_params
531566

532567

533568
def _build_function_declaration_log(
@@ -664,7 +699,9 @@ async def generate_content_async(
664699
self._maybe_append_user_content(llm_request)
665700
logger.debug(_build_request_log(llm_request))
666701

667-
messages, tools, response_format = _get_completion_inputs(llm_request)
702+
messages, tools, response_format, generation_params = (
703+
_get_completion_inputs(llm_request)
704+
)
668705

669706
if "functions" in self._additional_args:
670707
# LiteLLM does not support both tools and functions together.
@@ -678,6 +715,9 @@ async def generate_content_async(
678715
}
679716
completion_args.update(self._additional_args)
680717

718+
if generation_params:
719+
completion_args.update(generation_params)
720+
681721
if stream:
682722
text = ""
683723
# Track function calls by index

tests/unittests/models/test_litellm.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,3 +1447,35 @@ async def test_generate_content_async_non_compliant_multiple_function_calls(
14471447
assert final_response.content.parts[1].function_call.name == "function_2"
14481448
assert final_response.content.parts[1].function_call.id == "1"
14491449
assert final_response.content.parts[1].function_call.args == {"arg": "value2"}
1450+
1451+
1452+
@pytest.mark.asyncio
1453+
def test_get_completion_inputs_generation_params():
1454+
# Test that generation_params are extracted and mapped correctly
1455+
req = LlmRequest(
1456+
contents=[
1457+
types.Content(role="user", parts=[types.Part.from_text(text="hi")]),
1458+
],
1459+
config=types.GenerateContentConfig(
1460+
temperature=0.33,
1461+
max_output_tokens=123,
1462+
top_p=0.88,
1463+
top_k=7,
1464+
stop_sequences=["foo", "bar"],
1465+
presence_penalty=0.1,
1466+
frequency_penalty=0.2,
1467+
),
1468+
)
1469+
from google.adk.models.lite_llm import _get_completion_inputs
1470+
1471+
_, _, _, generation_params = _get_completion_inputs(req)
1472+
assert generation_params["temperature"] == 0.33
1473+
assert generation_params["max_completion_tokens"] == 123
1474+
assert generation_params["top_p"] == 0.88
1475+
assert generation_params["top_k"] == 7
1476+
assert generation_params["stop"] == ["foo", "bar"]
1477+
assert generation_params["presence_penalty"] == 0.1
1478+
assert generation_params["frequency_penalty"] == 0.2
1479+
# Should not include max_output_tokens
1480+
assert "max_output_tokens" not in generation_params
1481+
assert "stop_sequences" not in generation_params

0 commit comments

Comments
 (0)