From e79651cd86ba3f0c998109f2140f1db2cab78708 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Thu, 26 Jun 2025 11:18:12 -0700 Subject: [PATCH] feat: Add A2A endpoints to fast api server when --a2a option is specified (WIP) PiperOrigin-RevId: 776211580 --- src/google/adk/cli/fast_api.py | 85 +++++++++++++- tests/unittests/cli/test_fast_api.py | 160 ++++++++++++++++++++++++++- 2 files changed, 240 insertions(+), 5 deletions(-) diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 69d7c3a0e0..1360514793 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -16,6 +16,7 @@ import asyncio from contextlib import asynccontextmanager +import json import logging import os from pathlib import Path @@ -32,7 +33,6 @@ from fastapi import HTTPException from fastapi import Query from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import FileResponse from fastapi.responses import RedirectResponse from fastapi.responses import StreamingResponse from fastapi.staticfiles import StaticFiles @@ -53,7 +53,6 @@ from ..agents import RunConfig from ..agents.live_request_queue import LiveRequest from ..agents.live_request_queue import LiveRequestQueue -from ..agents.llm_agent import Agent from ..agents.run_config import StreamingMode from ..artifacts.gcs_artifact_service import GcsArtifactService from ..artifacts.in_memory_artifact_service import InMemoryArtifactService @@ -65,8 +64,6 @@ from ..evaluation.eval_metrics import EvalMetricResult from ..evaluation.eval_metrics import EvalMetricResultPerInvocation from ..evaluation.eval_result import EvalSetResult -from ..evaluation.gcs_eval_set_results_manager import GcsEvalSetResultsManager -from ..evaluation.gcs_eval_sets_manager import GcsEvalSetsManager from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager from ..events.event import Event @@ -965,6 +962,86 @@ async def _get_runner_async(app_name: str) -> Runner: runner_dict[app_name] = runner return runner + if a2a: + try: + from a2a.server.apps import A2AStarletteApplication + from a2a.server.request_handlers import DefaultRequestHandler + from a2a.server.tasks import InMemoryTaskStore + from a2a.types import AgentCard + + from ..a2a.executor.a2a_agent_executor import A2aAgentExecutor + + except ImportError as e: + import sys + + if sys.version_info < (3, 10): + raise ImportError( + "A2A requires Python 3.10 or above. Please upgrade your Python" + " version." + ) from e + else: + raise e + # locate all a2a agent apps in the agents directory + base_path = Path.cwd() / agents_dir + # the root agents directory should be an existing folder + if base_path.exists() and base_path.is_dir(): + a2a_task_store = InMemoryTaskStore() + + def create_a2a_runner_loader(captured_app_name: str): + """Factory function to create A2A runner with proper closure.""" + + async def _get_a2a_runner_async() -> Runner: + return await _get_runner_async(captured_app_name) + + return _get_a2a_runner_async + + for p in base_path.iterdir(): + # only folders with an agent.json file representing agent card are valid + # a2a agents + if ( + p.is_file() + or p.name.startswith((".", "__pycache__")) + or not (p / "agent.json").is_file() + ): + continue + + app_name = p.name + logger.info("Setting up A2A agent: %s", app_name) + + try: + a2a_rpc_path = f"http://{host}:{port}/a2a/{app_name}" + + agent_executor = A2aAgentExecutor( + runner=create_a2a_runner_loader(app_name), + ) + + request_handler = DefaultRequestHandler( + agent_executor=agent_executor, task_store=a2a_task_store + ) + + with (p / "agent.json").open("r", encoding="utf-8") as f: + data = json.load(f) + agent_card = AgentCard(**data) + agent_card.url = a2a_rpc_path + + a2a_app = A2AStarletteApplication( + agent_card=agent_card, + http_handler=request_handler, + ) + + routes = a2a_app.routes( + rpc_url=f"/a2a/{app_name}", + agent_card_url=f"/a2a/{app_name}/.well-known/agent.json", + ) + + for new_route in routes: + app.router.routes.append(new_route) + + logger.info("Successfully configured A2A agent: %s", app_name) + + except Exception as e: + logger.error("Failed to setup A2A agent %s: %s", app_name, e) + # Continue with other agents even if one fails if web: import mimetypes diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index aec7a020b2..d4f9382e37 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -13,10 +13,14 @@ # limitations under the License. import asyncio +import json import logging +import os +from pathlib import Path +import sys +import tempfile import time from typing import Any -from typing import Optional from unittest.mock import MagicMock from unittest.mock import patch @@ -465,6 +469,9 @@ def test_app( artifact_service_uri="", memory_service_uri="", allow_origins=["*"], + a2a=False, # Disable A2A for most tests + host="127.0.0.1", + port=8000, ) # Create a TestClient that doesn't start a real server @@ -520,6 +527,134 @@ async def create_test_eval_set( return test_session_info +@pytest.fixture +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="A2A requires Python 3.10+" +) +def temp_agents_dir_with_a2a(): + """Create a temporary agents directory with A2A agent configurations for testing.""" + with tempfile.TemporaryDirectory() as temp_dir: + # Create test agent directory + agent_dir = Path(temp_dir) / "test_a2a_agent" + agent_dir.mkdir() + + # Create agent.json file + agent_card = { + "name": "test_a2a_agent", + "description": "Test A2A agent", + "version": "1.0.0", + "author": "test", + "capabilities": ["text"], + } + + with open(agent_dir / "agent.json", "w") as f: + json.dump(agent_card, f) + + # Create a simple agent.py file + agent_py_content = """ +from google.adk.agents.base_agent import BaseAgent + +class TestA2AAgent(BaseAgent): + def __init__(self): + super().__init__(name="test_a2a_agent") +""" + + with open(agent_dir / "agent.py", "w") as f: + f.write(agent_py_content) + + yield temp_dir + + +@pytest.fixture +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="A2A requires Python 3.10+" +) +def test_app_with_a2a( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, + temp_agents_dir_with_a2a, +): + """Create a TestClient for the FastAPI app with A2A enabled.""" + + # Mock A2A related classes + with ( + patch("signal.signal", return_value=None), + patch( + "google.adk.cli.fast_api.InMemorySessionService", + return_value=mock_session_service, + ), + patch( + "google.adk.cli.fast_api.InMemoryArtifactService", + return_value=mock_artifact_service, + ), + patch( + "google.adk.cli.fast_api.InMemoryMemoryService", + return_value=mock_memory_service, + ), + patch( + "google.adk.cli.fast_api.AgentLoader", + return_value=mock_agent_loader, + ), + patch( + "google.adk.cli.fast_api.LocalEvalSetsManager", + return_value=mock_eval_sets_manager, + ), + patch( + "google.adk.cli.fast_api.LocalEvalSetResultsManager", + return_value=mock_eval_set_results_manager, + ), + patch( + "google.adk.cli.cli_eval.run_evals", + new=mock_run_evals_for_fast_api, + ), + patch("a2a.server.tasks.InMemoryTaskStore") as mock_task_store, + patch( + "google.adk.a2a.executor.a2a_agent_executor.A2aAgentExecutor" + ) as mock_executor, + patch( + "a2a.server.request_handlers.DefaultRequestHandler" + ) as mock_handler, + patch("a2a.server.apps.A2AStarletteApplication") as mock_a2a_app, + ): + # Configure mocks + mock_task_store.return_value = MagicMock() + mock_executor.return_value = MagicMock() + mock_handler.return_value = MagicMock() + + # Mock A2AStarletteApplication + mock_app_instance = MagicMock() + mock_app_instance.routes.return_value = ( + [] + ) # Return empty routes for testing + mock_a2a_app.return_value = mock_app_instance + + # Change to temp directory + original_cwd = os.getcwd() + os.chdir(temp_agents_dir_with_a2a) + + try: + app = get_fast_api_app( + agents_dir=".", + web=True, + session_service_uri="", + artifact_service_uri="", + memory_service_uri="", + allow_origins=["*"], + a2a=True, + host="127.0.0.1", + port=8000, + ) + + client = TestClient(app) + yield client + finally: + os.chdir(original_cwd) + + ################################################# # Test Cases ################################################# @@ -760,5 +895,28 @@ def test_debug_trace(test_app): logger.info("Debug trace test completed successfully") +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="A2A requires Python 3.10+" +) +def test_a2a_agent_discovery(test_app_with_a2a): + """Test that A2A agents are properly discovered and configured.""" + # This test mainly verifies that the A2A setup doesn't break the app + response = test_app_with_a2a.get("/list-apps") + assert response.status_code == 200 + logger.info("A2A agent discovery test passed") + + +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="A2A requires Python 3.10+" +) +def test_a2a_disabled_by_default(test_app): + """Test that A2A functionality is disabled by default.""" + # The regular test_app fixture has a2a=False + # This test ensures no A2A routes are added + response = test_app.get("/list-apps") + assert response.status_code == 200 + logger.info("A2A disabled by default test passed") + + if __name__ == "__main__": pytest.main(["-xvs", __file__])