Skip to content

Commit 120cbab

Browse files
seanzhougooglecopybara-github
authored andcommitted
refactor: Rename long util function name in runner.py and move it to functions.py
PiperOrigin-RevId: 774880990
1 parent 29cd183 commit 120cbab

File tree

4 files changed

+170
-206
lines changed

4 files changed

+170
-206
lines changed

src/google/adk/flows/llm_flows/functions.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,3 +519,35 @@ def merge_parallel_function_response_events(
519519
# Use the base_event as the timestamp
520520
merged_event.timestamp = base_event.timestamp
521521
return merged_event
522+
523+
524+
def find_matching_function_call(
525+
events: list[Event],
526+
) -> Optional[Event]:
527+
"""Finds the function call event that matches the function response id of the last event."""
528+
if not events:
529+
return None
530+
531+
last_event = events[-1]
532+
if (
533+
last_event.content
534+
and last_event.content.parts
535+
and any(part.function_response for part in last_event.content.parts)
536+
):
537+
538+
function_call_id = next(
539+
part.function_response.id
540+
for part in last_event.content.parts
541+
if part.function_response
542+
)
543+
for i in range(len(events) - 2, -1, -1):
544+
event = events[i]
545+
# looking for the system long running request euc function call
546+
function_calls = event.get_function_calls()
547+
if not function_calls:
548+
continue
549+
550+
for function_call in function_calls:
551+
if function_call.id == function_call_id:
552+
return event
553+
return None

src/google/adk/runners.py

Lines changed: 2 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from .auth.credential_service.base_credential_service import BaseCredentialService
3737
from .code_executors.built_in_code_executor import BuiltInCodeExecutor
3838
from .events.event import Event
39+
from .flows.llm_flows.functions import find_matching_function_call
3940
from .memory.base_memory_service import BaseMemoryService
4041
from .memory.in_memory_memory_service import InMemoryMemoryService
4142
from .platform.thread import create_thread
@@ -354,9 +355,7 @@ def _find_agent_to_run(
354355
# the agent that returned the corressponding function call regardless the
355356
# type of the agent. e.g. a remote a2a agent may surface a credential
356357
# request as a special long running function tool call.
357-
event = _find_function_call_event_if_last_event_is_function_response(
358-
session
359-
)
358+
event = find_matching_function_call(session.events)
360359
if event and event.author:
361360
return root_agent.find_agent(event.author)
362361
for event in filter(lambda e: e.author != 'user', reversed(session.events)):
@@ -538,35 +537,3 @@ def __init__(self, agent: BaseAgent, *, app_name: str = 'InMemoryRunner'):
538537
session_service=self._in_memory_session_service,
539538
memory_service=InMemoryMemoryService(),
540539
)
541-
542-
543-
def _find_function_call_event_if_last_event_is_function_response(
544-
session: Session,
545-
) -> Optional[Event]:
546-
events = session.events
547-
if not events:
548-
return None
549-
550-
last_event = events[-1]
551-
if (
552-
last_event.content
553-
and last_event.content.parts
554-
and any(part.function_response for part in last_event.content.parts)
555-
):
556-
557-
function_call_id = next(
558-
part.function_response.id
559-
for part in last_event.content.parts
560-
if part.function_response
561-
)
562-
for i in range(len(events) - 2, -1, -1):
563-
event = events[i]
564-
# looking for the system long running request euc function call
565-
function_calls = event.get_function_calls()
566-
if not function_calls:
567-
continue
568-
569-
for function_call in function_calls:
570-
if function_call.id == function_call_id:
571-
return event
572-
return None

tests/unittests/flows/llm_flows/test_functions_simple.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from typing import Callable
1818

1919
from google.adk.agents import Agent
20+
from google.adk.events.event import Event
21+
from google.adk.flows.llm_flows.functions import find_matching_function_call
22+
from google.adk.sessions.session import Session
2023
from google.adk.tools import ToolContext
2124
from google.adk.tools.function_tool import FunctionTool
2225
from google.genai import types
@@ -256,3 +259,136 @@ def increase_by_one(x: int) -> int:
256259
assert part.function_response.id is None
257260
assert events[0].content.parts[0].function_call.id.startswith('adk-')
258261
assert events[1].content.parts[0].function_response.id.startswith('adk-')
262+
263+
264+
def test_find_function_call_event_no_function_response_in_last_event():
265+
"""Test when last event has no function response."""
266+
events = [
267+
Event(
268+
invocation_id='inv1',
269+
author='user',
270+
content=types.Content(role='user', parts=[types.Part(text='Hello')]),
271+
)
272+
]
273+
274+
result = find_matching_function_call(events)
275+
assert result is None
276+
277+
278+
def test_find_function_call_event_empty_session_events():
279+
"""Test when session has no events."""
280+
events = []
281+
282+
result = find_matching_function_call(events)
283+
assert result is None
284+
285+
286+
def test_find_function_call_event_function_response_but_no_matching_call():
287+
"""Test when last event has function response but no matching call found."""
288+
# Create a function response
289+
function_response = types.FunctionResponse(
290+
id='func_123', name='test_func', response={}
291+
)
292+
293+
events = [
294+
Event(
295+
invocation_id='inv1',
296+
author='agent1',
297+
content=types.Content(
298+
role='model',
299+
parts=[types.Part(text='Some other response')],
300+
),
301+
),
302+
Event(
303+
invocation_id='inv2',
304+
author='user',
305+
content=types.Content(
306+
role='user',
307+
parts=[types.Part(function_response=function_response)],
308+
),
309+
),
310+
]
311+
312+
result = find_matching_function_call(events)
313+
assert result is None
314+
315+
316+
def test_find_function_call_event_function_response_with_matching_call():
317+
"""Test when last event has function response with matching function call."""
318+
# Create a function call
319+
function_call = types.FunctionCall(id='func_123', name='test_func', args={})
320+
321+
# Create a function response with matching ID
322+
function_response = types.FunctionResponse(
323+
id='func_123', name='test_func', response={}
324+
)
325+
326+
call_event = Event(
327+
invocation_id='inv1',
328+
author='agent1',
329+
content=types.Content(
330+
role='model', parts=[types.Part(function_call=function_call)]
331+
),
332+
)
333+
334+
response_event = Event(
335+
invocation_id='inv2',
336+
author='user',
337+
content=types.Content(
338+
role='user', parts=[types.Part(function_response=function_response)]
339+
),
340+
)
341+
342+
events = [call_event, response_event]
343+
344+
result = find_matching_function_call(events)
345+
assert result == call_event
346+
347+
348+
def test_find_function_call_event_multiple_function_responses():
349+
"""Test when last event has multiple function responses."""
350+
# Create function calls
351+
function_call1 = types.FunctionCall(id='func_123', name='test_func1', args={})
352+
function_call2 = types.FunctionCall(id='func_456', name='test_func2', args={})
353+
354+
# Create function responses
355+
function_response1 = types.FunctionResponse(
356+
id='func_123', name='test_func1', response={}
357+
)
358+
function_response2 = types.FunctionResponse(
359+
id='func_456', name='test_func2', response={}
360+
)
361+
362+
call_event1 = Event(
363+
invocation_id='inv1',
364+
author='agent1',
365+
content=types.Content(
366+
role='model', parts=[types.Part(function_call=function_call1)]
367+
),
368+
)
369+
370+
call_event2 = Event(
371+
invocation_id='inv2',
372+
author='agent2',
373+
content=types.Content(
374+
role='model', parts=[types.Part(function_call=function_call2)]
375+
),
376+
)
377+
378+
response_event = Event(
379+
invocation_id='inv3',
380+
author='user',
381+
content=types.Content(
382+
role='user',
383+
parts=[
384+
types.Part(function_response=function_response1),
385+
types.Part(function_response=function_response2),
386+
],
387+
),
388+
)
389+
390+
events = [call_event1, call_event2, response_event]
391+
392+
# Should return the first matching function call event found
393+
result = find_matching_function_call(events)
394+
assert result == call_event1 # First match (func_123)

0 commit comments

Comments
 (0)