diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index dce5ed7c4..7e142b8d4 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -23,6 +23,7 @@ from typing import Dict from typing import Generator from typing import Iterable +from typing import List from typing import Literal from typing import Optional from typing import Tuple @@ -481,16 +482,22 @@ def _message_to_generate_content_response( def _get_completion_inputs( llm_request: LlmRequest, -) -> tuple[Iterable[Message], Iterable[dict]]: - """Converts an LlmRequest to litellm inputs. +) -> Tuple[ + List[Message], + Optional[List[Dict]], + Optional[types.SchemaUnion], + Optional[Dict], +]: + """Converts an LlmRequest to litellm inputs and extracts generation params. Args: llm_request: The LlmRequest to convert. Returns: - The litellm inputs (message list, tool dictionary and response format). + The litellm inputs (message list, tool dictionary, response format and generation params). """ - messages = [] + # 1. Construct messages + messages: List[Message] = [] for content in llm_request.contents or []: message_param_or_list = _content_to_message_param(content) if isinstance(message_param_or_list, list): @@ -507,7 +514,8 @@ def _get_completion_inputs( ), ) - tools = None + # 2. Convert tool declarations + tools: Optional[List[Dict]] = None if ( llm_request.config and llm_request.config.tools @@ -518,12 +526,39 @@ def _get_completion_inputs( for tool in llm_request.config.tools[0].function_declarations ] - response_format = None - - if llm_request.config.response_schema: + # 3. Handle response format + response_format: Optional[types.SchemaUnion] = None + if llm_request.config and llm_request.config.response_schema: response_format = llm_request.config.response_schema - return messages, tools, response_format + # 4. Extract generation parameters + generation_params: Optional[Dict] = None + if llm_request.config: + config_dict = llm_request.config.model_dump(exclude_none=True) + # Generate LiteLlm parameters here, + # Following https://docs.litellm.ai/docs/completion/input. + generation_params = {} + param_mapping = { + "max_output_tokens": "max_completion_tokens", + "stop_sequences": "stop", + } + for key in ( + "temperature", + "max_output_tokens", + "top_p", + "top_k", + "stop_sequences", + "presence_penalty", + "frequency_penalty", + ): + if key in config_dict: + mapped_key = param_mapping.get(key, key) + generation_params[mapped_key] = config_dict[key] + + if not generation_params: + generation_params = None + + return messages, tools, response_format, generation_params def _build_function_declaration_log( @@ -660,7 +695,9 @@ async def generate_content_async( self._maybe_append_user_content(llm_request) logger.debug(_build_request_log(llm_request)) - messages, tools, response_format = _get_completion_inputs(llm_request) + messages, tools, response_format, generation_params = ( + _get_completion_inputs(llm_request) + ) completion_args = { "model": self.model, @@ -670,6 +707,9 @@ async def generate_content_async( } completion_args.update(self._additional_args) + if generation_params: + completion_args.update(generation_params) + if stream: text = "" # Track function calls by index diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 8b43cc48b..9246c70a4 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -1430,3 +1430,35 @@ async def test_generate_content_async_non_compliant_multiple_function_calls( assert final_response.content.parts[1].function_call.name == "function_2" assert final_response.content.parts[1].function_call.id == "1" assert final_response.content.parts[1].function_call.args == {"arg": "value2"} + + +@pytest.mark.asyncio +def test_get_completion_inputs_generation_params(): + # Test that generation_params are extracted and mapped correctly + req = LlmRequest( + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="hi")]), + ], + config=types.GenerateContentConfig( + temperature=0.33, + max_output_tokens=123, + top_p=0.88, + top_k=7, + stop_sequences=["foo", "bar"], + presence_penalty=0.1, + frequency_penalty=0.2, + ), + ) + from google.adk.models.lite_llm import _get_completion_inputs + + _, _, _, generation_params = _get_completion_inputs(req) + assert generation_params["temperature"] == 0.33 + assert generation_params["max_completion_tokens"] == 123 + assert generation_params["top_p"] == 0.88 + assert generation_params["top_k"] == 7 + assert generation_params["stop"] == ["foo", "bar"] + assert generation_params["presence_penalty"] == 0.1 + assert generation_params["frequency_penalty"] == 0.2 + # Should not include max_output_tokens + assert "max_output_tokens" not in generation_params + assert "stop_sequences" not in generation_params