Skip to content

Commit 8e40534

Browse files
seanzhougooglecopybara-github
authored andcommitted
feat: Add A2A endpoints to fast api server when --a2a option is specified (WIP)
PiperOrigin-RevId: 774526985
1 parent 22629a1 commit 8e40534

File tree

2 files changed

+240
-5
lines changed

2 files changed

+240
-5
lines changed

src/google/adk/cli/fast_api.py

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import asyncio
1818
from contextlib import asynccontextmanager
19+
import json
1920
import logging
2021
import os
2122
from pathlib import Path
@@ -32,7 +33,6 @@
3233
from fastapi import HTTPException
3334
from fastapi import Query
3435
from fastapi.middleware.cors import CORSMiddleware
35-
from fastapi.responses import FileResponse
3636
from fastapi.responses import RedirectResponse
3737
from fastapi.responses import StreamingResponse
3838
from fastapi.staticfiles import StaticFiles
@@ -53,7 +53,6 @@
5353
from ..agents import RunConfig
5454
from ..agents.live_request_queue import LiveRequest
5555
from ..agents.live_request_queue import LiveRequestQueue
56-
from ..agents.llm_agent import Agent
5756
from ..agents.run_config import StreamingMode
5857
from ..artifacts.gcs_artifact_service import GcsArtifactService
5958
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
@@ -65,8 +64,6 @@
6564
from ..evaluation.eval_metrics import EvalMetricResult
6665
from ..evaluation.eval_metrics import EvalMetricResultPerInvocation
6766
from ..evaluation.eval_result import EvalSetResult
68-
from ..evaluation.gcs_eval_set_results_manager import GcsEvalSetResultsManager
69-
from ..evaluation.gcs_eval_sets_manager import GcsEvalSetsManager
7067
from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager
7168
from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
7269
from ..events.event import Event
@@ -965,6 +962,86 @@ async def _get_runner_async(app_name: str) -> Runner:
965962
runner_dict[app_name] = runner
966963
return runner
967964

965+
if a2a:
966+
try:
967+
from a2a.server.apps import A2AStarletteApplication
968+
from a2a.server.request_handlers import DefaultRequestHandler
969+
from a2a.server.tasks import InMemoryTaskStore
970+
from a2a.types import AgentCard
971+
972+
from ..a2a.executor.a2a_agent_executor import A2aAgentExecutor
973+
974+
except ImportError as e:
975+
import sys
976+
977+
if sys.version_info < (3, 10):
978+
raise ImportError(
979+
"A2A requires Python 3.10 or above. Please upgrade your Python"
980+
" version."
981+
) from e
982+
else:
983+
raise e
984+
# locate all a2a agent apps in the agents directory
985+
base_path = Path.cwd() / agents_dir
986+
# the root agents directory should be an existing folder
987+
if base_path.exists() and base_path.is_dir():
988+
a2a_task_store = InMemoryTaskStore()
989+
990+
def create_a2a_runner_loader(captured_app_name: str):
991+
"""Factory function to create A2A runner with proper closure."""
992+
993+
async def _get_a2a_runner_async() -> Runner:
994+
return await _get_runner_async(captured_app_name)
995+
996+
return _get_a2a_runner_async
997+
998+
for p in base_path.iterdir():
999+
# only folders with an agent.json file representing agent card are valid
1000+
# a2a agents
1001+
if (
1002+
p.is_file()
1003+
or p.name.startswith((".", "__pycache__"))
1004+
or not (p / "agent.json").is_file()
1005+
):
1006+
continue
1007+
1008+
app_name = p.name
1009+
logger.info("Setting up A2A agent: %s", app_name)
1010+
1011+
try:
1012+
a2a_rpc_path = f"http://{host}:{port}/a2a/{app_name}"
1013+
1014+
agent_executor = A2aAgentExecutor(
1015+
runner=create_a2a_runner_loader(app_name),
1016+
)
1017+
1018+
request_handler = DefaultRequestHandler(
1019+
agent_executor=agent_executor, task_store=a2a_task_store
1020+
)
1021+
1022+
with (p / "agent.json").open("r", encoding="utf-8") as f:
1023+
data = json.load(f)
1024+
agent_card = AgentCard(**data)
1025+
agent_card.url = a2a_rpc_path
1026+
1027+
a2a_app = A2AStarletteApplication(
1028+
agent_card=agent_card,
1029+
http_handler=request_handler,
1030+
)
1031+
1032+
routes = a2a_app.routes(
1033+
rpc_url=f"/a2a/{app_name}",
1034+
agent_card_url=f"/a2a/{app_name}/.well-known/agent.json",
1035+
)
1036+
1037+
for new_route in routes:
1038+
app.router.routes.append(new_route)
1039+
1040+
logger.info("Successfully configured A2A agent: %s", app_name)
1041+
1042+
except Exception as e:
1043+
logger.error("Failed to setup A2A agent %s: %s", app_name, e)
1044+
# Continue with other agents even if one fails
9681045
if web:
9691046
import mimetypes
9701047

tests/unittests/cli/test_fast_api.py

Lines changed: 159 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,14 @@
1313
# limitations under the License.
1414

1515
import asyncio
16+
import json
1617
import logging
18+
import os
19+
from pathlib import Path
20+
import sys
21+
import tempfile
1722
import time
1823
from typing import Any
19-
from typing import Optional
2024
from unittest.mock import MagicMock
2125
from unittest.mock import patch
2226

@@ -465,6 +469,9 @@ def test_app(
465469
artifact_service_uri="",
466470
memory_service_uri="",
467471
allow_origins=["*"],
472+
a2a=False, # Disable A2A for most tests
473+
host="127.0.0.1",
474+
port=8000,
468475
)
469476

470477
# Create a TestClient that doesn't start a real server
@@ -520,6 +527,134 @@ async def create_test_eval_set(
520527
return test_session_info
521528

522529

530+
@pytest.fixture
531+
@pytest.mark.skipif(
532+
sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
533+
)
534+
def temp_agents_dir_with_a2a():
535+
"""Create a temporary agents directory with A2A agent configurations for testing."""
536+
with tempfile.TemporaryDirectory() as temp_dir:
537+
# Create test agent directory
538+
agent_dir = Path(temp_dir) / "test_a2a_agent"
539+
agent_dir.mkdir()
540+
541+
# Create agent.json file
542+
agent_card = {
543+
"name": "test_a2a_agent",
544+
"description": "Test A2A agent",
545+
"version": "1.0.0",
546+
"author": "test",
547+
"capabilities": ["text"],
548+
}
549+
550+
with open(agent_dir / "agent.json", "w") as f:
551+
json.dump(agent_card, f)
552+
553+
# Create a simple agent.py file
554+
agent_py_content = """
555+
from google.adk.agents.base_agent import BaseAgent
556+
557+
class TestA2AAgent(BaseAgent):
558+
def __init__(self):
559+
super().__init__(name="test_a2a_agent")
560+
"""
561+
562+
with open(agent_dir / "agent.py", "w") as f:
563+
f.write(agent_py_content)
564+
565+
yield temp_dir
566+
567+
568+
@pytest.fixture
569+
@pytest.mark.skipif(
570+
sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
571+
)
572+
def test_app_with_a2a(
573+
mock_session_service,
574+
mock_artifact_service,
575+
mock_memory_service,
576+
mock_agent_loader,
577+
mock_eval_sets_manager,
578+
mock_eval_set_results_manager,
579+
temp_agents_dir_with_a2a,
580+
):
581+
"""Create a TestClient for the FastAPI app with A2A enabled."""
582+
583+
# Mock A2A related classes
584+
with (
585+
patch("signal.signal", return_value=None),
586+
patch(
587+
"google.adk.cli.fast_api.InMemorySessionService",
588+
return_value=mock_session_service,
589+
),
590+
patch(
591+
"google.adk.cli.fast_api.InMemoryArtifactService",
592+
return_value=mock_artifact_service,
593+
),
594+
patch(
595+
"google.adk.cli.fast_api.InMemoryMemoryService",
596+
return_value=mock_memory_service,
597+
),
598+
patch(
599+
"google.adk.cli.fast_api.AgentLoader",
600+
return_value=mock_agent_loader,
601+
),
602+
patch(
603+
"google.adk.cli.fast_api.LocalEvalSetsManager",
604+
return_value=mock_eval_sets_manager,
605+
),
606+
patch(
607+
"google.adk.cli.fast_api.LocalEvalSetResultsManager",
608+
return_value=mock_eval_set_results_manager,
609+
),
610+
patch(
611+
"google.adk.cli.cli_eval.run_evals",
612+
new=mock_run_evals_for_fast_api,
613+
),
614+
patch("a2a.server.tasks.InMemoryTaskStore") as mock_task_store,
615+
patch(
616+
"google.adk.a2a.executor.a2a_agent_executor.A2aAgentExecutor"
617+
) as mock_executor,
618+
patch(
619+
"a2a.server.request_handlers.DefaultRequestHandler"
620+
) as mock_handler,
621+
patch("a2a.server.apps.A2AStarletteApplication") as mock_a2a_app,
622+
):
623+
# Configure mocks
624+
mock_task_store.return_value = MagicMock()
625+
mock_executor.return_value = MagicMock()
626+
mock_handler.return_value = MagicMock()
627+
628+
# Mock A2AStarletteApplication
629+
mock_app_instance = MagicMock()
630+
mock_app_instance.routes.return_value = (
631+
[]
632+
) # Return empty routes for testing
633+
mock_a2a_app.return_value = mock_app_instance
634+
635+
# Change to temp directory
636+
original_cwd = os.getcwd()
637+
os.chdir(temp_agents_dir_with_a2a)
638+
639+
try:
640+
app = get_fast_api_app(
641+
agents_dir=".",
642+
web=True,
643+
session_service_uri="",
644+
artifact_service_uri="",
645+
memory_service_uri="",
646+
allow_origins=["*"],
647+
a2a=True,
648+
host="127.0.0.1",
649+
port=8000,
650+
)
651+
652+
client = TestClient(app)
653+
yield client
654+
finally:
655+
os.chdir(original_cwd)
656+
657+
523658
#################################################
524659
# Test Cases
525660
#################################################
@@ -760,5 +895,28 @@ def test_debug_trace(test_app):
760895
logger.info("Debug trace test completed successfully")
761896

762897

898+
@pytest.mark.skipif(
899+
sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
900+
)
901+
def test_a2a_agent_discovery(test_app_with_a2a):
902+
"""Test that A2A agents are properly discovered and configured."""
903+
# This test mainly verifies that the A2A setup doesn't break the app
904+
response = test_app_with_a2a.get("/list-apps")
905+
assert response.status_code == 200
906+
logger.info("A2A agent discovery test passed")
907+
908+
909+
@pytest.mark.skipif(
910+
sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
911+
)
912+
def test_a2a_disabled_by_default(test_app):
913+
"""Test that A2A functionality is disabled by default."""
914+
# The regular test_app fixture has a2a=False
915+
# This test ensures no A2A routes are added
916+
response = test_app.get("/list-apps")
917+
assert response.status_code == 200
918+
logger.info("A2A disabled by default test passed")
919+
920+
763921
if __name__ == "__main__":
764922
pytest.main(["-xvs", __file__])

0 commit comments

Comments
 (0)