Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,13 +483,18 @@ def _model_response_to_generate_content_response(
"""

message = None
if response.get("choices", None):
message = response["choices"][0].get("message", None)
finish_reason = None
if choices := response.get("choices"):
first_choice = choices[0]
message = first_choice.get("message", None)
finish_reason = first_choice.get("finish_reason", None)

if not message:
raise ValueError("No message in response")

llm_response = _message_to_generate_content_response(message)
if finish_reason:
llm_response.finish_reason = finish_reason
if response.get("usage", None):
llm_response.usage_metadata = types.GenerateContentResponseUsageMetadata(
prompt_token_count=response["usage"].get("prompt_tokens", 0),
Expand Down
8 changes: 6 additions & 2 deletions src/google/adk/models/llm_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from typing import Any
from typing import Optional
from typing import Union

from google.genai import types
from pydantic import alias_generators
Expand Down Expand Up @@ -77,8 +78,11 @@ class LlmResponse(BaseModel):
Only used for streaming mode.
"""

finish_reason: Optional[types.FinishReason] = None
"""The finish reason of the response."""
finish_reason: Optional[Union[types.FinishReason, str]] = None
"""The finish reason of the response.

Can be either a types.FinishReason enum (from Gemini) or a string (from LiteLLM).
"""

error_code: Optional[str] = None
"""Error code if the response is an error. Code varies by model."""
Expand Down
6 changes: 5 additions & 1 deletion src/google/adk/telemetry/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,13 @@ def trace_call_llm(
llm_response.usage_metadata.candidates_token_count,
)
if llm_response.finish_reason:
if isinstance(llm_response.finish_reason, types.FinishReason):
finish_reason_str = llm_response.finish_reason.name.lower()
else:
finish_reason_str = str(llm_response.finish_reason).lower()
span.set_attribute(
'gen_ai.response.finish_reasons',
[llm_response.finish_reason.value.lower()],
[finish_reason_str],
)


Expand Down
72 changes: 72 additions & 0 deletions tests/unittests/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1849,3 +1849,75 @@ def test_non_gemini_litellm_no_warning():
# Test with non-Gemini model
LiteLlm(model="openai/gpt-4o")
assert len(w) == 0


@pytest.mark.parametrize(
"finish_reason,response_content,expected_content,has_tool_calls",
[
("length", "Test response", "Test response", False),
("stop", "Complete response", "Complete response", False),
(
"tool_calls",
"",
"",
True,
),
("content_filter", "", "", False),
],
ids=["length", "stop", "tool_calls", "content_filter"],
)
@pytest.mark.asyncio
async def test_finish_reason_propagation(
mock_acompletion,
lite_llm_instance,
finish_reason,
response_content,
expected_content,
has_tool_calls,
):
"""Test that finish_reason is properly propagated from LiteLLM response."""
tool_calls = None
if has_tool_calls:
tool_calls = [
ChatCompletionMessageToolCall(
type="function",
id="test_id",
function=Function(
name="test_function",
arguments='{"arg": "value"}',
),
)
]

mock_response = ModelResponse(
choices=[
Choices(
message=ChatCompletionAssistantMessage(
role="assistant",
content=response_content,
tool_calls=tool_calls,
),
finish_reason=finish_reason,
)
]
)
mock_acompletion.return_value = mock_response

llm_request = LlmRequest(
contents=[
types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
)
],
)

async for response in lite_llm_instance.generate_content_async(llm_request):
assert response.content.role == "model"
assert response.finish_reason == finish_reason
if expected_content:
assert response.content.parts[0].text == expected_content
if has_tool_calls:
assert len(response.content.parts) > 0
assert response.content.parts[-1].function_call.name == "test_function"

mock_acompletion.assert_called_once()