|
17 | 17 | from typing import Callable
|
18 | 18 |
|
19 | 19 | from google.adk.agents import Agent
|
| 20 | +from google.adk.events.event import Event |
| 21 | +from google.adk.flows.llm_flows.functions import find_matching_function_call |
| 22 | +from google.adk.sessions.session import Session |
20 | 23 | from google.adk.tools import ToolContext
|
21 | 24 | from google.adk.tools.function_tool import FunctionTool
|
22 | 25 | from google.genai import types
|
@@ -256,3 +259,136 @@ def increase_by_one(x: int) -> int:
|
256 | 259 | assert part.function_response.id is None
|
257 | 260 | assert events[0].content.parts[0].function_call.id.startswith('adk-')
|
258 | 261 | assert events[1].content.parts[0].function_response.id.startswith('adk-')
|
| 262 | + |
| 263 | + |
| 264 | +def test_find_function_call_event_no_function_response_in_last_event(): |
| 265 | + """Test when last event has no function response.""" |
| 266 | + events = [ |
| 267 | + Event( |
| 268 | + invocation_id='inv1', |
| 269 | + author='user', |
| 270 | + content=types.Content(role='user', parts=[types.Part(text='Hello')]), |
| 271 | + ) |
| 272 | + ] |
| 273 | + |
| 274 | + result = find_matching_function_call(events) |
| 275 | + assert result is None |
| 276 | + |
| 277 | + |
| 278 | +def test_find_function_call_event_empty_session_events(): |
| 279 | + """Test when session has no events.""" |
| 280 | + events = [] |
| 281 | + |
| 282 | + result = find_matching_function_call(events) |
| 283 | + assert result is None |
| 284 | + |
| 285 | + |
| 286 | +def test_find_function_call_event_function_response_but_no_matching_call(): |
| 287 | + """Test when last event has function response but no matching call found.""" |
| 288 | + # Create a function response |
| 289 | + function_response = types.FunctionResponse( |
| 290 | + id='func_123', name='test_func', response={} |
| 291 | + ) |
| 292 | + |
| 293 | + events = [ |
| 294 | + Event( |
| 295 | + invocation_id='inv1', |
| 296 | + author='agent1', |
| 297 | + content=types.Content( |
| 298 | + role='model', |
| 299 | + parts=[types.Part(text='Some other response')], |
| 300 | + ), |
| 301 | + ), |
| 302 | + Event( |
| 303 | + invocation_id='inv2', |
| 304 | + author='user', |
| 305 | + content=types.Content( |
| 306 | + role='user', |
| 307 | + parts=[types.Part(function_response=function_response)], |
| 308 | + ), |
| 309 | + ), |
| 310 | + ] |
| 311 | + |
| 312 | + result = find_matching_function_call(events) |
| 313 | + assert result is None |
| 314 | + |
| 315 | + |
| 316 | +def test_find_function_call_event_function_response_with_matching_call(): |
| 317 | + """Test when last event has function response with matching function call.""" |
| 318 | + # Create a function call |
| 319 | + function_call = types.FunctionCall(id='func_123', name='test_func', args={}) |
| 320 | + |
| 321 | + # Create a function response with matching ID |
| 322 | + function_response = types.FunctionResponse( |
| 323 | + id='func_123', name='test_func', response={} |
| 324 | + ) |
| 325 | + |
| 326 | + call_event = Event( |
| 327 | + invocation_id='inv1', |
| 328 | + author='agent1', |
| 329 | + content=types.Content( |
| 330 | + role='model', parts=[types.Part(function_call=function_call)] |
| 331 | + ), |
| 332 | + ) |
| 333 | + |
| 334 | + response_event = Event( |
| 335 | + invocation_id='inv2', |
| 336 | + author='user', |
| 337 | + content=types.Content( |
| 338 | + role='user', parts=[types.Part(function_response=function_response)] |
| 339 | + ), |
| 340 | + ) |
| 341 | + |
| 342 | + events = [call_event, response_event] |
| 343 | + |
| 344 | + result = find_matching_function_call(events) |
| 345 | + assert result == call_event |
| 346 | + |
| 347 | + |
| 348 | +def test_find_function_call_event_multiple_function_responses(): |
| 349 | + """Test when last event has multiple function responses.""" |
| 350 | + # Create function calls |
| 351 | + function_call1 = types.FunctionCall(id='func_123', name='test_func1', args={}) |
| 352 | + function_call2 = types.FunctionCall(id='func_456', name='test_func2', args={}) |
| 353 | + |
| 354 | + # Create function responses |
| 355 | + function_response1 = types.FunctionResponse( |
| 356 | + id='func_123', name='test_func1', response={} |
| 357 | + ) |
| 358 | + function_response2 = types.FunctionResponse( |
| 359 | + id='func_456', name='test_func2', response={} |
| 360 | + ) |
| 361 | + |
| 362 | + call_event1 = Event( |
| 363 | + invocation_id='inv1', |
| 364 | + author='agent1', |
| 365 | + content=types.Content( |
| 366 | + role='model', parts=[types.Part(function_call=function_call1)] |
| 367 | + ), |
| 368 | + ) |
| 369 | + |
| 370 | + call_event2 = Event( |
| 371 | + invocation_id='inv2', |
| 372 | + author='agent2', |
| 373 | + content=types.Content( |
| 374 | + role='model', parts=[types.Part(function_call=function_call2)] |
| 375 | + ), |
| 376 | + ) |
| 377 | + |
| 378 | + response_event = Event( |
| 379 | + invocation_id='inv3', |
| 380 | + author='user', |
| 381 | + content=types.Content( |
| 382 | + role='user', |
| 383 | + parts=[ |
| 384 | + types.Part(function_response=function_response1), |
| 385 | + types.Part(function_response=function_response2), |
| 386 | + ], |
| 387 | + ), |
| 388 | + ) |
| 389 | + |
| 390 | + events = [call_event1, call_event2, response_event] |
| 391 | + |
| 392 | + # Should return the first matching function call event found |
| 393 | + result = find_matching_function_call(events) |
| 394 | + assert result == call_event1 # First match (func_123) |
0 commit comments